diff --git a/rest_framework_jwt/utils.py b/rest_framework_jwt/utils.py index c72197bc..8d470ac6 100644 --- a/rest_framework_jwt/utils.py +++ b/rest_framework_jwt/utils.py @@ -12,7 +12,7 @@ from rest_framework_jwt.settings import api_settings -def jwt_get_secret_key(payload=None): +def jwt_get_secret_key(user_id=None): """ For enhanced security you may want to use a secret key based on user. @@ -21,12 +21,13 @@ def jwt_get_secret_key(payload=None): - password is changed - etc. """ - if api_settings.JWT_GET_USER_SECRET_KEY: - User = get_user_model() # noqa: N806 - user = User.objects.get(pk=payload.get('user_id')) - key = str(api_settings.JWT_GET_USER_SECRET_KEY(user)) - return key - return api_settings.JWT_SECRET_KEY + if not user_id: + return api_settings.JWT_SECRET_KEY + + User = get_user_model() # noqa: N806 + user = User.objects.get(pk=user_id) + key = str(api_settings.JWT_GET_USER_SECRET_KEY(user)) + return key def jwt_payload_handler(user): @@ -88,7 +89,9 @@ def jwt_get_username_from_payload_handler(payload): def jwt_encode_handler(payload): - key = api_settings.JWT_PRIVATE_KEY or jwt_get_secret_key(payload) + key = api_settings.JWT_PRIVATE_KEY or jwt_get_secret_key( + payload.get('user_id') + if api_settings.JWT_GET_USER_SECRET_KEY else None) return jwt.encode( payload, key, @@ -100,12 +103,12 @@ def jwt_decode_handler(token): options = { 'verify_exp': api_settings.JWT_VERIFY_EXPIRATION, } - # get user from token, BEFORE verification, to get user secret key - unverified_payload = jwt.decode(token, None, False) - secret_key = jwt_get_secret_key(unverified_payload) + key = api_settings.JWT_PUBLIC_KEY or jwt_get_secret_key( + jwt.decode(token, None, False).get('user_id') + if api_settings.JWT_GET_USER_SECRET_KEY else None) return jwt.decode( token, - api_settings.JWT_PUBLIC_KEY or secret_key, + key, api_settings.JWT_VERIFY, options=options, leeway=api_settings.JWT_LEEWAY,