From 520d4d97aef6bc161f787024316009807fc1742a Mon Sep 17 00:00:00 2001 From: sarayourfriend <24264157+sarayourfriend@users.noreply.github.com> Date: Mon, 27 Nov 2023 08:56:02 +1100 Subject: [PATCH] Add middleware to log application name and verification status (#3369) * Add middleware to log application name and verification status Also add authorization, app name and verified status to nginx logs * Use clearer name for new middlware --- .../middleware/response_headers_middleware.py | 25 ++++++++ api/api/utils/oauth2_helper.py | 33 +++++++--- api/api/utils/throttle.py | 17 +++--- api/api/views/oauth2_views.py | 13 +++- api/conf/settings/base.py | 1 + api/nginx.conf.template | 8 ++- api/test/test_auth.py | 60 +++++++++++++++---- 7 files changed, 123 insertions(+), 34 deletions(-) create mode 100644 api/api/middleware/response_headers_middleware.py diff --git a/api/api/middleware/response_headers_middleware.py b/api/api/middleware/response_headers_middleware.py new file mode 100644 index 00000000000..76b21e4ae6a --- /dev/null +++ b/api/api/middleware/response_headers_middleware.py @@ -0,0 +1,25 @@ +from api.utils.oauth2_helper import get_token_info + + +def response_headers_middleware(get_response): + """ + Add standard response headers used by Nginx logging. + + These headers help Openverse more easily and directly connect + individual requests to each other. This is particularly useful + when evaluating traffic patterns from individual source IPs + to identify malicious requesters or request patterns. + """ + + def middleware(request): + response = get_response(request) + + if hasattr(request, "auth") and request.auth: + token_info = get_token_info(str(request.auth)) + if token_info: + response["x-ov-client-application-name"] = token_info.application_name + response["x-ov-client-application-verified"] = token_info.verified + + return response + + return middleware diff --git a/api/api/utils/oauth2_helper.py b/api/api/utils/oauth2_helper.py index c6df07202c4..6ffd0c4c0c9 100644 --- a/api/api/utils/oauth2_helper.py +++ b/api/api/utils/oauth2_helper.py @@ -1,5 +1,6 @@ import datetime as dt import logging +from dataclasses import dataclass from oauth2_provider.models import AccessToken @@ -8,10 +9,22 @@ parent_logger = logging.getLogger(__name__) -_no_result = (None, None, None) +@dataclass +class TokenInfo: + """Extracted ``models.ThrottledApplication`` metadata.""" -def get_token_info(token: str): + client_id: str + rate_limit_model: str + verified: bool + application_name: str + + @property + def valid(self): + return self.client_id and self.verified + + +def get_token_info(token: str) -> None | TokenInfo: """ Recover an OAuth2 application client ID and rate limit model from an access token. @@ -24,7 +37,7 @@ def get_token_info(token: str): try: token = AccessToken.objects.get(token=token) except AccessToken.DoesNotExist: - return _no_result + return None try: application = models.ThrottledApplication.objects.get(accesstoken=token) @@ -33,7 +46,7 @@ def get_token_info(token: str): # In practice should never occur so long as the preceding # operation to retrieve the access token was successful. logger.critical("Failed to find application associated with access token.") - return _no_result + return None expired = token.expires < dt.datetime.now(token.expires.tzinfo) if expired: @@ -42,9 +55,11 @@ def get_token_info(token: str): f"application.name={application.name} " f"application.client_id={application.client_id} " ) - return _no_result + return None - client_id = str(application.client_id) - rate_limit_model = application.rate_limit_model - verified = application.verified - return client_id, rate_limit_model, verified + return TokenInfo( + client_id=str(application.client_id), + rate_limit_model=application.rate_limit_model, + verified=application.verified, + application_name=application.name, + ) diff --git a/api/api/utils/throttle.py b/api/api/utils/throttle.py index 7fb19f7c337..855e4467a04 100644 --- a/api/api/utils/throttle.py +++ b/api/api/utils/throttle.py @@ -28,7 +28,6 @@ def headers(self): contains the limit and the number of requests left in the limit. Since multiple rate limits can apply concurrently, the suffix identifies each pair uniquely. """ - prefix = "X-RateLimit" suffix = self.scope or self.__class__.__name__.lower() if hasattr(self, "history"): @@ -53,8 +52,8 @@ def get_cache_key(self, request, view): logger = self.logger.getChild("get_cache_key") # Do not apply anonymous throttle to request with valid tokens. if request.auth: - client_id, _, verified = get_token_info(str(request.auth)) - if client_id and verified: + token_info = get_token_info(str(request.auth)) + if token_info and token_info.valid: return None ident = self.get_ident(request) @@ -113,14 +112,14 @@ class AbstractOAuth2IdRateThrottle(SimpleRateThrottleHeader, metaclass=abc.ABCMe def get_cache_key(self, request, view): # Find the client ID associated with the access token. auth = str(request.auth) - client_id, rate_limit_model, verified = get_token_info(auth) - if client_id and rate_limit_model == self.applies_to_rate_limit_model: - ident = client_id - else: - # Return None, fallback to the anonymous rate limiting + token_info = get_token_info(auth) + if not (token_info and token_info.valid): + return None + + if token_info.rate_limit_model != self.applies_to_rate_limit_model: return None - return self.cache_format % {"scope": self.scope, "ident": ident} + return self.cache_format % {"scope": self.scope, "ident": token_info.client_id} class OAuth2IdThumbnailRateThrottle(AbstractOAuth2IdRateThrottle): diff --git a/api/api/views/oauth2_views.py b/api/api/views/oauth2_views.py index 30864bf0490..dc72cbfd76b 100644 --- a/api/api/views/oauth2_views.py +++ b/api/api/views/oauth2_views.py @@ -181,12 +181,19 @@ def get(self, request, format=None): return Response(status=403, data="Forbidden") access_token = str(request.auth) - client_id, rate_limit_model, verified = get_token_info(access_token) + token_info = get_token_info(access_token) + + if not token_info: + # This shouldn't happen if `request.auth` was true above, + # but better safe than sorry + return Response(status=403, data="Forbidden") + + client_id = token_info.client_id if not client_id: return Response(status=403, data="Forbidden") - throttle_type = rate_limit_model + throttle_type = token_info.rate_limit_model throttle_key = "throttle_{scope}_{client_id}" if throttle_type == "standard": sustained_throttle_key = throttle_key.format( @@ -223,7 +230,7 @@ def get(self, request, format=None): "requests_this_minute": burst_requests, "requests_today": sustained_requests, "rate_limit_model": throttle_type, - "verified": verified, + "verified": token_info.verified, } ) return Response(status=200, data=response_data.data) diff --git a/api/conf/settings/base.py b/api/conf/settings/base.py index a547f28abab..de29e89340c 100644 --- a/api/conf/settings/base.py +++ b/api/conf/settings/base.py @@ -22,6 +22,7 @@ "django.contrib.auth.middleware.AuthenticationMiddleware", "django.contrib.messages.middleware.MessageMiddleware", "django.middleware.clickjacking.XFrameOptionsMiddleware", + "api.middleware.response_headers_middleware.response_headers_middleware", ] # Storage diff --git a/api/nginx.conf.template b/api/nginx.conf.template index a9edeb391d3..5e11959c844 100644 --- a/api/nginx.conf.template +++ b/api/nginx.conf.template @@ -17,10 +17,14 @@ log_format json_combined escape=json '"host_header": "$host",' '"body_bytes_sent":$body_bytes_sent,' '"request_time":"$request_time",' + '"upstream_response_time":$upstream_response_time,' '"http_referrer":"$http_referer",' '"http_user_agent":"$http_user_agent",' - '"upstream_response_time":$upstream_response_time,' - '"http_x_forwarded_for":"$http_x_forwarded_for"' + '"http_x_forwarded_for":"$http_x_forwarded_for",' + '"http_authorization":"$http_authorization",' + '"request_id":"$sent_http_x_request_id",' + '"client_application_name":"$sent_http_x_ov_client_application_name",' + '"client_application_verified":"$sent_http_x_ov_client_application_verified"' '}'; access_log /var/log/nginx/access.log json_combined; diff --git a/api/test/test_auth.py b/api/test/test_auth.py index c2dc6688a9e..80441b30c0f 100644 --- a/api/test/test_auth.py +++ b/api/test/test_auth.py @@ -62,23 +62,32 @@ def test_auth_token_exchange_unsupported_method(client): assert res.json()["detail"] == 'Method "GET" not allowed.' +def _integration_verify_most_recent_token(client): + verify = OAuth2Verification.objects.last() + code = verify.code + path = reverse("verify-email", args=[code]) + return client.get(path) + + @pytest.mark.django_db @pytest.mark.parametrize( "rate_limit_model", [x[0] for x in ThrottledApplication.RATE_LIMIT_MODELS], ) +@pytest.mark.skipif( + API_URL != "http://localhost:8000", + reason=( + "This test needs to cheat by looking in the database," + " so it needs to skip in non-local environments where" + " that isn't possible." + ), +) def test_auth_email_verification(client, rate_limit_model, test_auth_token_exchange): - # This test needs to cheat by looking in the database, so it will be - # skipped in non-local environments. - if API_URL == "http://localhost:8000": - verify = OAuth2Verification.objects.last() - code = verify.code - path = reverse("verify-email", args=[code]) - res = client.get(path) - assert res.status_code == 200 - test_auth_rate_limit_reporting( - client, rate_limit_model, test_auth_token_exchange, verified=True - ) + res = _integration_verify_most_recent_token(client) + assert res.status_code == 200 + test_auth_rate_limit_reporting( + client, rate_limit_model, test_auth_token_exchange, verified=True + ) @pytest.mark.django_db @@ -106,6 +115,35 @@ def test_auth_rate_limit_reporting( assert res_data["verified"] is False +@pytest.mark.django_db +@pytest.mark.parametrize( + "verified", + (True, False), +) +def test_auth_response_headers( + client, verified, test_auth_tokens_registration, test_auth_token_exchange +): + if verified: + _integration_verify_most_recent_token(client) + + token = test_auth_token_exchange["access_token"] + + res = client.get("/v1/images/", HTTP_AUTHORIZATION=f"Bearer {token}") + + assert ( + res.headers["x-ov-client-application-name"] + == test_auth_tokens_registration["name"] + ) + assert res.headers["x-ov-client-application-verified"] == str(verified) + + +def test_unauthed_response_headers(client): + res = client.get("/v1/images") + + assert "x-ov-client-application-name" not in res.headers + assert "x-ov-client-application-verified" not in res.headers + + @pytest.mark.django_db @pytest.mark.parametrize( "sort_dir, exp_indexed_on",