diff --git a/flask_discord/client.py b/flask_discord/client.py index a0cfad3..37fa2a4 100644 --- a/flask_discord/client.py +++ b/flask_discord/client.py @@ -94,14 +94,12 @@ class DiscordOAuth2Session(_http.DiscordOAuth2HttpClient): # Encode any params into a jwt with the state as the key # Use generate_token in case state is None - session['DISCORD_RAW_OAUTH2_STATE'] = session.get("DISCORD_OAUTH2_STATE", generate_token()) - state = jwt.encode(params, session.get("DISCORD_RAW_OAUTH2_STATE")) + session['DISCORD_JWT_KEY'] = session.get("DISCORD_JWT_KEY", generate_token()) + state = jwt.encode(params, session.get("DISCORD_JWT_KEY")) discord_session = self._make_session(scope=scope, state=state) authorization_url, state = discord_session.authorization_url(configs.DISCORD_AUTHORIZATION_BASE_URL) - - # Save the encoded state as that's what Oauth2 lib is expecting - session["DISCORD_OAUTH2_STATE"] = state.decode("utf-8") + session['DISCORD_OAUTH2_STATE'] = state.decode("utf-8") # Add special parameters to uri instead of state uri_params = {'prompt': prompt} @@ -145,16 +143,24 @@ class DiscordOAuth2Session(_http.DiscordOAuth2HttpClient): It fetches the authorization token and saves it flask `session `_ object. + Raises + ------ + oauthlib.oauth2.rfc6749.errors.MismatchingStateError + jwt.exceptions.InvalidSignatureError + """ if request.values.get("error"): return request.values["error"] + + # Decode JWT. This only works if the state matches. + passed_state = request.args.get("state") + jwt_key = session.get("DISCORD_JWT_KEY") + decoded = jwt.decode(passed_state, jwt_key) + + # Now that we've decoded the state, we can continue the oauth2 process token = self._fetch_token() self.save_authorization_token(token) - - # Decode any parameters passed through state variable - raw_oauth_state = session.get("DISCORD_RAW_OAUTH2_STATE") - passed_state = request.args.get("state") - return jwt.decode(passed_state, raw_oauth_state) + return decoded def revoke(self): """This method clears current discord token, state and all session data from flask