Skip to content

Commit

Permalink
Add jwt_issuer config
Browse files Browse the repository at this point in the history
Deprecate `ClientPermissions` which duplicates role-based permissions spec in neon-data-models
Refactor token handling to use JWT model and updated configuration spec
  • Loading branch information
NeonDaniel committed Nov 5, 2024
1 parent e577118 commit e5b3f7d
Show file tree
Hide file tree
Showing 4 changed files with 114 additions and 138 deletions.
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ hana:
auth_requests_per_minute: 6 # This counts valid and invalid requests from an IP address
access_token_secret: a800445648142061fc238d1f84e96200da87f4f9fa7835cac90db8b4391b117b
refresh_token_secret: 833d369ac73d883123743a44b4a7fe21203cffc956f4c8fec712e71aafa8e1aa
jwt_issuer: neon.ai # Used in the `iss` field of generated JWT tokens.
fastapi_title: "My HANA API Host"
fastapi_summary: "Personal HTTP API to access my DIANA backend."
disable_auth: True
Expand Down
204 changes: 110 additions & 94 deletions neon_hana/auth/client_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,10 +23,12 @@
# LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
# NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
# SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
from threading import Lock
from uuid import uuid4

import jwt

from datetime import datetime
from threading import Lock
from time import time
from typing import Dict, Optional
from fastapi import Request, HTTPException
Expand All @@ -36,11 +38,12 @@
from token_throttler import TokenThrottler, TokenBucket
from token_throttler.storage import RuntimeStorage

from neon_hana.auth.permissions import ClientPermissions
from neon_data_models.models.api.jwt import HanaToken
from neon_hana.mq_service_api import MQServiceManager
from neon_data_models.models.user import (User, TokenConfig, NeonUserConfig,
PermissionsConfig)
from neon_data_models.enum import AccessRoles
from neon_hana.schema.auth_requests import AuthenticationResponse

_DEFAULT_USER_PERMISSIONS = PermissionsConfig(klat=AccessRoles.USER,
core=AccessRoles.USER,
Expand All @@ -55,10 +58,14 @@ def __init__(self, config: dict,
mq_connector: Optional[MQServiceManager] = None):
self.rate_limiter = TokenThrottler(cost=1, storage=RuntimeStorage())

self.authorized_clients: Dict[str, dict] = dict()
# TODO: Is `authorized_clients` useful to track?
# Keep a dict of `client_id` to auth tokens that have authenticated to
# this instance
self.authorized_clients: Dict[str, HanaToken] = dict()
self._access_token_lifetime = config.get("access_token_ttl", 3600 * 24)
self._refresh_token_lifetime = config.get("refresh_token_ttl",
3600 * 24 * 7)
3600 * 24 * 90)
self._jwt_issuer = config.get("jwt_issuer", "neon.ai")
self._access_secret = config.get("access_token_secret")
self._refresh_secret = config.get("refresh_token_secret")
self._rpm = config.get("requests_per_minute", 60)
Expand All @@ -72,42 +79,48 @@ def __init__(self, config: dict,
self._stream_check_lock = Lock()
self._mq_connector = mq_connector

def _create_tokens(self, encode_data: dict) -> TokenConfig:
# Permissions were not included in old tokens, allow refreshing with
# default permissions
encode_data.setdefault("permissions", ClientPermissions().as_dict())
def _create_tokens(self,
user_id: str,
client_id: str,
token_name: Optional[str] = None,
permissions: Optional[PermissionsConfig] = None,
**kwargs) -> (HanaToken, HanaToken, TokenConfig):
token_id = str(uuid4())
creation_timestamp = round(time())
expiration_timestamp = creation_timestamp + self._access_token_lifetime
refresh_expiration_timestamp = creation_timestamp + self._refresh_token_lifetime
permissions = permissions or PermissionsConfig(core=AccessRoles.GUEST,
diana=AccessRoles.GUEST,
node=AccessRoles.GUEST,
llm=AccessRoles.GUEST)
token_name = token_name or kwargs.get("name") or \
datetime.fromtimestamp(creation_timestamp).isoformat()
access_token_data = HanaToken(iss=self._jwt_issuer,
sub=user_id,
exp=expiration_timestamp,
iat=creation_timestamp,
jti=token_id,
client_id=client_id,
roles=permissions.to_roles(),
purpose="access")
refresh_token_data = HanaToken(iss=self._jwt_issuer,
sub=user_id,
exp=refresh_expiration_timestamp,
iat=creation_timestamp,
jti=f"{token_id}.refresh",
client_id=client_id,
roles=permissions.to_roles(),
purpose="refresh")

token_expiration = encode_data['expire']
token = jwt.encode(encode_data, self._access_secret, self._jwt_algo)
encode_data['expire'] = round(time()) + self._refresh_token_lifetime
encode_data['access_token'] = token
refresh = jwt.encode(encode_data, self._refresh_secret, self._jwt_algo)
return TokenConfig(**{"username": encode_data['username'],
"client_id": encode_data['client_id'],
"permissions": encode_data['permissions'],
"access_token": token,
"refresh_token": refresh,
"expiration": token_expiration,
"refresh_expiration": encode_data['expire'],
"token_name": encode_data['name'],
"creation_timestamp": encode_data['create'],
"last_refresh_timestamp": encode_data['last_refresh_timestamp']
})

def get_permissions(self, client_id: str) -> ClientPermissions:
"""
Get ClientPermissions model for the given client_id
@param client_id: Client ID to get permissions for
@return: ClientPermissions object for the specified client
"""
if self._disable_auth:
LOG.debug("Auth disabled, allow full client permissions")
return ClientPermissions(assist=True, backend=True, node=True)
if client_id not in self.authorized_clients:
LOG.warning(f"{client_id} not known to this server")
return ClientPermissions(assist=False, backend=False, node=False)
client = self.authorized_clients[client_id]
return ClientPermissions(**client.get('permissions', dict()))
token_config = TokenConfig(token_name=token_name,
token_id=token_id,
user_id=user_id,
client_id=client_id,
permissions=permissions,
refresh_expiration_timestamp=refresh_expiration_timestamp,
creation_timestamp=creation_timestamp,
last_refresh_timestamp=creation_timestamp)
return access_token_data, refresh_token_data, token_config

def check_connect_stream(self) -> bool:
"""
Expand Down Expand Up @@ -145,7 +158,7 @@ def check_registration_request(self, username: str, password: str,
def check_auth_request(self, client_id: str, username: str,
password: Optional[str] = None,
token_name: Optional[str] = None,
origin_ip: str = "127.0.0.1") -> dict:
origin_ip: str = "127.0.0.1") -> AuthenticationResponse:
"""
Authenticate and Authorize a new client connection with the specified
username, password, and origin IP address.
Expand All @@ -156,9 +169,9 @@ def check_auth_request(self, client_id: str, username: str,
@param origin_ip: Origin IP address of request
@return: response tokens, permissions, and other metadata
"""
if client_id in self.authorized_clients:
print(f"Using cached client: {self.authorized_clients[client_id]}")
return self.authorized_clients[client_id]
# if client_id in self.authorized_clients:
# print(f"Using cached client: {self.authorized_clients[client_id]}")
# return self.authorized_clients[client_id]

ratelimit_id = f"auth{origin_ip}"
if not self.rate_limiter.get_all_buckets(ratelimit_id):
Expand All @@ -182,91 +195,93 @@ def check_auth_request(self, client_id: str, username: str,
user.permissions.node = AccessRoles.USER
else:
user = self._mq_connector.get_user_profile(username, password)
username = user.username

# Boolean permissions allow access for any role, including `NODE`.
# Specific endpoints may enforce more granular controls/limits based on
# specific user.permissions values.
permissions = ClientPermissions(
node=user.permissions.node != AccessRoles.NONE,
assist=user.permissions.core != AccessRoles.NONE,
backend=user.permissions.diana != AccessRoles.NONE)
create_time = round(time())
expiration = create_time + self._access_token_lifetime
encode_data = {"client_id": client_id,
"sub": username, # Added for Klat token compat.
"name": token_name,
"username": username,
"permissions": permissions.as_dict(),
"create": create_time,
"expire": expiration,
"user_id": user.user_id,
"permissions": user.permissions,
"token_name": token_name,
"last_refresh_timestamp": create_time}
auth = self._create_tokens(encode_data)
self._add_token_to_userdb(user, auth)
self.authorized_clients[client_id] = auth.model_dump()
return auth.model_dump()
access, refresh, config = self._create_tokens(**encode_data)
self.authorized_clients[client_id] = config
self._add_token_to_userdb(user, config)
return AuthenticationResponse(username=user.username,
client_id=client_id,
access_token=access,
refresh_token=refresh,
expiration=config.refresh_expiration_timestamp)

def check_refresh_request(self, access_token: str, refresh_token: str,
client_id: str):
client_id: str) -> AuthenticationResponse:
# Read and validate refresh token
try:
refresh_data = jwt.decode(refresh_token, self._refresh_secret,
self._jwt_algo)
refresh_data = HanaToken(**jwt.decode(refresh_token,
self._refresh_secret,
self._jwt_algo))
token_data = HanaToken(**jwt.decode(access_token,
self._access_secret,
self._jwt_algo))
except DecodeError:
raise HTTPException(status_code=400,
detail="Invalid refresh token supplied")
if refresh_data['access_token'] != access_token:
if refresh_data.jti != token_data.jti + ".refresh":
raise HTTPException(status_code=403,
detail="Refresh and access token mismatch")
if time() > refresh_data['expire']:
if time() > refresh_data.exp:
raise HTTPException(status_code=401,
detail="Refresh token is expired")
# Read access token and re-generate a new pair of tokens
# This is already known to be a valid token based on the refresh token
token_data = jwt.decode(access_token, self._access_secret,
self._jwt_algo)

if token_data['client_id'] != client_id:
if token_data.client_id != client_id:
raise HTTPException(status_code=403,
detail="Access token does not match client_id")
encode_data = {k: token_data[k] for k in
("client_id", "username", "password")}

# `token_name` is not known here, but it will be read from the database
# when the new token replaces the old one
encode_data = {"user_id": token_data.sub,
"client_id": client_id,
"permissions": PermissionsConfig.from_roles(token_data.roles)
}
if self._mq_connector:
user = self._mq_connector.get_user_profile(username=token_data['username'],
user = self._mq_connector.get_user_profile(username=token_data.sub,
access_token=refresh_token)
if not user.password_hash:
# This should not be possible, but don't let an error in the
# users service allow for injecting a new valid token to the db
raise HTTPException(status_code=500, detail="Error Fetching User")
refresh_time = round(time())
encode_data['last_refresh_timestamp'] = refresh_time
encode_data["expire"] = refresh_time + self._access_token_lifetime
new_auth = self._create_tokens(encode_data)
self._add_token_to_userdb(user, new_auth)
access, refresh, config = self._create_tokens(**encode_data)
username = user.username
self._add_token_to_userdb(user, config)
else:
new_auth = self._create_tokens(encode_data)
return new_auth.model_dump()
username = token_data.sub
access, refresh, config = self._create_tokens(**encode_data)
return AuthenticationResponse(username=username,
client_id=client_id,
access_token=access,
refresh_token=refresh,
expiration=config.refresh_expiration_timestamp)

def _add_token_to_userdb(self, user: User, token_data: TokenConfig):
def _add_token_to_userdb(self, user: User, new_token: TokenConfig):
if self._mq_connector is None:
print("No MQ Connection to a user database")
return
# Enforce unique `creation_timestamp` values to avoid duplicate entries
for idx, token in enumerate(user.tokens):
if token.creation_timestamp == token_data.creation_timestamp:
if token.token_id == new_token.token_id:
# Tokens don't contain `token_name`, so use the same one as is
# being replaced
new_token.token_name = token.token_name
user.tokens.remove(token)
user.tokens.append(token_data)
user.tokens.append(new_token)
self._mq_connector.update_user(user)

def get_client_id(self, token: str) -> str:
"""
Extract the client_id from a JWT token
@param token: JWT token to parse
Extract the client_id from a JWT string
@param token: JWT to parse
@return: client_id associated with token
"""
auth = jwt.decode(token, self._access_secret, self._jwt_algo)
return auth['client_id']
auth = HanaToken(**jwt.decode(token, self._access_secret,
self._jwt_algo))
return auth.client_id

def validate_auth(self, token: str, origin_ip: str) -> bool:
if not self.rate_limiter.get_all_buckets(origin_ip):
Expand All @@ -281,11 +296,12 @@ def validate_auth(self, token: str, origin_ip: str) -> bool:
if self._disable_auth:
return True
try:
auth = jwt.decode(token, self._access_secret, self._jwt_algo)
if auth['expire'] < time():
self.authorized_clients.pop(auth['client_id'], None)
auth = HanaToken(**jwt.decode(token, self._access_secret,
self._jwt_algo))
if auth.exp < time():
self.authorized_clients.pop(auth.client_id, None)
return False
self.authorized_clients[auth['client_id']] = auth
self.authorized_clients[auth.client_id] = auth
return True
except DecodeError:
# Invalid token supplied
Expand Down
43 changes: 0 additions & 43 deletions neon_hana/auth/permissions.py

This file was deleted.

Loading

0 comments on commit e5b3f7d

Please sign in to comment.