Skip to content

Commit

Permalink
Refactor token handling to use same HanaToken model that is JWT-encoded
Browse files Browse the repository at this point in the history
  • Loading branch information
NeonDaniel committed Nov 12, 2024
1 parent 542db86 commit 7c2945f
Show file tree
Hide file tree
Showing 4 changed files with 39 additions and 31 deletions.
5 changes: 3 additions & 2 deletions neon_hana/app/routers/user.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,9 @@
@user_route.post("/get")
async def get_user(request: GetUserRequest,
token: str = Depends(jwt_bearer)) -> User:
user_id = jwt_bearer.client_manager.get_token_user_id(token)
return mq_connector.read_user(access_token=token, auth_user=user_id,
hana_token = jwt_bearer.client_manager.get_token_data(token)
return mq_connector.read_user(access_token=hana_token,
auth_user=hana_token.sub,
**dict(request))


Expand Down
52 changes: 27 additions & 25 deletions neon_hana/auth/client_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@

from neon_data_models.models.api.jwt import HanaToken
from neon_hana.mq_service_api import MQServiceManager
from neon_data_models.models.user import (User, TokenConfig, NeonUserConfig,
from neon_data_models.models.user import (User, NeonUserConfig,
PermissionsConfig)
from neon_data_models.enum import AccessRoles
from neon_hana.schema.auth_requests import AuthenticationResponse
Expand Down Expand Up @@ -94,7 +94,7 @@ def _create_tokens(self,
client_id: str,
token_name: Optional[str] = None,
permissions: Optional[PermissionsConfig] = None,
**kwargs) -> (str, str, TokenConfig):
**kwargs) -> (str, str, Dict[str, HanaToken]):
token_id = str(uuid4())
# Subtract a second from creation so the token may be used immediately
# upon return
Expand All @@ -106,14 +106,17 @@ def _create_tokens(self,
node=AccessRoles.GUEST,
llm=AccessRoles.GUEST)
token_name = token_name or kwargs.get("name") or \
datetime.fromtimestamp(creation_timestamp).isoformat()
datetime.fromtimestamp(creation_timestamp).isoformat()
access_token_data = HanaToken(iss=self._jwt_issuer,
sub=user_id,
exp=expiration_timestamp,
iat=creation_timestamp,
jti=token_id,
client_id=client_id,
roles=permissions.to_roles(),
token_name=token_name,
creation_timestamp=creation_timestamp,
last_refresh_timestamp=creation_timestamp,
purpose="access")
refresh_token_data = HanaToken(iss=self._jwt_issuer,
sub=user_id,
Expand All @@ -122,20 +125,17 @@ def _create_tokens(self,
jti=f"{token_id}.refresh",
client_id=client_id,
roles=permissions.to_roles(),
token_name=token_name,
creation_timestamp=creation_timestamp,
last_refresh_timestamp=creation_timestamp,
purpose="refresh")
access_token = jwt.encode(access_token_data.model_dump(),
self._access_secret, self._jwt_algo)
refresh_token = jwt.encode(refresh_token_data.model_dump(),
self._refresh_secret, self._jwt_algo)
token_config = TokenConfig(token_name=token_name,
token_id=token_id,
user_id=user_id,
client_id=client_id,
permissions=permissions,
refresh_expiration_timestamp=refresh_expiration_timestamp,
creation_timestamp=creation_timestamp,
last_refresh_timestamp=creation_timestamp)
return access_token, refresh_token, token_config

return access_token, refresh_token, {"access": access_token_data,
"refresh": refresh_token_data}

def check_connect_stream(self) -> bool:
"""
Expand Down Expand Up @@ -243,9 +243,9 @@ def check_auth_request(self, client_id: str, username: str,
client_id=client_id,
access_token=access,
refresh_token=refresh,
expiration=config.refresh_expiration_timestamp)
expiration=config['access'].exp)
self.authorized_clients[client_id] = auth_response
self._add_token_to_userdb(user, config)
self._add_token_to_userdb(user, config['refresh'])
return auth_response

def check_refresh_request(self, access_token: Optional[str],
Expand All @@ -256,9 +256,10 @@ def check_refresh_request(self, access_token: Optional[str],
refresh_data = HanaToken(**jwt.decode(refresh_token,
self._refresh_secret,
self._jwt_algo))
# token_data = HanaToken(**jwt.decode(access_token,
# self._access_secret,
# self._jwt_algo))
token_data = HanaToken(**jwt.decode(access_token,
self._access_secret,
self._jwt_algo,
options={"verify_signature": False}))
if refresh_data.purpose != "refresh":
raise HTTPException(status_code=400,
detail="Supplied refresh token not valid")
Expand Down Expand Up @@ -290,7 +291,7 @@ def check_refresh_request(self, access_token: Optional[str],
}
if self._mq_connector:
user = self._mq_connector.read_user(username=refresh_data.sub,
access_token=refresh_token)
access_token=token_data)
if not user.password_hash:
# This should not be possible, but don't let an error in the
# users service allow for injecting a new valid token to the db
Expand All @@ -306,18 +307,21 @@ def check_refresh_request(self, access_token: Optional[str],
client_id=client_id,
access_token=access,
refresh_token=refresh,
expiration=config.refresh_expiration_timestamp)
expiration=config['access'].refresh_expiration_timestamp)
self._authorized_clients[client_id] = auth_response
return auth_response

def _add_token_to_userdb(self, user: User, new_token: TokenConfig):
def _add_token_to_userdb(self, user: User, new_token: HanaToken):
if new_token.purpose != "refresh":
raise ValueError(f"Expected a refresh token, got: "
f"{new_token.purpose}")
if self._mq_connector is None:
print("No MQ Connection to a user database")
return
for idx, token in enumerate(user.tokens):
# If the token is already defined, maintain the original
# token_id and creation timestamp
if token.token_id == new_token.token_id:
if token.jti == new_token.jti:
new_token.token_name = token.token_name
new_token.creation_timestamp = token.creation_timestamp
user.tokens.remove(token)
Expand All @@ -334,16 +338,14 @@ def get_client_id(self, token: str) -> str:
self._jwt_algo))
return auth.client_id

def get_token_user_id(self, token: str) -> str:
def get_token_data(self, token: str) -> HanaToken:
"""
Extract the user_id from a JWT string
@param token: JWT to parse
@retrun: user_id associated with token
"""
auth = HanaToken(**jwt.decode(token, self._access_secret,
return HanaToken(**jwt.decode(token, self._access_secret,
self._jwt_algo))
return auth.user_id


def validate_auth(self, token: str, origin_ip: str) -> bool:
ratelimit_id = f"{origin_ip}-total"
Expand Down
8 changes: 6 additions & 2 deletions neon_hana/mq_service_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@

from neon_data_models.models.api import CreateUserRequest, ReadUserRequest, \
UpdateUserRequest, DeleteUserRequest
from neon_data_models.models.api.jwt import HanaToken
from neon_mq_connector.utils.client_utils import send_mq_request
from neon_data_models.models.client.node import NodeData
from neon_data_models.models.user.neon_profile import UserProfile
Expand Down Expand Up @@ -127,7 +128,7 @@ def create_user(self, user: User) -> User:
return err_or_user

def read_user(self, username: str, password: Optional[str] = None,
access_token: Optional[str] = None,
access_token: Optional[HanaToken] = None,
auth_user: Optional[str] = None) -> User:
"""
Get a User object for a user. This requires that a valid password OR
Expand All @@ -136,9 +137,10 @@ def read_user(self, username: str, password: Optional[str] = None,
@param username: Valid username to get a User object for
@param password: Valid password to use for authentication
@param access_token: Valid access token to use for authentication
@param auth_user: Optional username to use for authentication
@param auth_user: Optional username or user ID to use for authentication
@returns: User object from the Users service.
"""
auth_user = auth_user or username
read_user_request = ReadUserRequest(user_spec=username,
auth_user_spec=auth_user,
access_token=access_token,
Expand All @@ -158,6 +160,8 @@ def update_user(self, user: User,
@param auth_password: Password associated with `auth_user`
@returns: User as read from the database
"""
auth_user = auth_user or user.username
auth_password = auth_password or user.password_hash
update_user_request = UpdateUserRequest(user=user,
auth_username=auth_user,
auth_password=auth_password,
Expand Down
5 changes: 3 additions & 2 deletions tests/test_auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,8 @@ def test_check_refresh_request(self):
user_id=str(uuid4()), client_id=valid_client)
access2, refresh2, config2 = self.client_manager._create_tokens(
user_id=str(uuid4()), client_id=str(uuid4()))
self.assertEqual(config.client_id, valid_client)
self.assertEqual(config['access'].client_id, valid_client)
self.assertEqual(config['refresh'].client_id, valid_client)

# Test invalid refresh token
with self.assertRaises(HTTPException) as e:
Expand Down Expand Up @@ -139,7 +140,7 @@ def test_check_refresh_request(self):
user_id=str(uuid4()), client_id=valid_client)
with self.assertRaises(HTTPException) as e:
self.client_manager.check_refresh_request(access, refresh,
config.client_id)
config['access'].client_id)
self.assertEqual(e.exception.status_code, 401)
self.client_manager._refresh_token_lifetime = real_refresh

Expand Down

0 comments on commit 7c2945f

Please sign in to comment.