From 3d18bdb84e425dff8bc39f6233c2aa1cd4ebb0b7 Mon Sep 17 00:00:00 2001 From: Krystle Salazar Date: Mon, 11 Dec 2023 14:50:39 -0400 Subject: [PATCH] Move Oembed endpoint validation onto the serializer (#3069) Co-authored-by: Madison Swain-Bowden Co-authored-by: sarayourfriend --- api/api/serializers/image_serializers.py | 28 +++++++++++++++----- api/api/views/image_views.py | 5 +--- api/test/test_image_integration.py | 33 +++++------------------- api/test/unit/views/test_image_views.py | 4 ++- 4 files changed, 32 insertions(+), 38 deletions(-) diff --git a/api/api/serializers/image_serializers.py b/api/api/serializers/image_serializers.py index d483a5d5e2e..27dd63df907 100644 --- a/api/api/serializers/image_serializers.py +++ b/api/api/serializers/image_serializers.py @@ -1,6 +1,7 @@ from typing import Literal from uuid import UUID +from django.core.exceptions import ValidationError from rest_framework import serializers from api.constants.field_order import field_position_map @@ -16,7 +17,6 @@ get_hyperlinks_serializer, get_search_request_source_serializer, ) -from api.utils.url import add_protocol ####################### @@ -119,21 +119,35 @@ class Meta: class OembedRequestSerializer(serializers.Serializer): """Parse and validate oEmbed parameters.""" - url = serializers.CharField( + url = serializers.URLField( + allow_blank=False, help_text="The link to an image present in Openverse.", ) - @staticmethod - def validate_url(value): - url = add_protocol(value) + def to_internal_value(self, data): + data = super().to_internal_value(data) + + url = data["url"] if url.endswith("/"): url = url[:-1] identifier = url.rsplit("/", 1)[1] + try: uuid = UUID(identifier) except ValueError: - raise serializers.ValidationError("Could not parse identifier from URL.") - return uuid + raise serializers.ValidationError( + {"Could not parse identifier from URL.": data["url"]} + ) + + try: + image = Image.objects.get(identifier=uuid) + except (Image.DoesNotExist, ValidationError): + raise serializers.ValidationError( + {"Could not find image from the provided URL": data["url"]} + ) + + data["image"] = image + return data class OembedSerializer(BaseModelSerializer): diff --git a/api/api/views/image_views.py b/api/api/views/image_views.py index 955466bcf57..597fd84eafa 100644 --- a/api/api/views/image_views.py +++ b/api/api/views/image_views.py @@ -2,7 +2,6 @@ from django.conf import settings from django.http.response import FileResponse, HttpResponse -from django.shortcuts import get_object_or_404 from rest_framework.decorators import action from rest_framework.exceptions import NotFound from rest_framework.response import Response @@ -111,11 +110,9 @@ def oembed(self, request, *_, **__): params = OembedRequestSerializer(data=request.query_params) params.is_valid(raise_exception=True) - + image = params.validated_data["image"] context = self.get_serializer_context() - identifier = params.validated_data["url"] - image = get_object_or_404(Image, identifier=identifier) if not (image.height and image.width): image_file = requests.get(image.url, headers=self.OEMBED_HEADERS) width, height = PILImage.open(io.BytesIO(image_file.content)).size diff --git a/api/test/test_image_integration.py b/api/test/test_image_integration.py index 910633ad644..423db2534e1 100644 --- a/api/test/test_image_integration.py +++ b/api/test/test_image_integration.py @@ -85,40 +85,21 @@ def test_audio_report(image_fixture): report("images", image_fixture) -def test_oembed_endpoint_with_non_existent_image(): - params = { - "url": "https://any.domain/any/path/00000000-0000-0000-0000-000000000000", - } - response = requests.get( - f"{API_URL}/v1/images/oembed?{urlencode(params)}", verify=False - ) - assert response.status_code == 404 - - -def test_oembed_endpoint_with_bad_identifier(): - params = { - "url": "https://any.domain/any/path/not-a-valid-uuid", - } - response = requests.get( - f"{API_URL}/v1/images/oembed?{urlencode(params)}", verify=False - ) - assert response.status_code == 400 - - @pytest.mark.parametrize( - "url", + "url, expected_status_code", [ - f"https://any.domain/any/path/{identifier}", # no trailing slash - f"https://any.domain/any/path/{identifier}/", # trailing slash - identifier, # just identifier instead of URL + (f"https://any.domain/any/path/{identifier}", 200), # no trailing slash + (f"https://any.domain/any/path/{identifier}/", 200), # trailing slash + ("https://any.domain/any/path/00000000-0000-0000-0000-000000000000", 400), + ("https://any.domain/any/path/not-a-valid-uuid", 400), ], ) -def test_oembed_endpoint_with_fuzzy_input(url): +def test_oembed_endpoint(url, expected_status_code): params = {"url": url} response = requests.get( f"{API_URL}/v1/images/oembed?{urlencode(params)}", verify=False ) - assert response.status_code == 200 + assert response.status_code == expected_status_code def test_oembed_endpoint_for_json(): diff --git a/api/test/unit/views/test_image_views.py b/api/test/unit/views/test_image_views.py index dc258dfc246..3b919809c09 100644 --- a/api/test/unit/views/test_image_views.py +++ b/api/test/unit/views/test_image_views.py @@ -2,6 +2,7 @@ from collections.abc import Callable from dataclasses import dataclass from pathlib import Path +from test.constants import API_URL from test.factory.models.image import ImageFactory from unittest.mock import patch @@ -52,7 +53,8 @@ def requests_get(url, **kwargs): @pytest.mark.django_db def test_oembed_sends_ua_header(api_client, requests): image = ImageFactory.create() - res = api_client.get("/v1/images/oembed/", data={"url": f"/{image.identifier}"}) + image.url = f"https://any.domain/any/path/{image.identifier}" + res = api_client.get(f"{API_URL}/v1/images/oembed/", data={"url": image.url}) assert res.status_code == 200