From 7c2945f3a8f67612eb28c9341f9e0a5c3c334225 Mon Sep 17 00:00:00 2001 From: Daniel McKnight Date: Mon, 11 Nov 2024 19:08:35 -0800 Subject: [PATCH] Refactor token handling to use same HanaToken model that is JWT-encoded --- neon_hana/app/routers/user.py | 5 +-- neon_hana/auth/client_manager.py | 52 +++++++++++++++++--------------- neon_hana/mq_service_api.py | 8 +++-- tests/test_auth.py | 5 +-- 4 files changed, 39 insertions(+), 31 deletions(-) diff --git a/neon_hana/app/routers/user.py b/neon_hana/app/routers/user.py index 5810c8a..059320b 100644 --- a/neon_hana/app/routers/user.py +++ b/neon_hana/app/routers/user.py @@ -35,8 +35,9 @@ @user_route.post("/get") async def get_user(request: GetUserRequest, token: str = Depends(jwt_bearer)) -> User: - user_id = jwt_bearer.client_manager.get_token_user_id(token) - return mq_connector.read_user(access_token=token, auth_user=user_id, + hana_token = jwt_bearer.client_manager.get_token_data(token) + return mq_connector.read_user(access_token=hana_token, + auth_user=hana_token.sub, **dict(request)) diff --git a/neon_hana/auth/client_manager.py b/neon_hana/auth/client_manager.py index 9bd30d8..9eb3150 100644 --- a/neon_hana/auth/client_manager.py +++ b/neon_hana/auth/client_manager.py @@ -40,7 +40,7 @@ from neon_data_models.models.api.jwt import HanaToken from neon_hana.mq_service_api import MQServiceManager -from neon_data_models.models.user import (User, TokenConfig, NeonUserConfig, +from neon_data_models.models.user import (User, NeonUserConfig, PermissionsConfig) from neon_data_models.enum import AccessRoles from neon_hana.schema.auth_requests import AuthenticationResponse @@ -94,7 +94,7 @@ def _create_tokens(self, client_id: str, token_name: Optional[str] = None, permissions: Optional[PermissionsConfig] = None, - **kwargs) -> (str, str, TokenConfig): + **kwargs) -> (str, str, Dict[str, HanaToken]): token_id = str(uuid4()) # Subtract a second from creation so the token may be used immediately # upon return @@ -106,7 +106,7 @@ def _create_tokens(self, node=AccessRoles.GUEST, llm=AccessRoles.GUEST) token_name = token_name or kwargs.get("name") or \ - datetime.fromtimestamp(creation_timestamp).isoformat() + datetime.fromtimestamp(creation_timestamp).isoformat() access_token_data = HanaToken(iss=self._jwt_issuer, sub=user_id, exp=expiration_timestamp, @@ -114,6 +114,9 @@ def _create_tokens(self, jti=token_id, client_id=client_id, roles=permissions.to_roles(), + token_name=token_name, + creation_timestamp=creation_timestamp, + last_refresh_timestamp=creation_timestamp, purpose="access") refresh_token_data = HanaToken(iss=self._jwt_issuer, sub=user_id, @@ -122,20 +125,17 @@ def _create_tokens(self, jti=f"{token_id}.refresh", client_id=client_id, roles=permissions.to_roles(), + token_name=token_name, + creation_timestamp=creation_timestamp, + last_refresh_timestamp=creation_timestamp, purpose="refresh") access_token = jwt.encode(access_token_data.model_dump(), self._access_secret, self._jwt_algo) refresh_token = jwt.encode(refresh_token_data.model_dump(), self._refresh_secret, self._jwt_algo) - token_config = TokenConfig(token_name=token_name, - token_id=token_id, - user_id=user_id, - client_id=client_id, - permissions=permissions, - refresh_expiration_timestamp=refresh_expiration_timestamp, - creation_timestamp=creation_timestamp, - last_refresh_timestamp=creation_timestamp) - return access_token, refresh_token, token_config + + return access_token, refresh_token, {"access": access_token_data, + "refresh": refresh_token_data} def check_connect_stream(self) -> bool: """ @@ -243,9 +243,9 @@ def check_auth_request(self, client_id: str, username: str, client_id=client_id, access_token=access, refresh_token=refresh, - expiration=config.refresh_expiration_timestamp) + expiration=config['access'].exp) self.authorized_clients[client_id] = auth_response - self._add_token_to_userdb(user, config) + self._add_token_to_userdb(user, config['refresh']) return auth_response def check_refresh_request(self, access_token: Optional[str], @@ -256,9 +256,10 @@ def check_refresh_request(self, access_token: Optional[str], refresh_data = HanaToken(**jwt.decode(refresh_token, self._refresh_secret, self._jwt_algo)) - # token_data = HanaToken(**jwt.decode(access_token, - # self._access_secret, - # self._jwt_algo)) + token_data = HanaToken(**jwt.decode(access_token, + self._access_secret, + self._jwt_algo, + options={"verify_signature": False})) if refresh_data.purpose != "refresh": raise HTTPException(status_code=400, detail="Supplied refresh token not valid") @@ -290,7 +291,7 @@ def check_refresh_request(self, access_token: Optional[str], } if self._mq_connector: user = self._mq_connector.read_user(username=refresh_data.sub, - access_token=refresh_token) + access_token=token_data) if not user.password_hash: # This should not be possible, but don't let an error in the # users service allow for injecting a new valid token to the db @@ -306,18 +307,21 @@ def check_refresh_request(self, access_token: Optional[str], client_id=client_id, access_token=access, refresh_token=refresh, - expiration=config.refresh_expiration_timestamp) + expiration=config['access'].refresh_expiration_timestamp) self._authorized_clients[client_id] = auth_response return auth_response - def _add_token_to_userdb(self, user: User, new_token: TokenConfig): + def _add_token_to_userdb(self, user: User, new_token: HanaToken): + if new_token.purpose != "refresh": + raise ValueError(f"Expected a refresh token, got: " + f"{new_token.purpose}") if self._mq_connector is None: print("No MQ Connection to a user database") return for idx, token in enumerate(user.tokens): # If the token is already defined, maintain the original # token_id and creation timestamp - if token.token_id == new_token.token_id: + if token.jti == new_token.jti: new_token.token_name = token.token_name new_token.creation_timestamp = token.creation_timestamp user.tokens.remove(token) @@ -334,16 +338,14 @@ def get_client_id(self, token: str) -> str: self._jwt_algo)) return auth.client_id - def get_token_user_id(self, token: str) -> str: + def get_token_data(self, token: str) -> HanaToken: """ Extract the user_id from a JWT string @param token: JWT to parse @retrun: user_id associated with token """ - auth = HanaToken(**jwt.decode(token, self._access_secret, + return HanaToken(**jwt.decode(token, self._access_secret, self._jwt_algo)) - return auth.user_id - def validate_auth(self, token: str, origin_ip: str) -> bool: ratelimit_id = f"{origin_ip}-total" diff --git a/neon_hana/mq_service_api.py b/neon_hana/mq_service_api.py index 4521489..de8b646 100644 --- a/neon_hana/mq_service_api.py +++ b/neon_hana/mq_service_api.py @@ -33,6 +33,7 @@ from neon_data_models.models.api import CreateUserRequest, ReadUserRequest, \ UpdateUserRequest, DeleteUserRequest +from neon_data_models.models.api.jwt import HanaToken from neon_mq_connector.utils.client_utils import send_mq_request from neon_data_models.models.client.node import NodeData from neon_data_models.models.user.neon_profile import UserProfile @@ -127,7 +128,7 @@ def create_user(self, user: User) -> User: return err_or_user def read_user(self, username: str, password: Optional[str] = None, - access_token: Optional[str] = None, + access_token: Optional[HanaToken] = None, auth_user: Optional[str] = None) -> User: """ Get a User object for a user. This requires that a valid password OR @@ -136,9 +137,10 @@ def read_user(self, username: str, password: Optional[str] = None, @param username: Valid username to get a User object for @param password: Valid password to use for authentication @param access_token: Valid access token to use for authentication - @param auth_user: Optional username to use for authentication + @param auth_user: Optional username or user ID to use for authentication @returns: User object from the Users service. """ + auth_user = auth_user or username read_user_request = ReadUserRequest(user_spec=username, auth_user_spec=auth_user, access_token=access_token, @@ -158,6 +160,8 @@ def update_user(self, user: User, @param auth_password: Password associated with `auth_user` @returns: User as read from the database """ + auth_user = auth_user or user.username + auth_password = auth_password or user.password_hash update_user_request = UpdateUserRequest(user=user, auth_username=auth_user, auth_password=auth_password, diff --git a/tests/test_auth.py b/tests/test_auth.py index 7e2c862..c6e5139 100644 --- a/tests/test_auth.py +++ b/tests/test_auth.py @@ -104,7 +104,8 @@ def test_check_refresh_request(self): user_id=str(uuid4()), client_id=valid_client) access2, refresh2, config2 = self.client_manager._create_tokens( user_id=str(uuid4()), client_id=str(uuid4())) - self.assertEqual(config.client_id, valid_client) + self.assertEqual(config['access'].client_id, valid_client) + self.assertEqual(config['refresh'].client_id, valid_client) # Test invalid refresh token with self.assertRaises(HTTPException) as e: @@ -139,7 +140,7 @@ def test_check_refresh_request(self): user_id=str(uuid4()), client_id=valid_client) with self.assertRaises(HTTPException) as e: self.client_manager.check_refresh_request(access, refresh, - config.client_id) + config['access'].client_id) self.assertEqual(e.exception.status_code, 401) self.client_manager._refresh_token_lifetime = real_refresh