diff --git a/catalog/dags/providers/provider_api_scripts/wikimedia_commons.py b/catalog/dags/providers/provider_api_scripts/wikimedia_commons.py index da1009524cc..aa6199b907f 100644 --- a/catalog/dags/providers/provider_api_scripts/wikimedia_commons.py +++ b/catalog/dags/providers/provider_api_scripts/wikimedia_commons.py @@ -115,6 +115,7 @@ import lxml.html as html from common.constants import AUDIO, IMAGE +from common.extensions import EXTENSIONS from common.licenses import LicenseInfo, get_license_info from common.loader import provider_details as prov from providers.provider_api_scripts.provider_data_ingester import ProviderDataIngester @@ -215,7 +216,8 @@ def get_next_query_params(self, prev_query_params: dict | None): **self.continue_token, } - def get_media_type(self, record): + @staticmethod + def get_media_type(record): """Get the media_type of a parsed Record""" return record["media_type"] @@ -318,7 +320,7 @@ def get_record_data(self, record): creator, creator_url = self.extract_creator_info(media_info) title = self.extract_title(media_info) filesize = media_info.get("size", 0) # in bytes - filetype = self.extract_file_type(media_info) + filetype = self.extract_file_type(url, valid_media_type) meta_data = self.create_meta_data_dict(record) record_data = { @@ -532,9 +534,26 @@ def extract_category_info(media_info): return categories_list @staticmethod - def extract_file_type(media_info): - filetype = media_info.get("url", "").split(".")[-1] - return None if filetype == "" else filetype + def extract_file_type(url, media_type): + """ + Extract the filetype from extension in the media url. + + In case of images, we check if the filetype is in the list of valid image + types, so we can ignore other media types considered as videos (eg: .ogv). + """ + image_extensions = EXTENSIONS.get(IMAGE, {}) + if filetype := url.split(".")[-1]: + filetype = filetype.lower() + if ( + media_type == IMAGE and filetype in image_extensions + ) or media_type == AUDIO: + return filetype + + logger.warning( + f"Invalid filetype for `{media_type}` media type: {filetype}" + ) + + return None @staticmethod def extract_license_info(media_info) -> LicenseInfo | None: diff --git a/catalog/tests/dags/providers/provider_api_scripts/test_wikimedia_commons.py b/catalog/tests/dags/providers/provider_api_scripts/test_wikimedia_commons.py index 32bf43761a0..6ac554fa68d 100644 --- a/catalog/tests/dags/providers/provider_api_scripts/test_wikimedia_commons.py +++ b/catalog/tests/dags/providers/provider_api_scripts/test_wikimedia_commons.py @@ -399,6 +399,27 @@ def test_extract_creator_info_handles_link_as_partial_text(wmc): assert expect_creator_url == actual_creator_url +@pytest.mark.parametrize( + "url, media_type, expected", + [ + # Valid images + ("https://example.com/image.jpg", "image", "jpg"), + ("https://example.com/image.JpeG", "image", "jpeg"), + ("https://example.com/image.Png", "image", "png"), + ("https://example.com/image.GIF", "image", "gif"), + # Invalid (for our sake) images + ("https://example.com/image.ogv", "image", None), + ("https://example.com/image.xyz", "image", None), + # Valid audio + ("https://example.com/audio.mp3", "audio", "mp3"), + ("https://example.com/audio.ogg", "audio", "ogg"), + ("https://example.com/audio.WAV", "audio", "wav"), + ], +) +def test_extract_file_type(wmc, url, media_type, expected): + assert wmc.extract_file_type(url, media_type) == expected + + def test_extract_license_info_finds_license_url(wmc): image_info = _get_resource_json("image_info_from_example_data.json") expect_license_url = "https://creativecommons.org/licenses/by-sa/4.0/"