From b91cdc313977331e9462dfb6a322557f221e564b Mon Sep 17 00:00:00 2001 From: sarayourfriend <24264157+sarayourfriend@users.noreply.github.com> Date: Wed, 22 Nov 2023 11:47:14 +1100 Subject: [PATCH] Add ADRF and make the thumbnail view async (#3020) * Add ADRF and make the thumbnail view async This does not convert the `image_proxy.get` to async to avoid making this initial PR more complicated than it needs to be, and to avoid being blocked on the aiohttp client session sharing which we will want to have in before we convert the image proxy to use aiohttp. * Fix bad merge * Fix bad merge * Fix docstrings on api docs * Remove unnecessary thread sensitivity * Fix bad subclassing of thumbnails endpoints * Use simpler/clearer names for image proxy config objects --- api/Pipfile | 5 +- api/Pipfile.lock | 17 +++++- api/api/utils/image_proxy/__init__.py | 45 +++++++++----- api/api/views/audio_views.py | 38 ++++++------ api/api/views/image_views.py | 35 +++++------ api/api/views/media_views.py | 43 ++++++++++--- api/test/unit/utils/test_image_proxy.py | 81 +++++++++++++++---------- api/test/unit/views/test_image_views.py | 38 ++++++++---- 8 files changed, 193 insertions(+), 109 deletions(-) diff --git a/api/Pipfile b/api/Pipfile index d824cdb4cd5..e2d8e9a0430 100644 --- a/api/Pipfile +++ b/api/Pipfile @@ -18,6 +18,7 @@ pytest-sugar = "~=0.9" pook = {ref = "master", git = "git+https://github.com/h2non/pook.git"} [packages] +adrf = "~=0.1.2" aiohttp = "~=3.8" aws-requests-auth = "~=0.4" deepdiff = "~=6.4" @@ -26,6 +27,7 @@ django-cors-headers = "~=4.2" django-log-request-id = "~=2.0" django-oauth-toolkit = "~=2.3" django-redis = "~=5.4" +django-split-settings = "*" django-tqdm = "~=1.3" django-uuslug = "~=2.0" djangorestframework = "~=3.14" @@ -35,11 +37,10 @@ elasticsearch-dsl = "~=8.9" future = "~=0.18" limit = "~=0.2" Pillow = "~=10.1.0" +psycopg = "~=3.1" python-decouple = "~=3.8" python-xmp-toolkit = "~=2.0" sentry-sdk = "~=1.30" -django-split-settings = "*" -psycopg = "~=3.1" uvicorn = {extras = ["standard"], version = "~=0.23"} [requires] diff --git a/api/Pipfile.lock b/api/Pipfile.lock index 0c13736001c..7a2ade8cfef 100644 --- a/api/Pipfile.lock +++ b/api/Pipfile.lock @@ -1,7 +1,7 @@ { "_meta": { "hash": { - "sha256": "7e9adb0878e2d7523c7457b3712adde0aee93b6586d54eef2663d09e145a1b3e" + "sha256": "54293a6311c5ebb7d16bf6bb13d9a63f420b9275855cba28076d08f125ae95c2" }, "pipfile-spec": 6, "requires": { @@ -16,6 +16,14 @@ ] }, "default": { + "adrf": { + "hashes": [ + "sha256:a33f8f51f0f80072ffb2af061df1fb119bc00adaa720a2972049d4aa33155337", + "sha256:ce7160878ba27999d333752941cde0687c1a205fc26fa0eda1bad3924958dc69" + ], + "index": "pypi", + "version": "==0.1.2" + }, "aiohttp": { "hashes": [ "sha256:002f23e6ea8d3dd8d149e569fd580c999232b5fbc601c48d55398fbc2e582e8c", @@ -133,6 +141,13 @@ "markers": "python_version >= '3.7'", "version": "==3.7.2" }, + "async-property": { + "hashes": [ + "sha256:17d9bd6ca67e27915a75d92549df64b5c7174e9dc806b30a3934dc4ff0506380", + "sha256:8924d792b5843994537f8ed411165700b27b2bd966cefc4daeefc1253442a9d7" + ], + "version": "==0.2.2" + }, "async-timeout": { "hashes": [ "sha256:4640d96be84d82d02ed59ea2b7105a0f7b33abe8703703cd0ab0bf87c427522f", diff --git a/api/api/utils/image_proxy/__init__.py b/api/api/utils/image_proxy/__init__.py index 6a591a0c61f..bcfb44db6aa 100644 --- a/api/api/utils/image_proxy/__init__.py +++ b/api/api/utils/image_proxy/__init__.py @@ -1,4 +1,5 @@ import logging +from dataclasses import dataclass from typing import Literal from urllib.parse import urlparse @@ -33,13 +34,26 @@ THUMBNAIL_STRATEGY = Literal["photon_proxy", "original"] +@dataclass +class MediaInfo: + media_provider: str + media_identifier: str + image_url: str + + +@dataclass +class RequestConfig: + accept_header: str = "image/*" + is_full_size: bool = False + is_compressed: bool = True + + def get_request_params_for_extension( ext: str, headers: dict[str, str], image_url: str, parsed_image_url: urlparse, - is_full_size: bool, - is_compressed: bool, + request_config: RequestConfig, ) -> tuple[str, dict[str, str], dict[str, str]]: """ Get the request params (url, params, headers) for the thumbnail proxy. @@ -49,7 +63,10 @@ def get_request_params_for_extension( """ if ext in PHOTON_TYPES: return get_photon_request_params( - parsed_image_url, is_full_size, is_compressed, headers + parsed_image_url, + request_config.is_full_size, + request_config.is_compressed, + headers, ) elif ext in ORIGINAL_TYPES: return image_url, {}, headers @@ -59,24 +76,23 @@ def get_request_params_for_extension( def get( - image_url: str, - media_identifier: str, - media_provider: str, - accept_header: str = "image/*", - is_full_size: bool = False, - is_compressed: bool = True, + media_info: MediaInfo, + request_config: RequestConfig = RequestConfig(), ) -> HttpResponse: """ Proxy an image through Photon if its file type is supported, else return the original image if the file type is SVG. Otherwise, raise an exception. """ + image_url = media_info.image_url + media_identifier = media_info.media_identifier + logger = parent_logger.getChild("get") tallies = django_redis.get_redis_connection("tallies") month = get_monthly_timestamp() image_extension = get_image_extension(image_url, media_identifier) - headers = {"Accept": accept_header} | HEADERS + headers = {"Accept": request_config.accept_header} | HEADERS parsed_image_url = urlparse(image_url) domain = parsed_image_url.netloc @@ -86,8 +102,7 @@ def get( headers, image_url, parsed_image_url, - is_full_size, - is_compressed, + request_config, ) try: @@ -103,7 +118,7 @@ def get( f"{month}:{upstream_response.status_code}" ) tallies.incr( - f"thumbnail_response_code_by_provider:{media_provider}:" + f"thumbnail_response_code_by_provider:{media_info.media_provider}:" f"{month}:{upstream_response.status_code}" ) upstream_response.raise_for_status() @@ -133,7 +148,9 @@ def get( f"thumbnail_http_error:{domain}:{month}:{code}:{exc.response.text}" ) logger.warning( - f"Failed to render thumbnail {upstream_url=} {code=} {media_provider=}" + f"Failed to render thumbnail " + f"{upstream_url=} {code=} " + f"{media_info.media_provider=}" ) raise UpstreamThumbnailException(f"Failed to render thumbnail. {exc}") diff --git a/api/api/views/audio_views.py b/api/api/views/audio_views.py index 6457b233a09..dcde7266b92 100644 --- a/api/api/views/audio_views.py +++ b/api/api/views/audio_views.py @@ -15,9 +15,9 @@ source_collection, stats, tag_collection, - thumbnail, - waveform, ) +from api.docs.audio_docs import thumbnail as thumbnail_docs +from api.docs.audio_docs import waveform from api.models import Audio from api.serializers.audio_serializers import ( AudioCollectionRequestSerializer, @@ -26,7 +26,7 @@ AudioSerializer, AudioWaveformSerializer, ) -from api.serializers.media_serializers import MediaThumbnailRequestSerializer +from api.utils import image_proxy from api.utils.throttle import AnonThumbnailRateThrottle, OAuth2IdThumbnailRateThrottle from api.views.media_views import MediaViewSet @@ -80,21 +80,8 @@ def source_collection(self, request, source): def tag_collection(self, request, tag, *_, **__): return super().tag_collection(request, tag, *_, **__) - @thumbnail - @action( - detail=True, - url_path="thumb", - url_name="thumb", - serializer_class=MediaThumbnailRequestSerializer, - throttle_classes=[AnonThumbnailRateThrottle, OAuth2IdThumbnailRateThrottle], - ) - def thumbnail(self, request, *_, **__): - """ - Retrieve the scaled down and compressed thumbnail of the artwork of an - audio track or its audio set. - """ - - audio = self.get_object() + async def get_image_proxy_media_info(self) -> image_proxy.MediaInfo: + audio = await self.aget_object() image_url = None if audio_thumbnail := audio.thumbnail: @@ -104,7 +91,20 @@ def thumbnail(self, request, *_, **__): if not image_url: raise NotFound("Could not find artwork.") - return super().thumbnail(request, audio, image_url) + return image_proxy.MediaInfo( + media_identifier=audio.identifier, + media_provider=audio.provider, + image_url=image_url, + ) + + @thumbnail_docs + @MediaViewSet.thumbnail_action + async def thumbnail(self, *args, **kwargs): + """ + Retrieve the scaled down and compressed thumbnail of the artwork of an + audio track or its audio set. + """ + return await super().thumbnail(*args, **kwargs) @waveform @action( diff --git a/api/api/views/image_views.py b/api/api/views/image_views.py index 641a230f384..955466bcf57 100644 --- a/api/api/views/image_views.py +++ b/api/api/views/image_views.py @@ -22,8 +22,8 @@ source_collection, stats, tag_collection, - thumbnail, ) +from api.docs.image_docs import thumbnail as thumbnail_docs from api.docs.image_docs import watermark as watermark_doc from api.models import Image from api.serializers.image_serializers import ( @@ -34,11 +34,8 @@ OembedSerializer, WatermarkRequestSerializer, ) -from api.serializers.media_serializers import ( - MediaThumbnailRequestSerializer, - PaginatedRequestSerializer, -) -from api.utils.throttle import AnonThumbnailRateThrottle, OAuth2IdThumbnailRateThrottle +from api.serializers.media_serializers import PaginatedRequestSerializer +from api.utils import image_proxy from api.utils.watermark import watermark from api.views.media_views import MediaViewSet @@ -130,25 +127,25 @@ def oembed(self, request, *_, **__): serializer = self.get_serializer(image, context=context) return Response(data=serializer.data) - @thumbnail - @action( - detail=True, - url_path="thumb", - url_name="thumb", - serializer_class=MediaThumbnailRequestSerializer, - throttle_classes=[AnonThumbnailRateThrottle, OAuth2IdThumbnailRateThrottle], - ) - def thumbnail(self, request, *_, **__): - """Retrieve the scaled down and compressed thumbnail of the image.""" - - image = self.get_object() + async def get_image_proxy_media_info(self) -> image_proxy.MediaInfo: + image = await self.aget_object() image_url = image.url # Hotfix to use thumbnails for SMK images # TODO: Remove when small thumbnail issues are resolved if "iip.smk.dk" in image_url and image.thumbnail: image_url = image.thumbnail - return super().thumbnail(request, image, image_url) + return image_proxy.MediaInfo( + media_identifier=image.identifier, + media_provider=image.provider, + image_url=image_url, + ) + + @thumbnail_docs + @MediaViewSet.thumbnail_action + async def thumbnail(self, *args, **kwargs): + """Retrieve the scaled down and compressed thumbnail of the image.""" + return await super().thumbnail(*args, **kwargs) @watermark_doc @action(detail=True, url_path="watermark", url_name="watermark") diff --git a/api/api/views/media_views.py b/api/api/views/media_views.py index ecf6005f410..d003feec9fd 100644 --- a/api/api/views/media_views.py +++ b/api/api/views/media_views.py @@ -7,6 +7,10 @@ from rest_framework.response import Response from rest_framework.viewsets import ReadOnlyModelViewSet +from adrf.views import APIView as AsyncAPIView +from adrf.viewsets import ViewSetMixin as AsyncViewSetMixin +from asgiref.sync import sync_to_async + from api.constants.media_types import MediaType from api.constants.search import SearchStrategy from api.controllers import search_controller @@ -18,6 +22,7 @@ from api.utils import image_proxy from api.utils.pagination import StandardPagination from api.utils.search_context import SearchContext +from api.utils.throttle import AnonThumbnailRateThrottle, OAuth2IdThumbnailRateThrottle logger = logging.getLogger(__name__) @@ -35,7 +40,12 @@ class InvalidSource(APIException): default_code = "invalid_source" -class MediaViewSet(ReadOnlyModelViewSet): +image_proxy_aget = sync_to_async(image_proxy.get) + + +class MediaViewSet(AsyncViewSetMixin, AsyncAPIView, ReadOnlyModelViewSet): + view_is_async = True + lookup_field = "identifier" # TODO: https://github.com/encode/django-rest-framework/pull/6789 lookup_value_regex = ( @@ -79,6 +89,8 @@ def get_queryset(self): ).values_list("provider_identifier") ) + aget_object = sync_to_async(ReadOnlyModelViewSet.get_object) + def get_serializer_context(self): context = super().get_serializer_context() req_serializer = self._get_request_serializer(self.request) @@ -265,16 +277,31 @@ def report(self, request, identifier): return Response(data=serializer.data, status=status.HTTP_201_CREATED) - def thumbnail(self, request, media_obj, image_url): + async def get_image_proxy_media_info(self) -> image_proxy.MediaInfo: + raise NotImplementedError( + "Subclasses must implement `get_image_proxy_media_info`" + ) + + thumbnail_action = action( + detail=True, + url_path="thumb", + url_name="thumb", + serializer_class=media_serializers.MediaThumbnailRequestSerializer, + throttle_classes=[AnonThumbnailRateThrottle, OAuth2IdThumbnailRateThrottle], + ) + + async def thumbnail(self, request, *_, **__): serializer = self.get_serializer(data=request.query_params) serializer.is_valid(raise_exception=True) - return image_proxy.get( - image_url, - media_obj.identifier, - media_obj.provider, - accept_header=request.headers.get("Accept", "image/*"), - **serializer.validated_data, + media_info = await self.get_image_proxy_media_info() + + return await image_proxy_aget( + media_info, + request_config=image_proxy.RequestConfig( + accept_header=request.headers.get("Accept", "image/*"), + **serializer.validated_data, + ), ) # Helper functions diff --git a/api/test/unit/utils/test_image_proxy.py b/api/test/unit/utils/test_image_proxy.py index 35b047f5b7a..5f2aa58f15a 100644 --- a/api/test/unit/utils/test_image_proxy.py +++ b/api/test/unit/utils/test_image_proxy.py @@ -1,3 +1,4 @@ +from dataclasses import replace from test.factory.models.image import ImageFactory from unittest.mock import MagicMock from urllib.parse import urlencode @@ -9,7 +10,13 @@ import pytest import requests -from api.utils.image_proxy import HEADERS, UpstreamThumbnailException, extension +from api.utils.image_proxy import ( + HEADERS, + MediaInfo, + RequestConfig, + UpstreamThumbnailException, + extension, +) from api.utils.image_proxy import get as photon_get from api.utils.tallies import get_monthly_timestamp @@ -19,6 +26,12 @@ TEST_MEDIA_IDENTIFIER = "123" TEST_MEDIA_PROVIDER = "foo" +TEST_MEDIA_INFO = MediaInfo( + media_identifier=TEST_MEDIA_IDENTIFIER, + media_provider=TEST_MEDIA_PROVIDER, + image_url=TEST_IMAGE_URL, +) + UA_HEADER = HEADERS["User-Agent"] # cannot use actual image response because I kept running into some issue with @@ -60,7 +73,7 @@ def test_get_successful_no_auth_key_default_args(mock_image_data): .mock ) - res = photon_get(TEST_IMAGE_URL, TEST_MEDIA_IDENTIFIER, TEST_MEDIA_PROVIDER) + res = photon_get(TEST_MEDIA_INFO) assert res.content == MOCK_BODY.encode() assert res.status_code == 200 @@ -78,12 +91,12 @@ def test_get_successful_original_svg_no_auth_key_default_args(mock_image_data): .mock ) - res = photon_get( - TEST_IMAGE_URL.replace(".jpg", ".svg"), - TEST_MEDIA_IDENTIFIER, - TEST_MEDIA_PROVIDER, + media_info = replace( + TEST_MEDIA_INFO, image_url=TEST_MEDIA_INFO.image_url.replace(".jpg", ".svg") ) + res = photon_get(media_info) + assert res.content == SVG_BODY.encode() assert res.status_code == 200 assert mock_get.matched @@ -107,7 +120,7 @@ def test_get_successful_with_auth_key_default_args(mock_image_data, auth_key): .mock ) - res = photon_get(TEST_IMAGE_URL, TEST_MEDIA_IDENTIFIER, TEST_MEDIA_PROVIDER) + res = photon_get(TEST_MEDIA_INFO) assert res.content == MOCK_BODY.encode() assert res.status_code == 200 @@ -130,9 +143,7 @@ def test_get_successful_no_auth_key_not_compressed(mock_image_data): .mock ) - res = photon_get( - TEST_IMAGE_URL, TEST_MEDIA_IDENTIFIER, TEST_MEDIA_PROVIDER, is_compressed=False - ) + res = photon_get(TEST_MEDIA_INFO, RequestConfig(is_compressed=False)) assert res.content == MOCK_BODY.encode() assert res.status_code == 200 @@ -155,9 +166,7 @@ def test_get_successful_no_auth_key_full_size(mock_image_data): .mock ) - res = photon_get( - TEST_IMAGE_URL, TEST_MEDIA_IDENTIFIER, TEST_MEDIA_PROVIDER, is_full_size=True - ) + res = photon_get(TEST_MEDIA_INFO, RequestConfig(is_full_size=True)) assert res.content == MOCK_BODY.encode() assert res.status_code == 200 @@ -176,11 +185,8 @@ def test_get_successful_no_auth_key_full_size_not_compressed(mock_image_data): ) res = photon_get( - TEST_IMAGE_URL, - TEST_MEDIA_IDENTIFIER, - TEST_MEDIA_PROVIDER, - is_full_size=True, - is_compressed=False, + TEST_MEDIA_INFO, + RequestConfig(is_full_size=True, is_compressed=False), ) assert res.content == MOCK_BODY.encode() @@ -205,12 +211,7 @@ def test_get_successful_no_auth_key_png_only(mock_image_data): .mock ) - res = photon_get( - TEST_IMAGE_URL, - TEST_MEDIA_IDENTIFIER, - TEST_MEDIA_PROVIDER, - accept_header="image/png", - ) + res = photon_get(TEST_MEDIA_INFO, RequestConfig(accept_header="image/png")) assert res.content == MOCK_BODY.encode() assert res.status_code == 200 @@ -236,9 +237,11 @@ def test_get_successful_forward_query_params(mock_image_data): .mock ) - url_with_params = f"{TEST_IMAGE_URL}?{params}" + media_info_with_url_params = replace( + TEST_MEDIA_INFO, image_url=f"{TEST_IMAGE_URL}?{params}" + ) - res = photon_get(url_with_params, TEST_MEDIA_IDENTIFIER, TEST_MEDIA_PROVIDER) + res = photon_get(media_info_with_url_params) assert res.content == MOCK_BODY.encode() assert res.status_code == 200 @@ -273,7 +276,7 @@ def test_get_successful_records_response_code(mock_image_data, redis): .mock ) - photon_get(TEST_IMAGE_URL, TEST_MEDIA_IDENTIFIER, TEST_MEDIA_PROVIDER) + photon_get(TEST_MEDIA_INFO) month = get_monthly_timestamp() assert redis.get(f"thumbnail_response_code:{month}:200") == b"1" assert ( @@ -328,7 +331,7 @@ def test_get_exception_handles_error( redis.set(key, count_start) with pytest.raises(UpstreamThumbnailException): - photon_get(TEST_IMAGE_URL, TEST_MEDIA_IDENTIFIER, TEST_MEDIA_PROVIDER) + photon_get(TEST_MEDIA_INFO) assert_func = ( capture_exception.assert_called_once @@ -369,7 +372,7 @@ def test_get_http_exception_handles_error( redis.set(key, count_start) with pytest.raises(UpstreamThumbnailException): - photon_get(TEST_IMAGE_URL, TEST_MEDIA_IDENTIFIER, TEST_MEDIA_PROVIDER) + photon_get(TEST_MEDIA_INFO) assert_func = ( capture_exception.assert_called_once @@ -407,7 +410,9 @@ def test_get_successful_https_image_url_sends_ssl_parameter(mock_image_data): .mock ) - res = photon_get(https_url, TEST_MEDIA_IDENTIFIER, TEST_MEDIA_PROVIDER) + https_media_info = replace(TEST_MEDIA_INFO, image_url=https_url) + + res = photon_get(https_media_info) assert res.content == MOCK_BODY.encode() assert res.status_code == 200 @@ -421,7 +426,7 @@ def test_get_unsuccessful_request_raises_custom_exception(): with pytest.raises( UpstreamThumbnailException, match=r"Failed to render thumbnail." ): - photon_get(TEST_IMAGE_URL, TEST_MEDIA_IDENTIFIER, TEST_MEDIA_PROVIDER) + photon_get(TEST_MEDIA_INFO) assert mock_get.matched @@ -450,9 +455,14 @@ def test__get_extension_from_url(image_url, expected_ext): def test_photon_get_raises_by_not_allowed_types(image_type): image_url = TEST_IMAGE_URL.replace(".jpg", f".{image_type}") image = ImageFactory.create(url=image_url) + media_info = MediaInfo( + media_identifier=image.identifier, + media_provider=image.provider, + image_url=image_url, + ) with pytest.raises(UnsupportedMediaType): - photon_get(image_url, image.identifier, image.provider) + photon_get(media_info) @pytest.mark.django_db @@ -466,10 +476,15 @@ def test_photon_get_raises_by_not_allowed_types(image_type): def test_photon_get_saves_image_type_to_cache(redis, headers, expected_cache_val): image_url = TEST_IMAGE_URL.replace(".jpg", "") image = ImageFactory.create(url=image_url) + media_info = MediaInfo( + media_identifier=image.identifier, + media_provider=image.provider, + image_url=image_url, + ) with pook.use(): pook.head(image_url, reply=200, response_headers=headers) with pytest.raises(UnsupportedMediaType): - photon_get(image_url, image.identifier, image.provider) + photon_get(media_info) key = f"media:{image.identifier}:thumb_type" assert redis.get(key) == expected_cache_val diff --git a/api/test/unit/views/test_image_views.py b/api/test/unit/views/test_image_views.py index ffe7c0352ed..dc258dfc246 100644 --- a/api/test/unit/views/test_image_views.py +++ b/api/test/unit/views/test_image_views.py @@ -3,10 +3,9 @@ from dataclasses import dataclass from pathlib import Path from test.factory.models.image import ImageFactory -from unittest.mock import ANY, patch - -from django.http import HttpResponse +from unittest.mock import patch +import pook import pytest from PIL import UnidentifiedImageError from requests import Request, Response @@ -35,7 +34,7 @@ def _default_response_factory(req: Request) -> Response: return res -@pytest.fixture(autouse=True) +@pytest.fixture def requests(monkeypatch) -> RequestsFixture: fixture = RequestsFixture([]) @@ -68,18 +67,27 @@ def test_oembed_sends_ua_header(api_client, requests): [(True, "http://iip.smk.dk/thumb.jpg"), (False, "http://iip.smk.dk/image.jpg")], ) def test_thumbnail_uses_upstream_thumb_for_smk( - api_client, smk_has_thumb, expected_thumb_url + api_client, smk_has_thumb, expected_thumb_url, settings ): thumb_url = "http://iip.smk.dk/thumb.jpg" if smk_has_thumb else None image = ImageFactory.create( url="http://iip.smk.dk/image.jpg", thumbnail=thumb_url, ) - with patch("api.views.media_views.MediaViewSet.thumbnail") as thumb_call: - mock_response = HttpResponse("mock_response") - thumb_call.return_value = mock_response - api_client.get(f"/v1/images/{image.identifier}/thumb/") - thumb_call.assert_called_once_with(ANY, image, expected_thumb_url) + + with pook.use(): + mock_get = ( + # Pook interprets a trailing slash on the URL as the path, + # so strip that so the `path` matcher works + pook.get(settings.PHOTON_ENDPOINT[:-1]) + .path(expected_thumb_url.replace("http://", "/")) + .response(200) + ).mock + + response = api_client.get(f"/v1/images/{image.identifier}/thumb/") + + assert response.status_code == 200 + assert mock_get.matched is True @pytest.mark.django_db @@ -89,9 +97,13 @@ def test_watermark_raises_424_for_invalid_image(api_client): "cannot identify image file <_io.BytesIO object at 0xffff86d8fec0>" ) - with patch("PIL.Image.open") as mock_open: - mock_open.side_effect = UnidentifiedImageError(expected_error_message) - res = api_client.get(f"/v1/images/{image.identifier}/watermark/") + with pook.use(): + pook.get(image.url).reply(200) + + with patch("PIL.Image.open") as mock_open: + mock_open.side_effect = UnidentifiedImageError(expected_error_message) + res = api_client.get(f"/v1/images/{image.identifier}/watermark/") + assert res.status_code == 424 assert res.data["detail"] == expected_error_message