Skip to content

Commit

Permalink
Respond with 401 for requests with bad credentials (#4126)
Browse files Browse the repository at this point in the history
* Respond with 401 for requests with bad credentials

* Try fixing documentation errors

* Fix schema test failures
  • Loading branch information
sarayourfriend authored Apr 30, 2024
1 parent a85356b commit 575f529
Show file tree
Hide file tree
Showing 11 changed files with 183 additions and 34 deletions.
15 changes: 13 additions & 2 deletions api/api/docs/audio_docs.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from rest_framework.exceptions import (
AuthenticationFailed,
NotAuthenticated,
NotFound,
ValidationError,
Expand Down Expand Up @@ -75,7 +76,10 @@
By using this endpoint, you can obtain info about content providers such
as {fields_to_md(ProviderSerializer.Meta.fields)}.""",
res={200: (ProviderSerializer(many=True), audio_stats_200_example)},
res={
200: (ProviderSerializer(many=True), audio_stats_200_example),
401: (AuthenticationFailed, None),
},
eg=[audio_stats_curl],
)

Expand All @@ -87,6 +91,7 @@
{fields_to_md(AudioSerializer.Meta.fields)}""",
res={
200: (AudioSerializer, audio_detail_200_example),
401: (AuthenticationFailed, None),
404: (NotFound, audio_detail_404_example),
},
eg=[audio_detail_curl],
Expand All @@ -100,6 +105,7 @@
{fields_to_md(AudioSerializer.Meta.fields)}.""",
res={
200: (AudioSerializer(many=True), audio_related_200_example),
401: (AuthenticationFailed, None),
404: (NotFound, audio_related_404_example),
},
eg=[audio_related_curl],
Expand All @@ -109,18 +115,23 @@
res={
201: (AudioReportRequestSerializer, audio_complain_201_example),
400: (ValidationError, None),
401: (AuthenticationFailed, None),
},
eg=[audio_complain_curl],
)

thumbnail = extend_schema(
parameters=[MediaThumbnailRequestSerializer],
responses={200: OpenApiResponse(description="Thumbnail image")},
responses={
200: OpenApiResponse(description="Thumbnail image"),
401: AuthenticationFailed,
},
)

waveform = custom_extend_schema(
res={
200: (AudioWaveformSerializer, audio_waveform_200_example),
401: (AuthenticationFailed, None),
404: (NotFound, audio_waveform_404_example),
},
eg=[audio_waveform_curl],
Expand Down
82 changes: 75 additions & 7 deletions api/api/docs/base_docs.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,12 @@

from django.conf import settings
from rest_framework.exceptions import (
APIException,
NotFound,
ValidationError,
)

from drf_spectacular.extensions import OpenApiSerializerExtension
from drf_spectacular.openapi import AutoSchema
from drf_spectacular.utils import (
OpenApiExample,
Expand All @@ -31,6 +34,77 @@ def fields_to_md(field_names):
return f"{all_but_last} and `{last}`"


class APIExceptionOpenApiSerializerExtension(OpenApiSerializerExtension):
target_class = APIException
match_subclasses = True

@classmethod
def _get_detail(cls, target):
return getattr(target, "detail", target.default_detail)

def get_name(self, *args):
cls = self.target if isinstance(self.target, type) else self.target.__class__
return cls.__name__

def map_serializer(self, *args):
cls = self.target if isinstance(self.target, type) else self.target.__class__

detail_string = {
"type": "string",
"description": "A description of what went wrong.",
}

if cls == ValidationError or issubclass(cls, ValidationError):
return {
"title": "ValidationError",
"type": "object",
"properties": {
"detail": {
"oneOf": [
detail_string,
{
"type": "object",
"additionalProperties": True,
},
]
}
},
}

return {
"title": cls.__name__,
"type": "object",
"properties": {"detail": detail_string},
}

@classmethod
def exception_example(cls, exception):
if exception == ValidationError:
return {"detail": {"<request parameter>": "<error details>"}}

return {"detail": cls._get_detail(exception)}


def get_examples(code, serializer, example):
if (
not example
and isinstance(serializer, type)
and issubclass(serializer, APIException)
):
example = APIExceptionOpenApiSerializerExtension.exception_example(serializer)
elif example:
example = example["application/json"]
else:
return []

return [
OpenApiExample(
http_responses[code],
value=example,
)
]


def custom_extend_schema(**kwargs):
extend_args = {}

Expand All @@ -51,13 +125,7 @@ def custom_extend_schema(**kwargs):
code: OpenApiResponse(
serializer,
description=http_responses[code],
examples=[
OpenApiExample(
http_responses[code], value=example["application/json"]
)
]
if example
else [],
examples=get_examples(code, serializer, example),
)
for code, (serializer, example) in responses.items()
}
Expand Down
26 changes: 22 additions & 4 deletions api/api/docs/image_docs.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from rest_framework.exceptions import (
AuthenticationFailed,
NotAuthenticated,
NotFound,
ValidationError,
Expand Down Expand Up @@ -78,7 +79,10 @@
By using this endpoint, you can obtain info about content providers such
as {fields_to_md(ProviderSerializer.Meta.fields)}.""",
res={200: (ProviderSerializer(many=True), image_stats_200_example)},
res={
200: (ProviderSerializer(many=True), image_stats_200_example),
401: (AuthenticationFailed, None),
},
eg=[image_stats_curl],
)

Expand All @@ -90,6 +94,7 @@
{fields_to_md(ImageSerializer.Meta.fields)}""",
res={
200: (ImageSerializer, image_detail_200_example),
401: (AuthenticationFailed, None),
404: (NotFound, image_detail_404_example),
},
eg=[image_detail_curl],
Expand All @@ -103,6 +108,7 @@
{fields_to_md(ImageSerializer.Meta.fields)}.""",
res={
200: (ImageSerializer, image_related_200_example),
401: (AuthenticationFailed, None),
404: (NotFound, image_related_404_example),
},
eg=[image_related_curl],
Expand All @@ -111,24 +117,36 @@
report = custom_extend_schema(
res={
201: (ImageReportRequestSerializer, image_complain_201_example),
401: (AuthenticationFailed, None),
400: (ValidationError, None),
},
eg=[image_complain_curl],
)

thumbnail = extend_schema(
parameters=[MediaThumbnailRequestSerializer],
responses={200: OpenApiResponse(description="Thumbnail image"), 404: NotFound},
responses={
200: OpenApiResponse(description="Thumbnail image"),
404: NotFound,
401: AuthenticationFailed,
},
)

oembed = custom_extend_schema(
params=OembedRequestSerializer,
res={
200: (OembedSerializer, image_oembed_200_example),
404: (NotFound, image_oembed_404_example),
400: (ValidationError, image_oembed_400_example),
401: (AuthenticationFailed, None),
404: (NotFound, image_oembed_404_example),
},
eg=[image_oembed_curl],
)

watermark = extend_schema(deprecated=True, responses={404: NotFound})
watermark = extend_schema(
deprecated=True,
responses={
401: AuthenticationFailed,
404: NotFound,
},
)
5 changes: 2 additions & 3 deletions api/api/docs/oauth2_docs.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,12 @@
from rest_framework.exceptions import (
APIException,
NotAuthenticated,
PermissionDenied,
ValidationError,
)

from api.docs.base_docs import custom_extend_schema
from api.examples import (
auth_key_info_200_example,
auth_key_info_403_example,
auth_key_info_curl,
auth_register_201_example,
auth_register_curl,
Expand All @@ -30,6 +28,7 @@
res={
201: (OAuth2ApplicationSerializer, auth_register_201_example),
400: (ValidationError, None),
401: ({"type": "object", "properties": {"error": {"type": "string"}}}, None),
429: (
APIException("Request was throttled. Expected available in 1 second.", 429),
None,
Expand All @@ -42,7 +41,7 @@
operation_id="key_info",
res={
200: (OAuth2KeyInfoSerializer, auth_key_info_200_example),
403: (PermissionDenied, auth_key_info_403_example),
401: (NotAuthenticated, None),
429: (
APIException("Request was throttled. Expected available in 1 second.", 429),
None,
Expand Down
1 change: 0 additions & 1 deletion api/api/examples/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,6 @@
)
from api.examples.oauth2_responses import (
auth_key_info_200_example,
auth_key_info_403_example,
auth_register_201_example,
auth_token_200_example,
)
Expand Down
2 changes: 0 additions & 2 deletions api/api/examples/oauth2_responses.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,5 +22,3 @@
"rate_limit_model": "enhanced",
}
}

auth_key_info_403_example = {"application/json": "Forbidden"}
23 changes: 12 additions & 11 deletions api/api/views/oauth2_views.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,14 @@
from django.core.cache import cache
from django.core.mail import send_mail
from django.db import DataError
from rest_framework.exceptions import APIException, PermissionDenied
from rest_framework.exceptions import APIException
from rest_framework.request import Request
from rest_framework.response import Response
from rest_framework.reverse import reverse
from rest_framework.views import APIView

from drf_spectacular.utils import extend_schema
from oauth2_provider.contrib.rest_framework.permissions import TokenHasScope
from oauth2_provider.generators import generate_client_secret
from oauth2_provider.views import TokenView as BaseTokenView
from redis.exceptions import ConnectionError
Expand All @@ -40,6 +41,8 @@ class InvalidCredentials(APIException):
@extend_schema(tags=["auth"])
class Register(APIView):
throttle_classes = (TenPerDay,)
# Registration implicitly does not require authentication
authentication_classes = ()

@register
def post(self, request, format=None):
Expand Down Expand Up @@ -150,6 +153,10 @@ def get(self, request, code, format=None):

@extend_schema(tags=["auth"])
class TokenView(APIView, BaseTokenView):
# Token view is pre-authentication
authentication_classes = ()
permission_classes = ()

@token
def post(self, request):
"""
Expand Down Expand Up @@ -178,6 +185,8 @@ def post(self, request):
@extend_schema(tags=["auth"])
class CheckRates(APIView):
throttle_classes = (OnePerSecond,)
permission_classes = (TokenHasScope,)
required_scopes = ("read",)

@key_info
def get(self, request: Request, format=None):
Expand All @@ -187,21 +196,13 @@ def get(self, request: Request, format=None):
You can use this endpoint to get information about your API key such as
`requests_this_minute`, `requests_today`, and `rate_limit_model`.
> ℹ️ **NOTE:** If you get a 403 Forbidden response, it means your access
> token has expired.
> ℹ️ **NOTE:** If you get a 401 Unauthorized, it means your token is invalid
> (malformed, non-existent, or expired).
"""

# TODO: Replace 403 responses with DRF `authentication_classes`.
if not request.auth or not hasattr(request.auth, "application"):
raise PermissionDenied("Forbidden", 403)

application: ThrottledApplication = request.auth.application

client_id = application.client_id

if not client_id:
raise PermissionDenied("Forbidden", 403)

throttle_type = application.rate_limit_model
throttle_key = "throttle_{scope}_{client_id}"
if throttle_type == "standard":
Expand Down
27 changes: 27 additions & 0 deletions api/conf/oauth2_extensions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
from rest_framework.exceptions import AuthenticationFailed

from drf_spectacular.authentication import TokenScheme
from oauth2_provider.contrib.rest_framework import (
OAuth2Authentication as BaseOAuth2Authentication,
)


class OAuth2Authentication(BaseOAuth2Authentication):
# Required by schema extension
keyword = "Bearer"

def authenticate(self, request):
result = super().authenticate(request)
if getattr(request, "oauth2_error", None):
# oauth2_error is only defined on requests that had errors
# it will be undefined or empty for anonymous requests and
# requests with valid credentials
# `request` is mutated by `super().authenticate`
raise AuthenticationFailed()

return result


class OAuth2OpenApiAuthenticationExtension(TokenScheme):
target_class = "conf.oauth2_extensions.OAuth2Authentication"
name = "Openverse API Token"
4 changes: 1 addition & 3 deletions api/conf/settings/rest_framework.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,9 +53,7 @@
)

REST_FRAMEWORK = {
"DEFAULT_AUTHENTICATION_CLASSES": (
"oauth2_provider.contrib.rest_framework.OAuth2Authentication",
),
"DEFAULT_AUTHENTICATION_CLASSES": ("conf.oauth2_extensions.OAuth2Authentication",),
"DEFAULT_VERSIONING_CLASS": "rest_framework.versioning.URLPathVersioning",
"DEFAULT_RENDERER_CLASSES": (
"rest_framework.renderers.JSONRenderer",
Expand Down
Loading

0 comments on commit 575f529

Please sign in to comment.