Skip to content

Commit

Permalink
Add ADRF and make the thumbnail view async (#3020)
Browse files Browse the repository at this point in the history
* Add ADRF and make the thumbnail view async

This does not convert the `image_proxy.get` to async to avoid making this initial PR more complicated than it needs to be, and to avoid being blocked on the aiohttp client session sharing which we will want to have in before we convert the image proxy to use aiohttp.

* Fix bad merge

* Fix bad merge

* Fix docstrings on api docs

* Remove unnecessary thread sensitivity

* Fix bad subclassing of thumbnails endpoints

* Use simpler/clearer names for image proxy config objects
  • Loading branch information
sarayourfriend authored Nov 22, 2023
1 parent a840aa6 commit b91cdc3
Show file tree
Hide file tree
Showing 8 changed files with 193 additions and 109 deletions.
5 changes: 3 additions & 2 deletions api/Pipfile
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ pytest-sugar = "~=0.9"
pook = {ref = "master", git = "git+https://github.com/h2non/pook.git"}

[packages]
adrf = "~=0.1.2"
aiohttp = "~=3.8"
aws-requests-auth = "~=0.4"
deepdiff = "~=6.4"
Expand All @@ -26,6 +27,7 @@ django-cors-headers = "~=4.2"
django-log-request-id = "~=2.0"
django-oauth-toolkit = "~=2.3"
django-redis = "~=5.4"
django-split-settings = "*"
django-tqdm = "~=1.3"
django-uuslug = "~=2.0"
djangorestframework = "~=3.14"
Expand All @@ -35,11 +37,10 @@ elasticsearch-dsl = "~=8.9"
future = "~=0.18"
limit = "~=0.2"
Pillow = "~=10.1.0"
psycopg = "~=3.1"
python-decouple = "~=3.8"
python-xmp-toolkit = "~=2.0"
sentry-sdk = "~=1.30"
django-split-settings = "*"
psycopg = "~=3.1"
uvicorn = {extras = ["standard"], version = "~=0.23"}

[requires]
Expand Down
17 changes: 16 additions & 1 deletion api/Pipfile.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

45 changes: 31 additions & 14 deletions api/api/utils/image_proxy/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import logging
from dataclasses import dataclass
from typing import Literal
from urllib.parse import urlparse

Expand Down Expand Up @@ -33,13 +34,26 @@
THUMBNAIL_STRATEGY = Literal["photon_proxy", "original"]


@dataclass
class MediaInfo:
media_provider: str
media_identifier: str
image_url: str


@dataclass
class RequestConfig:
accept_header: str = "image/*"
is_full_size: bool = False
is_compressed: bool = True


def get_request_params_for_extension(
ext: str,
headers: dict[str, str],
image_url: str,
parsed_image_url: urlparse,
is_full_size: bool,
is_compressed: bool,
request_config: RequestConfig,
) -> tuple[str, dict[str, str], dict[str, str]]:
"""
Get the request params (url, params, headers) for the thumbnail proxy.
Expand All @@ -49,7 +63,10 @@ def get_request_params_for_extension(
"""
if ext in PHOTON_TYPES:
return get_photon_request_params(
parsed_image_url, is_full_size, is_compressed, headers
parsed_image_url,
request_config.is_full_size,
request_config.is_compressed,
headers,
)
elif ext in ORIGINAL_TYPES:
return image_url, {}, headers
Expand All @@ -59,24 +76,23 @@ def get_request_params_for_extension(


def get(
image_url: str,
media_identifier: str,
media_provider: str,
accept_header: str = "image/*",
is_full_size: bool = False,
is_compressed: bool = True,
media_info: MediaInfo,
request_config: RequestConfig = RequestConfig(),
) -> HttpResponse:
"""
Proxy an image through Photon if its file type is supported, else return the
original image if the file type is SVG. Otherwise, raise an exception.
"""
image_url = media_info.image_url
media_identifier = media_info.media_identifier

logger = parent_logger.getChild("get")
tallies = django_redis.get_redis_connection("tallies")
month = get_monthly_timestamp()

image_extension = get_image_extension(image_url, media_identifier)

headers = {"Accept": accept_header} | HEADERS
headers = {"Accept": request_config.accept_header} | HEADERS

parsed_image_url = urlparse(image_url)
domain = parsed_image_url.netloc
Expand All @@ -86,8 +102,7 @@ def get(
headers,
image_url,
parsed_image_url,
is_full_size,
is_compressed,
request_config,
)

try:
Expand All @@ -103,7 +118,7 @@ def get(
f"{month}:{upstream_response.status_code}"
)
tallies.incr(
f"thumbnail_response_code_by_provider:{media_provider}:"
f"thumbnail_response_code_by_provider:{media_info.media_provider}:"
f"{month}:{upstream_response.status_code}"
)
upstream_response.raise_for_status()
Expand Down Expand Up @@ -133,7 +148,9 @@ def get(
f"thumbnail_http_error:{domain}:{month}:{code}:{exc.response.text}"
)
logger.warning(
f"Failed to render thumbnail {upstream_url=} {code=} {media_provider=}"
f"Failed to render thumbnail "
f"{upstream_url=} {code=} "
f"{media_info.media_provider=}"
)
raise UpstreamThumbnailException(f"Failed to render thumbnail. {exc}")

Expand Down
38 changes: 19 additions & 19 deletions api/api/views/audio_views.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,9 @@
source_collection,
stats,
tag_collection,
thumbnail,
waveform,
)
from api.docs.audio_docs import thumbnail as thumbnail_docs
from api.docs.audio_docs import waveform
from api.models import Audio
from api.serializers.audio_serializers import (
AudioCollectionRequestSerializer,
Expand All @@ -26,7 +26,7 @@
AudioSerializer,
AudioWaveformSerializer,
)
from api.serializers.media_serializers import MediaThumbnailRequestSerializer
from api.utils import image_proxy
from api.utils.throttle import AnonThumbnailRateThrottle, OAuth2IdThumbnailRateThrottle
from api.views.media_views import MediaViewSet

Expand Down Expand Up @@ -80,21 +80,8 @@ def source_collection(self, request, source):
def tag_collection(self, request, tag, *_, **__):
return super().tag_collection(request, tag, *_, **__)

@thumbnail
@action(
detail=True,
url_path="thumb",
url_name="thumb",
serializer_class=MediaThumbnailRequestSerializer,
throttle_classes=[AnonThumbnailRateThrottle, OAuth2IdThumbnailRateThrottle],
)
def thumbnail(self, request, *_, **__):
"""
Retrieve the scaled down and compressed thumbnail of the artwork of an
audio track or its audio set.
"""

audio = self.get_object()
async def get_image_proxy_media_info(self) -> image_proxy.MediaInfo:
audio = await self.aget_object()

image_url = None
if audio_thumbnail := audio.thumbnail:
Expand All @@ -104,7 +91,20 @@ def thumbnail(self, request, *_, **__):
if not image_url:
raise NotFound("Could not find artwork.")

return super().thumbnail(request, audio, image_url)
return image_proxy.MediaInfo(
media_identifier=audio.identifier,
media_provider=audio.provider,
image_url=image_url,
)

@thumbnail_docs
@MediaViewSet.thumbnail_action
async def thumbnail(self, *args, **kwargs):
"""
Retrieve the scaled down and compressed thumbnail of the artwork of an
audio track or its audio set.
"""
return await super().thumbnail(*args, **kwargs)

@waveform
@action(
Expand Down
35 changes: 16 additions & 19 deletions api/api/views/image_views.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,8 @@
source_collection,
stats,
tag_collection,
thumbnail,
)
from api.docs.image_docs import thumbnail as thumbnail_docs
from api.docs.image_docs import watermark as watermark_doc
from api.models import Image
from api.serializers.image_serializers import (
Expand All @@ -34,11 +34,8 @@
OembedSerializer,
WatermarkRequestSerializer,
)
from api.serializers.media_serializers import (
MediaThumbnailRequestSerializer,
PaginatedRequestSerializer,
)
from api.utils.throttle import AnonThumbnailRateThrottle, OAuth2IdThumbnailRateThrottle
from api.serializers.media_serializers import PaginatedRequestSerializer
from api.utils import image_proxy
from api.utils.watermark import watermark
from api.views.media_views import MediaViewSet

Expand Down Expand Up @@ -130,25 +127,25 @@ def oembed(self, request, *_, **__):
serializer = self.get_serializer(image, context=context)
return Response(data=serializer.data)

@thumbnail
@action(
detail=True,
url_path="thumb",
url_name="thumb",
serializer_class=MediaThumbnailRequestSerializer,
throttle_classes=[AnonThumbnailRateThrottle, OAuth2IdThumbnailRateThrottle],
)
def thumbnail(self, request, *_, **__):
"""Retrieve the scaled down and compressed thumbnail of the image."""

image = self.get_object()
async def get_image_proxy_media_info(self) -> image_proxy.MediaInfo:
image = await self.aget_object()
image_url = image.url
# Hotfix to use thumbnails for SMK images
# TODO: Remove when small thumbnail issues are resolved
if "iip.smk.dk" in image_url and image.thumbnail:
image_url = image.thumbnail

return super().thumbnail(request, image, image_url)
return image_proxy.MediaInfo(
media_identifier=image.identifier,
media_provider=image.provider,
image_url=image_url,
)

@thumbnail_docs
@MediaViewSet.thumbnail_action
async def thumbnail(self, *args, **kwargs):
"""Retrieve the scaled down and compressed thumbnail of the image."""
return await super().thumbnail(*args, **kwargs)

@watermark_doc
@action(detail=True, url_path="watermark", url_name="watermark")
Expand Down
43 changes: 35 additions & 8 deletions api/api/views/media_views.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,10 @@
from rest_framework.response import Response
from rest_framework.viewsets import ReadOnlyModelViewSet

from adrf.views import APIView as AsyncAPIView
from adrf.viewsets import ViewSetMixin as AsyncViewSetMixin
from asgiref.sync import sync_to_async

from api.constants.media_types import MediaType
from api.constants.search import SearchStrategy
from api.controllers import search_controller
Expand All @@ -18,6 +22,7 @@
from api.utils import image_proxy
from api.utils.pagination import StandardPagination
from api.utils.search_context import SearchContext
from api.utils.throttle import AnonThumbnailRateThrottle, OAuth2IdThumbnailRateThrottle


logger = logging.getLogger(__name__)
Expand All @@ -35,7 +40,12 @@ class InvalidSource(APIException):
default_code = "invalid_source"


class MediaViewSet(ReadOnlyModelViewSet):
image_proxy_aget = sync_to_async(image_proxy.get)


class MediaViewSet(AsyncViewSetMixin, AsyncAPIView, ReadOnlyModelViewSet):
view_is_async = True

lookup_field = "identifier"
# TODO: https://github.com/encode/django-rest-framework/pull/6789
lookup_value_regex = (
Expand Down Expand Up @@ -79,6 +89,8 @@ def get_queryset(self):
).values_list("provider_identifier")
)

aget_object = sync_to_async(ReadOnlyModelViewSet.get_object)

def get_serializer_context(self):
context = super().get_serializer_context()
req_serializer = self._get_request_serializer(self.request)
Expand Down Expand Up @@ -265,16 +277,31 @@ def report(self, request, identifier):

return Response(data=serializer.data, status=status.HTTP_201_CREATED)

def thumbnail(self, request, media_obj, image_url):
async def get_image_proxy_media_info(self) -> image_proxy.MediaInfo:
raise NotImplementedError(
"Subclasses must implement `get_image_proxy_media_info`"
)

thumbnail_action = action(
detail=True,
url_path="thumb",
url_name="thumb",
serializer_class=media_serializers.MediaThumbnailRequestSerializer,
throttle_classes=[AnonThumbnailRateThrottle, OAuth2IdThumbnailRateThrottle],
)

async def thumbnail(self, request, *_, **__):
serializer = self.get_serializer(data=request.query_params)
serializer.is_valid(raise_exception=True)

return image_proxy.get(
image_url,
media_obj.identifier,
media_obj.provider,
accept_header=request.headers.get("Accept", "image/*"),
**serializer.validated_data,
media_info = await self.get_image_proxy_media_info()

return await image_proxy_aget(
media_info,
request_config=image_proxy.RequestConfig(
accept_header=request.headers.get("Accept", "image/*"),
**serializer.validated_data,
),
)

# Helper functions
Expand Down
Loading

0 comments on commit b91cdc3

Please sign in to comment.