Skip to content

Commit

Permalink
[Issue #3499] Make user token session typing cleaner (#3516)
Browse files Browse the repository at this point in the history
## Summary
Fixes #3516

### Time to review: __3 mins__

## Changes proposed
Add a utility method for fetching the user token session to handle type
checking issues.

## Context for reviewers
This removes the need to have `api_jwt_auth.current_user # type: ignore`
in every place we fetch the user.

Note that we have to do some sort of type ignore/cast because the type
is `None | Any` on `current_user` as it's just whatever you return from
the decode_token method.

## Additional information
All tests continue to pass which rely on fetching this
  • Loading branch information
chouinar authored Jan 14, 2025
1 parent d0f8b56 commit 158a768
Show file tree
Hide file tree
Showing 3 changed files with 27 additions and 9 deletions.
14 changes: 7 additions & 7 deletions api/src/api/users/user_routes.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ def login_result() -> flask.Response:
def user_token_logout(db_session: db.Session) -> response.ApiResponse:
logger.info("POST /v1/users/token/logout")

user_token_session: UserTokenSession = api_jwt_auth.current_user # type: ignore
user_token_session: UserTokenSession = api_jwt_auth.get_user_token_session()
with db_session.begin():
user_token_session.is_valid = False
db_session.add(user_token_session)
Expand All @@ -123,7 +123,7 @@ def user_token_logout(db_session: db.Session) -> response.ApiResponse:
def user_token_refresh(db_session: db.Session) -> response.ApiResponse:
logger.info("POST /v1/users/token/refresh")

user_token_session: UserTokenSession = api_jwt_auth.current_user # type: ignore
user_token_session: UserTokenSession = api_jwt_auth.get_user_token_session()

with db_session.begin():
refresh_token_expiration(user_token_session)
Expand All @@ -148,7 +148,7 @@ def user_token_refresh(db_session: db.Session) -> response.ApiResponse:
def user_get(db_session: db.Session, user_id: UUID) -> response.ApiResponse:
logger.info("GET /v1/users/:user_id")

user_token_session: UserTokenSession = api_jwt_auth.current_user # type: ignore
user_token_session: UserTokenSession = api_jwt_auth.get_user_token_session()

if user_token_session.user_id == user_id:
with db_session.begin():
Expand All @@ -170,7 +170,7 @@ def user_save_opportunity(
) -> response.ApiResponse:
logger.info("POST /v1/users/:user_id/saved-opportunities")

user_token_session: UserTokenSession = api_jwt_auth.current_user # type: ignore
user_token_session: UserTokenSession = api_jwt_auth.get_user_token_session()

# Verify the authenticated user matches the requested user_id
if user_token_session.user_id != user_id:
Expand Down Expand Up @@ -205,7 +205,7 @@ def user_delete_saved_opportunity(
) -> response.ApiResponse:
logger.info("DELETE /v1/users/:user_id/saved-opportunities/:opportunity_id")

user_token_session: UserTokenSession = api_jwt_auth.current_user # type: ignore
user_token_session: UserTokenSession = api_jwt_auth.get_user_token_session()

# Verify the authenticated user matches the requested user_id
if user_token_session.user_id != user_id:
Expand All @@ -226,7 +226,7 @@ def user_delete_saved_opportunity(
def user_get_saved_opportunities(db_session: db.Session, user_id: UUID) -> response.ApiResponse:
logger.info("GET /v1/users/:user_id/saved-opportunities")

user_token_session: UserTokenSession = api_jwt_auth.current_user # type: ignore
user_token_session: UserTokenSession = api_jwt_auth.get_user_token_session()

# Verify the authenticated user matches the requested user_id
if user_token_session.user_id != user_id:
Expand All @@ -249,7 +249,7 @@ def user_save_search(
) -> response.ApiResponse:
logger.info("POST /v1/users/:user_id/saved-searches")

user_token_session: UserTokenSession = api_jwt_auth.current_user # type: ignore
user_token_session: UserTokenSession = api_jwt_auth.get_user_token_session()

# Verify the authenticated user matches the requested user_id
if user_token_session.user_id != user_id:
Expand Down
6 changes: 4 additions & 2 deletions api/src/auth/api_jwt_auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
from typing import Tuple

import jwt
from apiflask import HTTPTokenAuth
from pydantic import Field
from sqlalchemy import select
from sqlalchemy.orm import selectinload
Expand All @@ -14,13 +13,16 @@
from src.adapters.db import flask_db
from src.api.route_utils import raise_flask_error
from src.auth.auth_errors import JwtValidationError
from src.auth.jwt_user_http_token_auth import JwtUserHttpTokenAuth
from src.db.models.user_models import User, UserTokenSession
from src.logging.flask_logger import add_extra_data_to_current_request_logs
from src.util.env_config import PydanticBaseEnvConfig

logger = logging.getLogger(__name__)

api_jwt_auth = HTTPTokenAuth("ApiKey", header="X-SGG-Token", security_scheme_name="ApiJwtAuth")
api_jwt_auth = JwtUserHttpTokenAuth(
"ApiKey", header="X-SGG-Token", security_scheme_name="ApiJwtAuth"
)


class ApiJwtConfig(PydanticBaseEnvConfig):
Expand Down
16 changes: 16 additions & 0 deletions api/src/auth/jwt_user_http_token_auth.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
from typing import cast

from apiflask import HTTPTokenAuth

from src.db.models.user_models import UserTokenSession


class JwtUserHttpTokenAuth(HTTPTokenAuth):

def get_user_token_session(self) -> UserTokenSession:
"""Wrapper method around the current_user value to handle type issues
Note that this value gets set based on whatever is returned from the method
you configure for @<your JwtUserHttpTokenAuth obj>.verify_token
"""
return cast(UserTokenSession, self.current_user)

0 comments on commit 158a768

Please sign in to comment.