Skip to content

Commit

Permalink
Add optional audience param to validation functions (#279)
Browse files Browse the repository at this point in the history
  • Loading branch information
guyp-descope authored Aug 27, 2023
1 parent c385103 commit b3899b5
Show file tree
Hide file tree
Showing 9 changed files with 85 additions and 32 deletions.
2 changes: 2 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -359,6 +359,8 @@ jwt_response = descope_client.validate_and_refresh_session(session_token, refres

Choose the right session validation and refresh combination that suits your needs.

Note: all those validation apis can receive an optional 'audience' parameter that should be provided when using jwt that has the 'aud' claim)

Refreshed sessions return the same response as is returned when users first sign up / log in,
containing the session and refresh tokens, as well as all of the JWT claims.
Make sure to return the tokens from the response to the client, or updated the cookie if you're using it.
Expand Down
65 changes: 48 additions & 17 deletions descope/auth.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,11 @@
import sys

if sys.version_info[0] >= 3 and sys.version_info[1] >= 10:
# Python 3.10 and above
from collections.abc import Iterable
else:
from collections.abc import Iterable

import copy
import json
import os
Expand Down Expand Up @@ -153,7 +161,7 @@ def exchange_token(self, uri, code: str) -> dict:
response = self.do_post(uri=uri, body=body, params=None)
resp = response.json()
jwt_response = self.generate_jwt_response(
resp, response.cookies.get(REFRESH_SESSION_COOKIE_NAME)
resp, response.cookies.get(REFRESH_SESSION_COOKIE_NAME), None
)
return jwt_response

Expand Down Expand Up @@ -275,7 +283,7 @@ def exchange_access_key(self, access_key: str) -> dict:
server_response = self.do_post(uri=uri, body={}, params=None, pswd=access_key)
json = server_response.json()
return self._generate_auth_info(
response_body=json, refresh_token=None, user_jwt=False
response_body=json, refresh_token=None, user_jwt=False, audience=None
)

@staticmethod
Expand Down Expand Up @@ -421,19 +429,25 @@ def adjust_properties(self, jwt_response: dict, user_jwt: bool):
return jwt_response

def _generate_auth_info(
self, response_body: dict, refresh_token: str, user_jwt: bool
self,
response_body: dict,
refresh_token: str,
user_jwt: bool,
audience: str | Iterable[str] | None = None,
) -> dict:
jwt_response = {}
st_jwt = response_body.get("sessionJwt", "")
if st_jwt:
jwt_response[SESSION_TOKEN_NAME] = self._validate_token(st_jwt)
jwt_response[SESSION_TOKEN_NAME] = self._validate_token(st_jwt, audience)
rt_jwt = response_body.get("refreshJwt", "")
if refresh_token:
jwt_response[REFRESH_SESSION_TOKEN_NAME] = self._validate_token(
refresh_token
refresh_token, audience
)
elif rt_jwt:
jwt_response[REFRESH_SESSION_TOKEN_NAME] = self._validate_token(rt_jwt)
jwt_response[REFRESH_SESSION_TOKEN_NAME] = self._validate_token(
rt_jwt, audience
)

jwt_response = self.adjust_properties(jwt_response, user_jwt)

Expand All @@ -447,8 +461,15 @@ def _generate_auth_info(

return jwt_response

def generate_jwt_response(self, response_body: dict, refresh_cookie: str) -> dict:
jwt_response = self._generate_auth_info(response_body, refresh_cookie, True)
def generate_jwt_response(
self,
response_body: dict,
refresh_cookie: str,
audience: str | Iterable[str] | None = None,
) -> dict:
jwt_response = self._generate_auth_info(
response_body, refresh_cookie, True, audience
)

jwt_response["user"] = response_body.get("user", {})
jwt_response["firstSeen"] = response_body.get("firstSeen", True)
Expand All @@ -471,7 +492,9 @@ def _get_default_headers(self, pswd: str = None):
return headers

# Validate a token and load the public key if needed
def _validate_token(self, token: str) -> dict:
def _validate_token(
self, token: str, audience: str | Iterable[str] | None = None
) -> dict:
if not token:
raise AuthException(
500,
Expand Down Expand Up @@ -527,6 +550,7 @@ def _validate_token(self, token: str) -> dict:
jwt=token,
key=copy_key[0].key,
algorithms=[alg_header],
audience=audience,
leeway=self.jwt_validation_leeway,
)
except ImmatureSignatureError:
Expand All @@ -539,7 +563,9 @@ def _validate_token(self, token: str) -> dict:
claims["jwt"] = token
return claims

def validate_session(self, session_token: str) -> dict:
def validate_session(
self, session_token: str, audience: str | Iterable[str] | None = None
) -> dict:
if not session_token:
raise AuthException(
400,
Expand All @@ -548,7 +574,7 @@ def validate_session(self, session_token: str) -> dict:
)

try:
res = self._validate_token(session_token)
res = self._validate_token(session_token, audience)
res[SESSION_TOKEN_NAME] = copy.deepcopy(
res
) # Duplicate for saving backward compatibility but keep the same structure as the refresh operation response
Expand All @@ -560,7 +586,9 @@ def validate_session(self, session_token: str) -> dict:
401, ERROR_TYPE_INVALID_TOKEN, f"Invalid session token: {e}"
)

def refresh_session(self, refresh_token: str) -> dict:
def refresh_session(
self, refresh_token: str, audience: str | Iterable[str] | None = None
) -> dict:
if not refresh_token:
raise AuthException(
400,
Expand All @@ -569,7 +597,7 @@ def refresh_session(self, refresh_token: str) -> dict:
)

try:
self._validate_token(refresh_token)
self._validate_token(refresh_token, audience)
except RateLimitException as e:
raise e
except Exception as e:
Expand All @@ -582,10 +610,13 @@ def refresh_session(self, refresh_token: str) -> dict:
response = self.do_post(uri=uri, body={}, params=None, pswd=refresh_token)

resp = response.json()
return self.generate_jwt_response(resp, refresh_token)
return self.generate_jwt_response(resp, refresh_token, audience)

def validate_and_refresh_session(
self, session_token: str = None, refresh_token: str = None
self,
session_token: str = None,
refresh_token: str = None,
audience: str | Iterable[str] | None = None,
) -> dict:
if not session_token and not refresh_token:
raise AuthException(
Expand All @@ -595,10 +626,10 @@ def validate_and_refresh_session(
)

try:
return self.validate_session(session_token)
return self.validate_session(session_token, audience)
except Exception:
# Session is invalid - try to refresh it
return self.refresh_session(refresh_token)
return self.refresh_session(refresh_token, audience)

@staticmethod
def extract_masked_address(response: dict, method: DeliveryMethod) -> str:
Expand Down
2 changes: 1 addition & 1 deletion descope/authmethod/enchantedlink.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ def get_session(self, pending_ref: str) -> dict:

resp = response.json()
jwt_response = self._auth.generate_jwt_response(
resp, response.cookies.get(REFRESH_SESSION_COOKIE_NAME, None)
resp, response.cookies.get(REFRESH_SESSION_COOKIE_NAME, None), None
)
return jwt_response

Expand Down
2 changes: 1 addition & 1 deletion descope/authmethod/magiclink.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ def verify(self, token: str) -> dict:
response = self._auth.do_post(uri, body, None)
resp = response.json()
jwt_response = self._auth.generate_jwt_response(
resp, response.cookies.get(REFRESH_SESSION_COOKIE_NAME, None)
resp, response.cookies.get(REFRESH_SESSION_COOKIE_NAME, None), None
)
return jwt_response

Expand Down
2 changes: 1 addition & 1 deletion descope/authmethod/otp.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,7 @@ def verify_code(self, method: DeliveryMethod, login_id: str, code: str) -> dict:

resp = response.json()
jwt_response = self._auth.generate_jwt_response(
resp, response.cookies.get(REFRESH_SESSION_COOKIE_NAME, None)
resp, response.cookies.get(REFRESH_SESSION_COOKIE_NAME, None), None
)
return jwt_response

Expand Down
6 changes: 3 additions & 3 deletions descope/authmethod/password.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ def sign_up(self, login_id: str, password: str, user: dict = None) -> dict:

resp = response.json()
jwt_response = self._auth.generate_jwt_response(
resp, response.cookies.get(REFRESH_SESSION_COOKIE_NAME, None)
resp, response.cookies.get(REFRESH_SESSION_COOKIE_NAME, None), None
)
return jwt_response

Expand Down Expand Up @@ -80,7 +80,7 @@ def sign_in(

resp = response.json()
jwt_response = self._auth.generate_jwt_response(
resp, response.cookies.get(REFRESH_SESSION_COOKIE_NAME, None)
resp, response.cookies.get(REFRESH_SESSION_COOKIE_NAME, None), None
)
return jwt_response

Expand Down Expand Up @@ -202,7 +202,7 @@ def replace(self, login_id: str, old_password: str, new_password: str) -> dict:

resp = response.json()
jwt_response = self._auth.generate_jwt_response(
resp, response.cookies.get(REFRESH_SESSION_COOKIE_NAME, None)
resp, response.cookies.get(REFRESH_SESSION_COOKIE_NAME, None), None
)
return jwt_response

Expand Down
2 changes: 1 addition & 1 deletion descope/authmethod/totp.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ def sign_in_code(

resp = response.json()
jwt_response = self._auth.generate_jwt_response(
resp, response.cookies.get(REFRESH_SESSION_COOKIE_NAME, None)
resp, response.cookies.get(REFRESH_SESSION_COOKIE_NAME, None), None
)
return jwt_response

Expand Down
4 changes: 2 additions & 2 deletions descope/authmethod/webauthn.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ def sign_up_finish(self, transaction_id: str, response: str) -> dict:

resp = response.json()
jwt_response = self._auth.generate_jwt_response(
resp, response.cookies.get(REFRESH_SESSION_COOKIE_NAME, None)
resp, response.cookies.get(REFRESH_SESSION_COOKIE_NAME, None), None
)
return jwt_response

Expand Down Expand Up @@ -104,7 +104,7 @@ def sign_in_finish(self, transaction_id: str, response: str) -> dict:

resp = response.json()
jwt_response = self._auth.generate_jwt_response(
resp, response.cookies.get(REFRESH_SESSION_COOKIE_NAME, None)
resp, response.cookies.get(REFRESH_SESSION_COOKIE_NAME, None), None
)
return jwt_response

Expand Down
32 changes: 26 additions & 6 deletions descope/descope_client.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,11 @@
import sys

if sys.version_info[0] >= 3 and sys.version_info[1] >= 10:
# Python 3.10 and above
from collections.abc import Iterable
else:
from collections.abc import Iterable

from typing import List

import requests
Expand Down Expand Up @@ -183,7 +191,9 @@ def validate_tenant_roles(
return False
return True

def validate_session(self, session_token: str) -> dict:
def validate_session(
self, session_token: str, audience: str | Iterable[str] | None = None
) -> dict:
"""
Validate a session token. Call this function for every incoming request to your
private endpoints. Alternatively, use validate_and_refresh_session in order to
Expand All @@ -194,32 +204,39 @@ def validate_session(self, session_token: str) -> dict:
Args:
session_token (str): The session token to be validated
audience (str|Iterable[str]|None): Optional recipients that the JWT is intended for (must be equal to the 'aud' claim on the provided token)
Return value (dict):
Return dict includes the session token and all JWT claims
Raise:
AuthException: Exception is raised if session is not authorized or any other error occurs
"""
return self._auth.validate_session(session_token)
return self._auth.validate_session(session_token, audience)

def refresh_session(self, refresh_token: str) -> dict:
def refresh_session(
self, refresh_token: str, audience: str | Iterable[str] | None = None
) -> dict:
"""
Refresh a session. Call this function when a session expires and needs to be refreshed.
Args:
refresh_token (str): The refresh token that will be used to refresh the session
audience (str|Iterable[str]|None): Optional recipients that the JWT is intended for (must be equal to the 'aud' claim on the provided token)
Return value (dict):
Return dict includes the session token, refresh token, and all JWT claims
Raise:
AuthException: Exception is raised if refresh token is not authorized or any other error occurs
"""
return self._auth.refresh_session(refresh_token)
return self._auth.refresh_session(refresh_token, audience)

def validate_and_refresh_session(
self, session_token: str, refresh_token: str
self,
session_token: str,
refresh_token: str,
audience: str | Iterable[str] | None = None,
) -> dict:
"""
Validate the session token and refresh it if it has expired, the session token will automatically be refreshed.
Expand All @@ -230,14 +247,17 @@ def validate_and_refresh_session(
Args:
session_token (str): The session token to be validated
refresh_token (str): The refresh token that will be used to refresh the session token, if needed
audience (str|Iterable[str]|None): Optional recipients that the JWT is intended for (must be equal to the 'aud' claim on the provided token)
Return value (dict):
Return dict includes the session token, refresh token, and all JWT claims
Raise:
AuthException: Exception is raised if session is not authorized or another error occurs
"""
return self._auth.validate_and_refresh_session(session_token, refresh_token)
return self._auth.validate_and_refresh_session(
session_token, refresh_token, audience
)

def logout(self, refresh_token: str) -> requests.Response:
"""
Expand Down

0 comments on commit b3899b5

Please sign in to comment.