diff --git a/api/api/middleware/response_headers_middleware.py b/api/api/middleware/response_headers_middleware.py index 76b21e4ae6a..740c4650378 100644 --- a/api/api/middleware/response_headers_middleware.py +++ b/api/api/middleware/response_headers_middleware.py @@ -1,4 +1,6 @@ -from api.utils.oauth2_helper import get_token_info +from rest_framework.request import Request + +from api.models.oauth import ThrottledApplication def response_headers_middleware(get_response): @@ -11,14 +13,15 @@ def response_headers_middleware(get_response): to identify malicious requesters or request patterns. """ - def middleware(request): + def middleware(request: Request): response = get_response(request) - if hasattr(request, "auth") and request.auth: - token_info = get_token_info(str(request.auth)) - if token_info: - response["x-ov-client-application-name"] = token_info.application_name - response["x-ov-client-application-verified"] = token_info.verified + if not (hasattr(request, "auth") and hasattr(request.auth, "application")): + return response + + application: ThrottledApplication = request.auth.application + response["x-ov-client-application-name"] = application.name + response["x-ov-client-application-verified"] = application.verified return response diff --git a/api/api/serializers/media_serializers.py b/api/api/serializers/media_serializers.py index 961ce5b5027..1dd62e8b054 100644 --- a/api/api/serializers/media_serializers.py +++ b/api/api/serializers/media_serializers.py @@ -54,7 +54,7 @@ class PaginatedRequestSerializer(serializers.Serializer): def validate_page_size(self, value): request = self.context.get("request") - is_anonymous = bool(request and request.user and request.user.is_anonymous) + is_anonymous = getattr(request, "auth", None) is None max_value = ( settings.MAX_ANONYMOUS_PAGE_SIZE if is_anonymous @@ -247,7 +247,7 @@ class MediaSearchRequestSerializer(PaginatedRequestSerializer): def is_request_anonymous(self): request = self.context.get("request") - return bool(request and request.user and request.user.is_anonymous) + return getattr(request, "auth", None) is None @staticmethod def _truncate(value): diff --git a/api/api/utils/oauth2_helper.py b/api/api/utils/oauth2_helper.py deleted file mode 100644 index 6ffd0c4c0c9..00000000000 --- a/api/api/utils/oauth2_helper.py +++ /dev/null @@ -1,65 +0,0 @@ -import datetime as dt -import logging -from dataclasses import dataclass - -from oauth2_provider.models import AccessToken - -from api import models - - -parent_logger = logging.getLogger(__name__) - - -@dataclass -class TokenInfo: - """Extracted ``models.ThrottledApplication`` metadata.""" - - client_id: str - rate_limit_model: str - verified: bool - application_name: str - - @property - def valid(self): - return self.client_id and self.verified - - -def get_token_info(token: str) -> None | TokenInfo: - """ - Recover an OAuth2 application client ID and rate limit model from an access token. - - :param token: An OAuth2 access token. - :return: If the token is valid, return the client ID associated with the - token, rate limit model, and email verification status as a tuple; else - return ``(None, None, None)``. - """ - logger = parent_logger.getChild("get_token_info") - try: - token = AccessToken.objects.get(token=token) - except AccessToken.DoesNotExist: - return None - - try: - application = models.ThrottledApplication.objects.get(accesstoken=token) - except models.ThrottledApplication.DoesNotExist: - # Critical because it indicates a data integrity problem. - # In practice should never occur so long as the preceding - # operation to retrieve the access token was successful. - logger.critical("Failed to find application associated with access token.") - return None - - expired = token.expires < dt.datetime.now(token.expires.tzinfo) - if expired: - logger.info( - "rejected expired access token " - f"application.name={application.name} " - f"application.client_id={application.client_id} " - ) - return None - - return TokenInfo( - client_id=str(application.client_id), - rate_limit_model=application.rate_limit_model, - verified=application.verified, - application_name=application.name, - ) diff --git a/api/api/utils/throttle.py b/api/api/utils/throttle.py index 4e3effaab34..8d4f4b1a859 100644 --- a/api/api/utils/throttle.py +++ b/api/api/utils/throttle.py @@ -5,8 +5,6 @@ from redis.exceptions import ConnectionError -from api.utils.oauth2_helper import get_token_info - parent_logger = logging.getLogger(__name__) @@ -47,8 +45,11 @@ def has_valid_token(self, request): if not request.auth: return False - token_info = get_token_info(str(request.auth)) - return token_info and token_info.valid + application = getattr(request.auth, "application", None) + if application is None: + return False + + return application.client_id and application.verified def get_cache_key(self, request, view): return self.cache_format % { @@ -146,15 +147,16 @@ class AbstractOAuth2IdRateThrottle(SimpleRateThrottle, metaclass=abc.ABCMeta): def get_cache_key(self, request, view): # Find the client ID associated with the access token. - auth = str(request.auth) - token_info = get_token_info(auth) - if not (token_info and token_info.valid): + if not self.has_valid_token(request): return None - if token_info.rate_limit_model not in self.applies_to_rate_limit_model: + # `self.has_valid_token` call earlier ensures accessing `application` will not fail + application = request.auth.application + + if application.rate_limit_model not in self.applies_to_rate_limit_model: return None - return self.cache_format % {"scope": self.scope, "ident": token_info.client_id} + return self.cache_format % {"scope": self.scope, "ident": application.client_id} class OAuth2IdThumbnailRateThrottle(AbstractOAuth2IdRateThrottle): diff --git a/api/api/views/oauth2_views.py b/api/api/views/oauth2_views.py index db4d780af72..b8511ae03ad 100644 --- a/api/api/views/oauth2_views.py +++ b/api/api/views/oauth2_views.py @@ -7,6 +7,7 @@ from django.conf import settings from django.core.cache import cache from django.core.mail import send_mail +from rest_framework.request import Request from rest_framework.response import Response from rest_framework.reverse import reverse from rest_framework.views import APIView @@ -22,7 +23,6 @@ OAuth2KeyInfoSerializer, OAuth2RegistrationSerializer, ) -from api.utils.oauth2_helper import get_token_info from api.utils.throttle import OnePerSecond, TenPerDay @@ -169,7 +169,7 @@ class CheckRates(APIView): throttle_classes = (OnePerSecond,) @key_info - def get(self, request, format=None): + def get(self, request: Request, format=None): """ Get information about your API key. @@ -181,23 +181,17 @@ def get(self, request, format=None): """ # TODO: Replace 403 responses with DRF `authentication_classes`. - if not request.auth: + if not request.auth or not hasattr(request.auth, "application"): return Response(status=403, data="Forbidden") - access_token = str(request.auth) - token_info = get_token_info(access_token) + application: ThrottledApplication = request.auth.application - if not token_info: - # This shouldn't happen if `request.auth` was true above, - # but better safe than sorry - return Response(status=403, data="Forbidden") - - client_id = token_info.client_id + client_id = application.client_id if not client_id: return Response(status=403, data="Forbidden") - throttle_type = token_info.rate_limit_model + throttle_type = application.rate_limit_model throttle_key = "throttle_{scope}_{client_id}" if throttle_type == "standard": sustained_throttle_key = throttle_key.format( @@ -242,7 +236,7 @@ def get(self, request, format=None): "requests_this_minute": burst_requests, "requests_today": sustained_requests, "rate_limit_model": throttle_type, - "verified": token_info.verified, + "verified": application.verified, } ) return Response(status=status, data=response_data.data) diff --git a/api/test/unit/controllers/test_search_controller_search_query.py b/api/test/unit/controllers/test_search_controller_search_query.py index 3ccd56daded..081bd356c27 100644 --- a/api/test/unit/controllers/test_search_controller_search_query.py +++ b/api/test/unit/controllers/test_search_controller_search_query.py @@ -172,7 +172,6 @@ def test_create_search_query_q_search_with_filters(image_media_type_config): } }, {"rank_feature": {"boost": 10000, "field": "standardized_popularity"}}, - {"rank_feature": {"boost": 25000, "field": "authority_boost"}}, ], } diff --git a/api/test/unit/utils/test_throttle.py b/api/test/unit/utils/test_throttle.py index 1afdaff5325..a727ea2b836 100644 --- a/api/test/unit/utils/test_throttle.py +++ b/api/test/unit/utils/test_throttle.py @@ -71,6 +71,7 @@ def enable_throttles(settings): def access_token(): token = AccessTokenFactory.create() token.application.verified = True + token.application.client_id = 123 token.application.save() return token