diff --git a/README.md b/README.md index d741025..c378839 100644 --- a/README.md +++ b/README.md @@ -22,6 +22,7 @@ hana: auth_requests_per_minute: 6 # This counts valid and invalid requests from an IP address access_token_secret: a800445648142061fc238d1f84e96200da87f4f9fa7835cac90db8b4391b117b refresh_token_secret: 833d369ac73d883123743a44b4a7fe21203cffc956f4c8fec712e71aafa8e1aa + jwt_issuer: neon.ai # Used in the `iss` field of generated JWT tokens. fastapi_title: "My HANA API Host" fastapi_summary: "Personal HTTP API to access my DIANA backend." disable_auth: True diff --git a/neon_hana/auth/client_manager.py b/neon_hana/auth/client_manager.py index e26a51b..9dcf22e 100644 --- a/neon_hana/auth/client_manager.py +++ b/neon_hana/auth/client_manager.py @@ -23,10 +23,12 @@ # LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING # NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS # SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -from threading import Lock +from uuid import uuid4 import jwt +from datetime import datetime +from threading import Lock from time import time from typing import Dict, Optional from fastapi import Request, HTTPException @@ -36,11 +38,12 @@ from token_throttler import TokenThrottler, TokenBucket from token_throttler.storage import RuntimeStorage -from neon_hana.auth.permissions import ClientPermissions +from neon_data_models.models.api.jwt import HanaToken from neon_hana.mq_service_api import MQServiceManager from neon_data_models.models.user import (User, TokenConfig, NeonUserConfig, PermissionsConfig) from neon_data_models.enum import AccessRoles +from neon_hana.schema.auth_requests import AuthenticationResponse _DEFAULT_USER_PERMISSIONS = PermissionsConfig(klat=AccessRoles.USER, core=AccessRoles.USER, @@ -55,10 +58,14 @@ def __init__(self, config: dict, mq_connector: Optional[MQServiceManager] = None): self.rate_limiter = TokenThrottler(cost=1, storage=RuntimeStorage()) - self.authorized_clients: Dict[str, dict] = dict() + # TODO: Is `authorized_clients` useful to track? + # Keep a dict of `client_id` to auth tokens that have authenticated to + # this instance + self.authorized_clients: Dict[str, HanaToken] = dict() self._access_token_lifetime = config.get("access_token_ttl", 3600 * 24) self._refresh_token_lifetime = config.get("refresh_token_ttl", - 3600 * 24 * 7) + 3600 * 24 * 90) + self._jwt_issuer = config.get("jwt_issuer", "neon.ai") self._access_secret = config.get("access_token_secret") self._refresh_secret = config.get("refresh_token_secret") self._rpm = config.get("requests_per_minute", 60) @@ -72,42 +79,48 @@ def __init__(self, config: dict, self._stream_check_lock = Lock() self._mq_connector = mq_connector - def _create_tokens(self, encode_data: dict) -> TokenConfig: - # Permissions were not included in old tokens, allow refreshing with - # default permissions - encode_data.setdefault("permissions", ClientPermissions().as_dict()) + def _create_tokens(self, + user_id: str, + client_id: str, + token_name: Optional[str] = None, + permissions: Optional[PermissionsConfig] = None, + **kwargs) -> (HanaToken, HanaToken, TokenConfig): + token_id = str(uuid4()) + creation_timestamp = round(time()) + expiration_timestamp = creation_timestamp + self._access_token_lifetime + refresh_expiration_timestamp = creation_timestamp + self._refresh_token_lifetime + permissions = permissions or PermissionsConfig(core=AccessRoles.GUEST, + diana=AccessRoles.GUEST, + node=AccessRoles.GUEST, + llm=AccessRoles.GUEST) + token_name = token_name or kwargs.get("name") or \ + datetime.fromtimestamp(creation_timestamp).isoformat() + access_token_data = HanaToken(iss=self._jwt_issuer, + sub=user_id, + exp=expiration_timestamp, + iat=creation_timestamp, + jti=token_id, + client_id=client_id, + roles=permissions.to_roles(), + purpose="access") + refresh_token_data = HanaToken(iss=self._jwt_issuer, + sub=user_id, + exp=refresh_expiration_timestamp, + iat=creation_timestamp, + jti=f"{token_id}.refresh", + client_id=client_id, + roles=permissions.to_roles(), + purpose="refresh") - token_expiration = encode_data['expire'] - token = jwt.encode(encode_data, self._access_secret, self._jwt_algo) - encode_data['expire'] = round(time()) + self._refresh_token_lifetime - encode_data['access_token'] = token - refresh = jwt.encode(encode_data, self._refresh_secret, self._jwt_algo) - return TokenConfig(**{"username": encode_data['username'], - "client_id": encode_data['client_id'], - "permissions": encode_data['permissions'], - "access_token": token, - "refresh_token": refresh, - "expiration": token_expiration, - "refresh_expiration": encode_data['expire'], - "token_name": encode_data['name'], - "creation_timestamp": encode_data['create'], - "last_refresh_timestamp": encode_data['last_refresh_timestamp'] - }) - - def get_permissions(self, client_id: str) -> ClientPermissions: - """ - Get ClientPermissions model for the given client_id - @param client_id: Client ID to get permissions for - @return: ClientPermissions object for the specified client - """ - if self._disable_auth: - LOG.debug("Auth disabled, allow full client permissions") - return ClientPermissions(assist=True, backend=True, node=True) - if client_id not in self.authorized_clients: - LOG.warning(f"{client_id} not known to this server") - return ClientPermissions(assist=False, backend=False, node=False) - client = self.authorized_clients[client_id] - return ClientPermissions(**client.get('permissions', dict())) + token_config = TokenConfig(token_name=token_name, + token_id=token_id, + user_id=user_id, + client_id=client_id, + permissions=permissions, + refresh_expiration_timestamp=refresh_expiration_timestamp, + creation_timestamp=creation_timestamp, + last_refresh_timestamp=creation_timestamp) + return access_token_data, refresh_token_data, token_config def check_connect_stream(self) -> bool: """ @@ -145,7 +158,7 @@ def check_registration_request(self, username: str, password: str, def check_auth_request(self, client_id: str, username: str, password: Optional[str] = None, token_name: Optional[str] = None, - origin_ip: str = "127.0.0.1") -> dict: + origin_ip: str = "127.0.0.1") -> AuthenticationResponse: """ Authenticate and Authorize a new client connection with the specified username, password, and origin IP address. @@ -156,9 +169,9 @@ def check_auth_request(self, client_id: str, username: str, @param origin_ip: Origin IP address of request @return: response tokens, permissions, and other metadata """ - if client_id in self.authorized_clients: - print(f"Using cached client: {self.authorized_clients[client_id]}") - return self.authorized_clients[client_id] + # if client_id in self.authorized_clients: + # print(f"Using cached client: {self.authorized_clients[client_id]}") + # return self.authorized_clients[client_id] ratelimit_id = f"auth{origin_ip}" if not self.rate_limiter.get_all_buckets(ratelimit_id): @@ -182,91 +195,93 @@ def check_auth_request(self, client_id: str, username: str, user.permissions.node = AccessRoles.USER else: user = self._mq_connector.get_user_profile(username, password) - username = user.username - # Boolean permissions allow access for any role, including `NODE`. - # Specific endpoints may enforce more granular controls/limits based on - # specific user.permissions values. - permissions = ClientPermissions( - node=user.permissions.node != AccessRoles.NONE, - assist=user.permissions.core != AccessRoles.NONE, - backend=user.permissions.diana != AccessRoles.NONE) create_time = round(time()) - expiration = create_time + self._access_token_lifetime encode_data = {"client_id": client_id, - "sub": username, # Added for Klat token compat. - "name": token_name, - "username": username, - "permissions": permissions.as_dict(), - "create": create_time, - "expire": expiration, + "user_id": user.user_id, + "permissions": user.permissions, + "token_name": token_name, "last_refresh_timestamp": create_time} - auth = self._create_tokens(encode_data) - self._add_token_to_userdb(user, auth) - self.authorized_clients[client_id] = auth.model_dump() - return auth.model_dump() + access, refresh, config = self._create_tokens(**encode_data) + self.authorized_clients[client_id] = config + self._add_token_to_userdb(user, config) + return AuthenticationResponse(username=user.username, + client_id=client_id, + access_token=access, + refresh_token=refresh, + expiration=config.refresh_expiration_timestamp) def check_refresh_request(self, access_token: str, refresh_token: str, - client_id: str): + client_id: str) -> AuthenticationResponse: # Read and validate refresh token try: - refresh_data = jwt.decode(refresh_token, self._refresh_secret, - self._jwt_algo) + refresh_data = HanaToken(**jwt.decode(refresh_token, + self._refresh_secret, + self._jwt_algo)) + token_data = HanaToken(**jwt.decode(access_token, + self._access_secret, + self._jwt_algo)) except DecodeError: raise HTTPException(status_code=400, detail="Invalid refresh token supplied") - if refresh_data['access_token'] != access_token: + if refresh_data.jti != token_data.jti + ".refresh": raise HTTPException(status_code=403, detail="Refresh and access token mismatch") - if time() > refresh_data['expire']: + if time() > refresh_data.exp: raise HTTPException(status_code=401, detail="Refresh token is expired") - # Read access token and re-generate a new pair of tokens - # This is already known to be a valid token based on the refresh token - token_data = jwt.decode(access_token, self._access_secret, - self._jwt_algo) - if token_data['client_id'] != client_id: + if token_data.client_id != client_id: raise HTTPException(status_code=403, detail="Access token does not match client_id") - encode_data = {k: token_data[k] for k in - ("client_id", "username", "password")} + # `token_name` is not known here, but it will be read from the database + # when the new token replaces the old one + encode_data = {"user_id": token_data.sub, + "client_id": client_id, + "permissions": PermissionsConfig.from_roles(token_data.roles) + } if self._mq_connector: - user = self._mq_connector.get_user_profile(username=token_data['username'], + user = self._mq_connector.get_user_profile(username=token_data.sub, access_token=refresh_token) if not user.password_hash: # This should not be possible, but don't let an error in the # users service allow for injecting a new valid token to the db raise HTTPException(status_code=500, detail="Error Fetching User") - refresh_time = round(time()) - encode_data['last_refresh_timestamp'] = refresh_time - encode_data["expire"] = refresh_time + self._access_token_lifetime - new_auth = self._create_tokens(encode_data) - self._add_token_to_userdb(user, new_auth) + access, refresh, config = self._create_tokens(**encode_data) + username = user.username + self._add_token_to_userdb(user, config) else: - new_auth = self._create_tokens(encode_data) - return new_auth.model_dump() + username = token_data.sub + access, refresh, config = self._create_tokens(**encode_data) + return AuthenticationResponse(username=username, + client_id=client_id, + access_token=access, + refresh_token=refresh, + expiration=config.refresh_expiration_timestamp) - def _add_token_to_userdb(self, user: User, token_data: TokenConfig): + def _add_token_to_userdb(self, user: User, new_token: TokenConfig): if self._mq_connector is None: print("No MQ Connection to a user database") return - # Enforce unique `creation_timestamp` values to avoid duplicate entries for idx, token in enumerate(user.tokens): - if token.creation_timestamp == token_data.creation_timestamp: + if token.token_id == new_token.token_id: + # Tokens don't contain `token_name`, so use the same one as is + # being replaced + new_token.token_name = token.token_name user.tokens.remove(token) - user.tokens.append(token_data) + user.tokens.append(new_token) self._mq_connector.update_user(user) def get_client_id(self, token: str) -> str: """ - Extract the client_id from a JWT token - @param token: JWT token to parse + Extract the client_id from a JWT string + @param token: JWT to parse @return: client_id associated with token """ - auth = jwt.decode(token, self._access_secret, self._jwt_algo) - return auth['client_id'] + auth = HanaToken(**jwt.decode(token, self._access_secret, + self._jwt_algo)) + return auth.client_id def validate_auth(self, token: str, origin_ip: str) -> bool: if not self.rate_limiter.get_all_buckets(origin_ip): @@ -281,11 +296,12 @@ def validate_auth(self, token: str, origin_ip: str) -> bool: if self._disable_auth: return True try: - auth = jwt.decode(token, self._access_secret, self._jwt_algo) - if auth['expire'] < time(): - self.authorized_clients.pop(auth['client_id'], None) + auth = HanaToken(**jwt.decode(token, self._access_secret, + self._jwt_algo)) + if auth.exp < time(): + self.authorized_clients.pop(auth.client_id, None) return False - self.authorized_clients[auth['client_id']] = auth + self.authorized_clients[auth.client_id] = auth return True except DecodeError: # Invalid token supplied diff --git a/neon_hana/auth/permissions.py b/neon_hana/auth/permissions.py deleted file mode 100644 index 287cc38..0000000 --- a/neon_hana/auth/permissions.py +++ /dev/null @@ -1,43 +0,0 @@ -# NEON AI (TM) SOFTWARE, Software Development Kit & Application Development System -# All trademark and other rights reserved by their respective owners -# Copyright 2008-2021 Neongecko.com Inc. -# BSD-3 -# Redistribution and use in source and binary forms, with or without -# modification, are permitted provided that the following conditions are met: -# 1. Redistributions of source code must retain the above copyright notice, -# this list of conditions and the following disclaimer. -# 2. Redistributions in binary form must reproduce the above copyright notice, -# this list of conditions and the following disclaimer in the documentation -# and/or other materials provided with the distribution. -# 3. Neither the name of the copyright holder nor the names of its -# contributors may be used to endorse or promote products derived from this -# software without specific prior written permission. -# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, -# THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR -# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR -# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, -# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, -# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, -# OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF -# LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING -# NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS -# SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - -from dataclasses import dataclass, asdict - - -@dataclass -class ClientPermissions: - """ - Data class representing permissions of a particular client connection. - """ - assist: bool = True - backend: bool = True - node: bool = False - - def as_dict(self) -> dict: - """ - Get a dict representation of this instance. - """ - return asdict(self) diff --git a/tests/test_auth.py b/tests/test_auth.py index 88422fb..73b0ae8 100644 --- a/tests/test_auth.py +++ b/tests/test_auth.py @@ -76,7 +76,7 @@ def test_validate_auth(self): "127.0.0.1")) self.assertFalse(self.client_manager.validate_auth(invalid_client, "127.0.0.1")) - + # TODO: Update token data expired_token = self.client_manager._create_tokens( {"client_id": invalid_client, "username": "test", "password": "test", "expire": time(), @@ -93,6 +93,7 @@ def test_validate_auth(self): def test_check_refresh_request(self): valid_client = str(uuid4()) + # TODO: Update token data tokens = self.client_manager._create_tokens({"client_id": valid_client, "username": "test", "password": "test", @@ -134,6 +135,7 @@ def test_check_refresh_request(self): # Test expired refresh token real_refresh = self.client_manager._refresh_token_lifetime self.client_manager._refresh_token_lifetime = 0 + # TODO: Update token data tokens = self.client_manager._create_tokens({"client_id": valid_client, "username": "test", "password": "test",