Skip to content

Commit

Permalink
Move Oembed endpoint validation onto the serializer (#3069)
Browse files Browse the repository at this point in the history
Co-authored-by: Madison Swain-Bowden <[email protected]>
Co-authored-by: sarayourfriend <[email protected]>
  • Loading branch information
3 people authored Dec 11, 2023
1 parent 8b0ec49 commit 3d18bdb
Show file tree
Hide file tree
Showing 4 changed files with 32 additions and 38 deletions.
28 changes: 21 additions & 7 deletions api/api/serializers/image_serializers.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from typing import Literal
from uuid import UUID

from django.core.exceptions import ValidationError
from rest_framework import serializers

from api.constants.field_order import field_position_map
Expand All @@ -16,7 +17,6 @@
get_hyperlinks_serializer,
get_search_request_source_serializer,
)
from api.utils.url import add_protocol


#######################
Expand Down Expand Up @@ -119,21 +119,35 @@ class Meta:
class OembedRequestSerializer(serializers.Serializer):
"""Parse and validate oEmbed parameters."""

url = serializers.CharField(
url = serializers.URLField(
allow_blank=False,
help_text="The link to an image present in Openverse.",
)

@staticmethod
def validate_url(value):
url = add_protocol(value)
def to_internal_value(self, data):
data = super().to_internal_value(data)

url = data["url"]
if url.endswith("/"):
url = url[:-1]
identifier = url.rsplit("/", 1)[1]

try:
uuid = UUID(identifier)
except ValueError:
raise serializers.ValidationError("Could not parse identifier from URL.")
return uuid
raise serializers.ValidationError(
{"Could not parse identifier from URL.": data["url"]}
)

try:
image = Image.objects.get(identifier=uuid)
except (Image.DoesNotExist, ValidationError):
raise serializers.ValidationError(
{"Could not find image from the provided URL": data["url"]}
)

data["image"] = image
return data


class OembedSerializer(BaseModelSerializer):
Expand Down
5 changes: 1 addition & 4 deletions api/api/views/image_views.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@

from django.conf import settings
from django.http.response import FileResponse, HttpResponse
from django.shortcuts import get_object_or_404
from rest_framework.decorators import action
from rest_framework.exceptions import NotFound
from rest_framework.response import Response
Expand Down Expand Up @@ -111,11 +110,9 @@ def oembed(self, request, *_, **__):

params = OembedRequestSerializer(data=request.query_params)
params.is_valid(raise_exception=True)

image = params.validated_data["image"]
context = self.get_serializer_context()

identifier = params.validated_data["url"]
image = get_object_or_404(Image, identifier=identifier)
if not (image.height and image.width):
image_file = requests.get(image.url, headers=self.OEMBED_HEADERS)
width, height = PILImage.open(io.BytesIO(image_file.content)).size
Expand Down
33 changes: 7 additions & 26 deletions api/test/test_image_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,40 +85,21 @@ def test_audio_report(image_fixture):
report("images", image_fixture)


def test_oembed_endpoint_with_non_existent_image():
params = {
"url": "https://any.domain/any/path/00000000-0000-0000-0000-000000000000",
}
response = requests.get(
f"{API_URL}/v1/images/oembed?{urlencode(params)}", verify=False
)
assert response.status_code == 404


def test_oembed_endpoint_with_bad_identifier():
params = {
"url": "https://any.domain/any/path/not-a-valid-uuid",
}
response = requests.get(
f"{API_URL}/v1/images/oembed?{urlencode(params)}", verify=False
)
assert response.status_code == 400


@pytest.mark.parametrize(
"url",
"url, expected_status_code",
[
f"https://any.domain/any/path/{identifier}", # no trailing slash
f"https://any.domain/any/path/{identifier}/", # trailing slash
identifier, # just identifier instead of URL
(f"https://any.domain/any/path/{identifier}", 200), # no trailing slash
(f"https://any.domain/any/path/{identifier}/", 200), # trailing slash
("https://any.domain/any/path/00000000-0000-0000-0000-000000000000", 400),
("https://any.domain/any/path/not-a-valid-uuid", 400),
],
)
def test_oembed_endpoint_with_fuzzy_input(url):
def test_oembed_endpoint(url, expected_status_code):
params = {"url": url}
response = requests.get(
f"{API_URL}/v1/images/oembed?{urlencode(params)}", verify=False
)
assert response.status_code == 200
assert response.status_code == expected_status_code


def test_oembed_endpoint_for_json():
Expand Down
4 changes: 3 additions & 1 deletion api/test/unit/views/test_image_views.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from collections.abc import Callable
from dataclasses import dataclass
from pathlib import Path
from test.constants import API_URL
from test.factory.models.image import ImageFactory
from unittest.mock import patch

Expand Down Expand Up @@ -52,7 +53,8 @@ def requests_get(url, **kwargs):
@pytest.mark.django_db
def test_oembed_sends_ua_header(api_client, requests):
image = ImageFactory.create()
res = api_client.get("/v1/images/oembed/", data={"url": f"/{image.identifier}"})
image.url = f"https://any.domain/any/path/{image.identifier}"
res = api_client.get(f"{API_URL}/v1/images/oembed/", data={"url": image.url})

assert res.status_code == 200

Expand Down

0 comments on commit 3d18bdb

Please sign in to comment.