Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Pull all addons in one DB request #5289

Merged
merged 5 commits into from
Dec 20, 2024
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 0 additions & 13 deletions api/api/models/audio.py
Original file line number Diff line number Diff line change
Expand Up @@ -224,19 +224,6 @@ def duration_in_s(self):
def audio_set(self):
return getattr(self, "audioset")

def get_waveform(self) -> list[float]:
"""
Get the waveform if it exists. Return a blank list otherwise.

:return: the waveform, if it exists; empty list otherwise
"""

try:
add_on = AudioAddOn.objects.get(audio_identifier=self.identifier)
return add_on.waveform_peaks or []
except AudioAddOn.DoesNotExist:
return []

def get_or_create_waveform(self):
add_on, _ = AudioAddOn.objects.get_or_create(audio_identifier=self.identifier)

Expand Down
6 changes: 3 additions & 3 deletions api/api/serializers/audio_serializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,9 +179,9 @@ def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)

def get_peaks(self, obj) -> list[int]:
if isinstance(obj, Hit):
obj = Audio.objects.get(identifier=obj.identifier)
return obj.get_waveform()
audio_addon = self.context.get("addons", {}).get(obj.identifier)
if audio_addon:
return audio_addon.waveform_peaks

def to_representation(self, instance):
# Get the original representation
Expand Down
2 changes: 2 additions & 0 deletions api/api/views/audio_views.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
)
from api.docs.audio_docs import thumbnail as thumbnail_docs
from api.models import Audio
from api.models.audio import AudioAddOn
from api.serializers.audio_serializers import (
AudioReportRequestSerializer,
AudioSearchRequestSerializer,
Expand All @@ -38,6 +39,7 @@ class AudioViewSet(MediaViewSet):
"""Viewset for all endpoints pertaining to audio."""

model_class = Audio
addon_model_class = AudioAddOn
media_type = AUDIO_TYPE
query_serializer_class = AudioSearchRequestSerializer
default_index = settings.MEDIA_INDEX_MAPPING[AUDIO_TYPE]
Expand Down
28 changes: 22 additions & 6 deletions api/api/views/media_views.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from api.controllers import search_controller
from api.controllers.elasticsearch.related import related_media
from api.models import ContentSource
from api.models.base import OpenLedgerModel
from api.models.media import AbstractMedia
from api.serializers import media_serializers
from api.serializers.source_serializers import SourceSerializer
Expand Down Expand Up @@ -51,6 +52,7 @@ class MediaViewSet(AsyncViewSetMixin, AsyncAPIView, ReadOnlyModelViewSet):

# Populate these in the corresponding subclass
model_class: type[AbstractMedia] = None
addon_model_class: type[OpenLedgerModel] = None
media_type: MediaType | None = None
query_serializer_class = None
default_index = None
Expand Down Expand Up @@ -97,7 +99,11 @@ def _get_request_serializer(self, request):
req_serializer.is_valid(raise_exception=True)
return req_serializer

def get_db_results(self, results):
def get_db_results(
self,
results,
include_addons=False,
) -> tuple[list[AbstractMedia], list[OpenLedgerModel]]:
"""
Map ES hits to ORM model instances.

Expand All @@ -107,6 +113,7 @@ def get_db_results(self, results):
which is both unique and indexed, so it's quite performant.

:param results: the list of ES hits
:param include_addons: whether to include add-ons with results
:return: the corresponding list of ORM model instances
"""

Expand All @@ -121,7 +128,12 @@ def get_db_results(self, results):
for result, hit in zip(results, hits):
result.fields_matched = getattr(hit.meta, "highlight", None)

return results
if include_addons and self.addon_model_class:
addons = list(self.addon_model_class.objects.filter(pk__in=identifiers))
else:
addons = []

return (results, addons)

# Standard actions

Expand Down Expand Up @@ -188,9 +200,13 @@ def get_media_results(
except ValueError as e:
raise APIException(getattr(e, "message", str(e)))

serializer_context = search_context | self.get_serializer_context()

results = self.get_db_results(results)
peaks = params.validated_data.get("peaks")
results, addons = self.get_db_results(results, include_addons=peaks)
serializer_context = (
search_context
| self.get_serializer_context()
| {"addons": {addon.audio_identifier: addon for addon in addons}}
)

serializer = self.get_serializer(results, many=True, context=serializer_context)
return self.get_paginated_response(serializer.data)
Expand Down Expand Up @@ -231,7 +247,7 @@ def related(self, request, identifier=None, *_, **__):

serializer_context = self.get_serializer_context()

results = self.get_db_results(results)
results, _ = self.get_db_results(results)

serializer = self.get_serializer(results, many=True, context=serializer_context)
return self.get_paginated_response(serializer.data)
Expand Down
8 changes: 6 additions & 2 deletions api/test/integration/test_dead_link_filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,10 @@ def get_empty_cached_statuses(_, image_urls):
_MAKE_HEAD_REQUESTS_MODULE_PATH = "api.utils.check_dead_links._make_head_requests"


def _mock_get_db_results(results, include_addons=False):
return (results, [])


def _patch_make_head_requests():
def _make_head_requests(urls, *args, **kwargs):
responses = []
Expand Down Expand Up @@ -67,7 +71,7 @@ def test_dead_link_filtering(mocked_map, api_client):
with patch(
"api.views.image_views.ImageViewSet.get_db_results"
) as mock_get_db_result:
mock_get_db_result.side_effect = lambda value: value
mock_get_db_result.side_effect = _mock_get_db_results
res_with_dead_links = api_client.get(
path,
query_params | {"filter_dead": False},
Expand Down Expand Up @@ -121,7 +125,7 @@ def test_dead_link_filtering_all_dead_links(
with patch(
"api.views.image_views.ImageViewSet.get_db_results"
) as mock_get_db_result:
mock_get_db_result.side_effect = lambda value: value
mock_get_db_result.side_effect = _mock_get_db_results
with patch_link_validation_dead_for_count(page_size / DEAD_LINK_RATIO):
response = api_client.get(
path,
Expand Down
25 changes: 0 additions & 25 deletions api/test/unit/models/test_audio.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,28 +41,3 @@ def test_audio_waveform_caches(generate_peaks_mock, audio_fixture):
audio_fixture.delete()

assert AudioAddOn.objects.count() == 1


@pytest.mark.django_db
@mock.patch("api.models.audio.AudioAddOn.objects.get")
def test_audio_waveform_sent_when_present(get_mock, audio_fixture):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should there be new tests for the modified functionality of get_db_results with the include_addons parameter?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good call, I'll add some tests for that asap.

# When ``AudioAddOn.waveform_peaks`` exists, waveform is filled
peaks = [0, 0.25, 0.5, 0.25, 0.1]
get_mock.return_value = mock.Mock(waveform_peaks=peaks)
assert audio_fixture.get_waveform() == peaks


@pytest.mark.django_db
@mock.patch("api.models.audio.AudioAddOn.objects.get")
def test_audio_waveform_blank_when_absent(get_mock, audio_fixture):
# When ``AudioAddOn`` does not exist, waveform is blank
get_mock.side_effect = AudioAddOn.DoesNotExist()
assert audio_fixture.get_waveform() == []


@pytest.mark.django_db
@mock.patch("api.models.audio.AudioAddOn.objects.get")
def test_audio_waveform_blank_when_none(get_mock, audio_fixture):
# When ``AudioAddOn.waveform_peaks`` is None, waveform is blank
get_mock.return_value = mock.Mock(waveform_peaks=None)
assert audio_fixture.get_waveform() == []
Loading