diff --git a/.flake8 b/.flake8 index 851862bd..7f375890 100644 --- a/.flake8 +++ b/.flake8 @@ -6,5 +6,4 @@ per-file-ignores = docs/_howto/scripts/*:E402 tests/**/test_*.py:F811 max-line-length = 120 -# TODO: reduce this to 10 -max-complexity = 17 +max-complexity = 10 diff --git a/musify/core/printer.py b/musify/core/printer.py index 92e5cc05..f74044eb 100644 --- a/musify/core/printer.py +++ b/musify/core/printer.py @@ -106,30 +106,21 @@ def _to_str(cls, attributes: Mapping[str, Any], indent: int = 2, increment: int elif isinstance(attr_val, (datetime, date)): attr_val_str = str(attr_val) elif isinstance(attr_val, (list, tuple, set)) and len(attr_val) > 0: - pp_repr = "[{}]" - if isinstance(attr_val, set): - pp_repr = "{{}}" - elif isinstance(attr_val, tuple): - pp_repr = "({})" - - attr_val_str = str(attr_val) - - if len(attr_val_str) > max_val_width: - pp_repr = pp_repr.format("\n" + " " * indent + "{}\n" + " " * indent_prev) - attr_val_pp = [] - for val in attr_val: - if isinstance(val, PrettyPrinter): - attr_val_pp.append(val.__str__(indent=indent + increment, increment=increment)) - else: - attr_val_pp.append(str(val)) - attr_val_str = pp_repr.format((",\n" + " " * indent).join(attr_val_pp)) + attr_val_str = cls._get_attribute_value_from_collection( + attr_val=attr_val, + max_val_width=max_val_width, + indent=indent, + indent_prev=indent_prev, + increment=increment + ) elif isinstance(attr_val, Mapping) and len(attr_val) > 0: - attr_val_pp = cls._to_str(attr_val, indent=indent, increment=increment) - attr_val_str = "{" + ", ".join(attr_val_pp) + "}" - - if len(attr_val_str) > max_val_width: - pp_repr = "\n" + " " * indent + "{}\n" + " " * indent_prev - attr_val_str = "{" + pp_repr.format((",\n" + " " * indent).join(attr_val_pp)) + "}" + attr_val_str = cls._get_attribute_value_from_mapping( + attr_val=attr_val, + max_val_width=max_val_width, + indent=indent, + indent_prev=indent_prev, + increment=increment + ) else: attr_val_str = repr(attr_val) @@ -141,6 +132,43 @@ def _to_str(cls, attributes: Mapping[str, Any], indent: int = 2, increment: int return attributes_repr + @staticmethod + def _get_attribute_value_from_collection( + attr_val: list | tuple | set, max_val_width: int, indent: int, indent_prev: int, increment: int + ) -> str: + pp_repr = "[{}]" + if isinstance(attr_val, set): + pp_repr = "{{}}" + elif isinstance(attr_val, tuple): + pp_repr = "({})" + + attr_val_str = str(attr_val) + + if len(attr_val_str) > max_val_width: + pp_repr = pp_repr.format("\n" + " " * indent + "{}\n" + " " * indent_prev) + attr_val_pp = [] + for val in attr_val: + if isinstance(val, PrettyPrinter): + attr_val_pp.append(val.__str__(indent=indent + increment, increment=increment)) + else: + attr_val_pp.append(str(val)) + attr_val_str = pp_repr.format((",\n" + " " * indent).join(attr_val_pp)) + + return attr_val_str + + @classmethod + def _get_attribute_value_from_mapping( + cls, attr_val: Mapping, max_val_width: int, indent: int, indent_prev: int, increment: int + ) -> str: + attr_val_pp = cls._to_str(attr_val, indent=indent, increment=increment) + attr_val_str = "{" + ", ".join(attr_val_pp) + "}" + + if len(attr_val_str) > max_val_width: + pp_repr = "\n" + " " * indent + "{}\n" + " " * indent_prev + attr_val_str = "{" + pp_repr.format((",\n" + " " * indent).join(attr_val_pp)) + "}" + + return attr_val_str + def __repr__(self): return f"{self.__class__.__name__}({repr(self.as_dict())})" diff --git a/musify/libraries/core/collection.py b/musify/libraries/core/collection.py index 648a117c..dd9e62de 100644 --- a/musify/libraries/core/collection.py +++ b/musify/libraries/core/collection.py @@ -297,6 +297,22 @@ def __getitem__( if isinstance(__key, int) or isinstance(__key, slice): # simply index the list or items return self.items[__key] + getters = self.__get_item_getters(__key) + if not getters: + raise MusifyKeyError(f"Unrecognised key type | {__key=} | type={type(__key).__name__}") + + caught_exceptions = [] + for getter in getters: + try: + return getter.get_item(self) + except (MusifyAttributeError, MusifyKeyError) as ex: + caught_exceptions.append(ex) + + raise MusifyKeyError( + f"Key is invalid. The following errors were thrown: {[str(ex) for ex in caught_exceptions]}" + ) + + def __get_item_getters(self, __key: str) -> list[ItemGetterStrategy]: getters = [] if isinstance(__key, File): getters.append(PathGetter(__key.path)) @@ -316,19 +332,7 @@ def __getitem__( NameGetter(__key), ]) - if not getters: - raise MusifyKeyError(f"Unrecognised key type | {__key=} | type={type(__key).__name__}") - - caught_exceptions = [] - for getter in getters: - try: - return getter.get_item(self) - except (MusifyAttributeError, MusifyKeyError) as ex: - caught_exceptions.append(ex) - - raise MusifyKeyError( - f"Key is invalid. The following errors were thrown: {[str(ex) for ex in caught_exceptions]}" - ) + return getters def __setitem__(self, __key: str | int | T, __value: T): """Replace the item at a given ``__key`` with the given ``__value``.""" diff --git a/musify/libraries/remote/core/processors/check.py b/musify/libraries/remote/core/processors/check.py index 6a178a81..2c46a353 100644 --- a/musify/libraries/remote/core/processors/check.py +++ b/musify/libraries/remote/core/processors/check.py @@ -408,53 +408,58 @@ def _match_to_input(self, name: str) -> None: print("\n" + help_text) for item in self._remaining.copy(): - while item in self._remaining: # while item not matched or skipped + while current_input is not None and item in self._remaining: # while item not matched or skipped self._log_padded([name, f"{len(self._remaining):>6} remaining items"]) if 'a' not in current_input: current_input = self._get_user_input(align_string(item.name, max_width=max_width)) if current_input.casefold() == 'h': # print help print("\n" + help_text) + else: + current_input = self._match_item_to_input(name=name, item=item, current_input=current_input) - elif current_input.casefold() == 's' or current_input.casefold() == 'q': # quit/skip - self._log_padded([name, "Skipping all loops"], pad="<") - self._quit = current_input.casefold() == 'q' or self._quit - self._skip = current_input.casefold() == 's' or self._skip - self._remaining.clear() - return - - elif current_input.casefold().replace('a', '') == 'u': # mark item as unavailable - self._log_padded([name, "Marking as unavailable"], pad="<") - item.uri = self.api.wrangler.unavailable_uri_dummy - self._remaining.remove(item) - - elif current_input.casefold().replace('a', '') == 'n': # leave item without URI and unprocessed - self._log_padded([name, "Skipping"], pad="<") - item.uri = None - self._remaining.remove(item) - - elif current_input.casefold() == 'r': # return to former 'while' loop - self._log_padded([name, "Refreshing playlist metadata and restarting loop"]) - return - - elif current_input.casefold() == 'p' and hasattr(item, "path"): # print item path - print(f"\33[96m{item.path}\33[0m") - - elif self.api.wrangler.validate_id_type(current_input): # update URI and add item to switched list - uri = self.api.wrangler.convert( - current_input, kind=RemoteObjectType.TRACK, type_out=RemoteIDType.URI - ) + if current_input is None or not self._remaining: + break + + def _match_item_to_input(self, name: str, item: MusifyItemSettable, current_input: str) -> str | None: + if current_input.casefold() == 's' or current_input.casefold() == 'q': # quit/skip + self._log_padded([name, "Skipping all loops"], pad="<") + self._quit = current_input.casefold() == 'q' or self._quit + self._skip = current_input.casefold() == 's' or self._skip + self._remaining.clear() + return - self._log_padded([name, f"Updating URI: {item.uri} -> {uri}"], pad="<") - item.uri = uri + elif current_input.casefold().replace('a', '') == 'u': # mark item as unavailable + self._log_padded([name, "Marking as unavailable"], pad="<") + item.uri = self.api.wrangler.unavailable_uri_dummy + self._remaining.remove(item) - self._switched.append(item) - self._remaining.remove(item) - current_input = "" + elif current_input.casefold().replace('a', '') == 'n': # leave item without URI and unprocessed + self._log_padded([name, "Skipping"], pad="<") + item.uri = None + self._remaining.remove(item) - elif current_input: # invalid input - self.logger.warning("Input not recognised.") - current_input = "" + elif current_input.casefold() == 'r': # return to former 'while' loop + self._log_padded([name, "Refreshing playlist metadata and restarting loop"]) + return - if not self._remaining: - break + elif current_input.casefold() == 'p' and hasattr(item, "path"): # print item path + print(f"\33[96m{item.path}\33[0m") + + elif self.api.wrangler.validate_id_type(current_input): # update URI and add item to switched list + uri = self.api.wrangler.convert( + current_input, kind=RemoteObjectType.TRACK, type_out=RemoteIDType.URI + ) + + self._log_padded([name, f"Updating URI: {item.uri} -> {uri}"], pad="<") + item.uri = uri + + self._switched.append(item) + self._remaining.remove(item) + current_input = "" + + elif current_input: # invalid input + self.logger.warning("Input not recognised.") + current_input = "" + + return current_input diff --git a/musify/libraries/remote/spotify/processors.py b/musify/libraries/remote/spotify/processors.py index ac9f631a..d3f0eb6a 100644 --- a/musify/libraries/remote/spotify/processors.py +++ b/musify/libraries/remote/spotify/processors.py @@ -68,18 +68,9 @@ def _get_item_type( cls, value: str | Mapping[str, Any] | RemoteResponse, kind: RemoteObjectType | None = None ) -> RemoteObjectType | None: if isinstance(value, RemoteResponse): - response_kind = cls._get_item_type(value.response) - if value.kind != response_kind: - raise RemoteObjectTypeError( - f"RemoteResponse kind != actual response kind: {value.kind} != {response_kind}" - ) - return value.kind + return cls._get_item_type_from_response(value) if isinstance(value, Mapping): - if value.get("is_local", False): - raise RemoteObjectTypeError("Cannot process local items") - if "type" not in value: - raise RemoteObjectTypeError(f"Given map does not contain a 'type' key: {value}") - return RemoteObjectType.from_name(value["type"].casefold().rstrip('s'))[0] + return cls._get_item_type_from_mapping(value) value = value.strip() uri_check = value.split(':') @@ -99,6 +90,23 @@ def _get_item_type( return kind raise RemoteObjectTypeError(f"Could not determine item type of given value: {value}") + @classmethod + def _get_item_type_from_response(cls, value: RemoteResponse) -> RemoteObjectType: + response_kind = cls._get_item_type_from_mapping(value.response) + if value.kind != response_kind: + raise RemoteObjectTypeError( + f"RemoteResponse kind != actual response kind: {value.kind} != {response_kind}" + ) + return value.kind + + @classmethod + def _get_item_type_from_mapping(cls, value: Mapping[str, Any]) -> RemoteObjectType: + if value.get("is_local", False): + raise RemoteObjectTypeError("Cannot process local items") + if "type" not in value: + raise RemoteObjectTypeError(f"Given map does not contain a 'type' key: {value}") + return RemoteObjectType.from_name(value["type"].casefold().rstrip('s'))[0] + @classmethod def convert( cls, @@ -114,34 +122,14 @@ def convert( value = value.strip() - if type_in == RemoteIDType.URL_EXT or type_in == RemoteIDType.URL: # open/API URL - url_path = urlparse(value).path.split("/") - for chunk in url_path: - try: - kind = RemoteObjectType.from_name(chunk.rstrip('s'))[0] - break - except MusifyEnumError: - continue - - if kind == RemoteObjectType.USER: - name = kind.name.lower() - try: - id_ = url_path[url_path.index(name) + 1] - except ValueError: - id_ = url_path[url_path.index(name + "s") + 1] - else: - id_ = next(p for p in url_path if len(p) == RemoteIDType.ID.value) - + if type_in == RemoteIDType.URL_EXT or type_in == RemoteIDType.URL: + kind, id_ = cls._get_id_from_url(value=value, kind=kind) elif type_in == RemoteIDType.URI: - uri_split = value.split(':') - kind = RemoteObjectType.from_name(uri_split[1])[0] - id_ = uri_split[2] - + kind, id_ = cls._get_id_from_uri(value=value) elif type_in == RemoteIDType.ID: if kind is None: raise RemoteIDTypeError("Input value is an ID and no defined 'kind' has been given.", RemoteIDType.ID) id_ = value - else: raise RemoteIDTypeError(f"Could not determine item type: {value}") @@ -156,6 +144,34 @@ def convert( else: return id_ + @classmethod + def _get_id_from_url(cls, value: str, kind: RemoteObjectType | None = None) -> tuple[RemoteObjectType, str]: + url_path = urlparse(value).path.split("/") + for chunk in url_path: + try: + kind = RemoteObjectType.from_name(chunk.rstrip('s'))[0] + break + except MusifyEnumError: + continue + + if kind == RemoteObjectType.USER: + name = kind.name.lower() + try: + id_ = url_path[url_path.index(name) + 1] + except ValueError: + id_ = url_path[url_path.index(name + "s") + 1] + else: + id_ = next(p for p in url_path if len(p) == RemoteIDType.ID.value) + + return kind, id_ + + @classmethod + def _get_id_from_uri(cls, value: str) -> tuple[RemoteObjectType, str]: + uri_split = value.split(':') + kind = RemoteObjectType.from_name(uri_split[1])[0] + id_ = uri_split[2] + return kind, id_ + @classmethod def extract_ids(cls, values: APIInputValue, kind: RemoteObjectType | None = None) -> list[str]: def extract_id(value: str | Mapping[str, Any] | RemoteResponse) -> str: diff --git a/musify/processors/limit.py b/musify/processors/limit.py index 714d8fe4..6d3e36db 100644 --- a/musify/processors/limit.py +++ b/musify/processors/limit.py @@ -120,28 +120,42 @@ def limit[T: MusifyItem](self, items: list[T], ignore: Collection[T] = ()) -> No items_limit = [t for t in items] items.clear() - if self.kind == LimitType.ITEMS: # limit on items + if self.kind == LimitType.ITEMS: items.extend(items_limit[:self.limit_max]) - elif self.kind == LimitType.ALBUMS: # limit on albums - seen_albums = [] - for item in items_limit: - if not isinstance(item, Track): - ItemLimiterError("In order to limit on Album, all items must be of type 'Track'") - - if len(seen_albums) < self.limit_max and item.album not in seen_albums: - # album limit not yet reached - seen_albums.append(item.album) - if item.album in seen_albums: - items.append(item) - else: # limit on duration or size - count = 0 - for item in items_limit: - value = self._convert(item) - if count + value <= self.limit_max * self.allowance: # limit not yet reached - items.append(item) - count += value - if count > self.limit_max: # limit reached - break + elif self.kind == LimitType.ALBUMS: + items.extend(self._limit_on_albums(items_limit)) + else: + items.extend(self._limit_on_numeric(items_limit)) + + def _limit_on_albums[T: MusifyItem](self, items: list[T]) -> list[T]: + seen_albums = [] + result = [] + + for item in items: + if not isinstance(item, Track): + ItemLimiterError("In order to limit on Album, all items must be of type 'Track'") + + if len(seen_albums) < self.limit_max and item.album not in seen_albums: + # album limit not yet reached + seen_albums.append(item.album) + if item.album in seen_albums: + result.append(item) + + return result + + def _limit_on_numeric[T: MusifyItem](self, items: list[T]) -> list[T]: + count = 0 + result = [] + + for item in items: + value = self._convert(item) + if count + value <= self.limit_max * self.allowance: # limit not yet reached + result.append(item) + count += value + if count > self.limit_max: # limit reached + break + + return result def _convert(self, item: MusifyItem) -> float: """ diff --git a/musify/report.py b/musify/report.py index d744da2a..a5723839 100644 --- a/musify/report.py +++ b/musify/report.py @@ -98,26 +98,11 @@ def report_missing_tags( logger: MusifyLogger = logging.getLogger(__name__) logger.debug("Report missing tags: START") - tags = to_collection(tags, set) - tag_order = [field.name.lower() for field in ALL_FIELDS] - # noinspection PyTypeChecker - tag_names = set(TagField.__tags__) if Fields.ALL in tags else TagField.to_tags(tags) - tag_names: list[str] = list(sorted(tag_names, key=lambda x: tag_order.index(x))) - if isinstance(collections, LocalLibrary): collections = collections.albums - items_total = sum(len(collection) for collection in collections) - logger.info( - f"\33[1;95m ->\33[1;97m " - f"Checking {items_total} items for {'all' if match_all else 'any'} missing tags: \n" - f" \33[90m{', '.join(tag_names)}\33[0m" - ) - - if Fields.URI in tags or Fields.ALL in tags: - tag_names[tag_names.index(Fields.URI.name.lower())] = "has_uri" - if Fields.IMAGES in tags or Fields.ALL in tags: - tag_names[tag_names.index(Fields.IMAGES.name.lower())] = "has_image" + item_total = sum(len(collection) for collection in collections) + tag_names = _get_tag_names(logger=logger, tags=tags, item_total=item_total, match_all=match_all) missing: dict[str, dict[MusifyItem, tuple[str, ...]]] = {} for collection in collections: @@ -138,26 +123,53 @@ def report_missing_tags( logger.debug("Report missing tags: DONE\n") return missing + _log_missing_tags(logger=logger, missing=missing) + + missing_tags_all = {tag for items in missing.values() for tags in items.values() for tag in tags} + tag_order = [field.name.lower() for field in ALL_FIELDS] + logger.info( + f" \33[94mFound {len({item for items in missing.values() for item in items})} items with " + f"{'all' if match_all else 'any'} missing tags\33[0m: \n" + f" \33[90m{', '.join(sorted(missing_tags_all, key=lambda x: tag_order.index(x)))}\33[0m" + ) + logger.print() + logger.debug("Report missing tags: DONE\n") + return missing + + +def _get_tag_names(logger: MusifyLogger, tags: UnitIterable[TagField], item_total: int, match_all: bool) -> list[str]: + tags = to_collection(tags, set) + tag_order = [field.name.lower() for field in ALL_FIELDS] + # noinspection PyTypeChecker + tag_names_set = set(TagField.__tags__) if Fields.ALL in tags else TagField.to_tags(tags) + tag_names: list[str] = list(sorted(tag_names_set, key=lambda x: tag_order.index(x))) + + logger.info( + f"\33[1;95m ->\33[1;97m " + f"Checking {item_total} items for {'all' if match_all else 'any'} missing tags: \n" + f" \33[90m{', '.join(tag_names)}\33[0m" + ) + + # switch out tag names with expected attribute names to check on + if Fields.URI in tags or Fields.ALL in tags: + tag_names[tag_names.index(Fields.URI.name.lower())] = "has_uri" + if Fields.IMAGES in tags or Fields.ALL in tags: + tag_names[tag_names.index(Fields.IMAGES.name.lower())] = "has_image" + + return tag_names + + +def _log_missing_tags(logger: MusifyLogger, missing: dict[str, dict[MusifyItem, tuple[str, ...]]]) -> None: all_keys = {item.name for items in missing.values() for item in items} max_width = get_max_width(all_keys) - # log the report logger.print(REPORT) logger.report("\33[1;94mFound the following missing items by collection: \33[0m") logger.print(REPORT) + for name, result in missing.items(): logger.report(f"\33[1;91m -> {name} \33[0m") for item, tags in result.items(): n = align_string(item.name, max_width=max_width) logger.report(f"\33[96m{n} \33[0m| \33[93m{', '.join(tags)} \33[0m") logger.print(REPORT) - - missing_tags_all = {tag for items in missing.values() for tags in items.values() for tag in tags} - logger.info( - f" \33[94mFound {len(all_keys)} items with " - f"{'all' if match_all else 'any'} missing tags\33[0m: \n" - f" \33[90m{', '.join(sorted(missing_tags_all, key=lambda x: tag_order.index(x)))}\33[0m" - ) - logger.print() - logger.debug("Report missing tags: DONE\n") - return missing diff --git a/tests/libraries/local/track/test_track.py b/tests/libraries/local/track/test_track.py index 3ff9bb33..3905e282 100644 --- a/tests/libraries/local/track/test_track.py +++ b/tests/libraries/local/track/test_track.py @@ -311,43 +311,44 @@ def test_merge_dunder_methods(self, track: LocalTrack, item_modified: Track): class TestLocalTrackWriter: @staticmethod - def assert_track_tags_equal(actual: LocalTrack, expected: LocalTrack, check_tag_exists: bool = False): - """ - Assert the tags of the givens tracks equal. - ``check_tag_exists`` checks that a mapping for that tag exists before comparing, skipping any that don't - """ - if not check_tag_exists or actual.tag_map.title: - assert actual.title == expected.title, "title" - if not check_tag_exists or actual.tag_map.artist: - assert actual.artist == expected.artist, "artist" - if not check_tag_exists or actual.tag_map.album: - assert actual.album == expected.album, "album" - if not check_tag_exists or actual.tag_map.album_artist: - assert actual.album_artist == expected.album_artist, "album_artist" - if not check_tag_exists or actual.tag_map.track_number: - assert actual.track_number == expected.track_number, "track_number" - if not check_tag_exists or actual.tag_map.track_total: - assert actual.track_total == expected.track_total, "track_total" - if not check_tag_exists or actual.tag_map.genres: - assert actual.genres == expected.genres, "genres" - if not check_tag_exists or actual.tag_map.date: - assert actual.date == expected.date, "date" - if not check_tag_exists or actual.tag_map.year: - assert actual.year == expected.year, "year" - if not check_tag_exists or actual.tag_map.month: - assert actual.month == expected.month, "month" - if not check_tag_exists or actual.tag_map.day: - assert actual.day == expected.day, "day" - if not check_tag_exists or actual.tag_map.bpm: - assert actual.bpm == expected.bpm, "bpm" - if not check_tag_exists or actual.tag_map.key: - assert actual.key == expected.key, "key" - if not check_tag_exists or actual.tag_map.disc_number: - assert actual.disc_number == expected.disc_number, "disc_number" - if not check_tag_exists or actual.tag_map.disc_total: - assert actual.disc_total == expected.disc_total, "disc_total" - if not check_tag_exists or actual.tag_map.compilation: - assert actual.compilation == expected.compilation, "compilation" + def assert_track_tags_equal(actual: LocalTrack, expected: LocalTrack): + """Assert the tags of the givens tracks equal.""" + assert actual.title == expected.title, "title" + assert actual.artist == expected.artist, "artist" + assert actual.album == expected.album, "album" + assert actual.album_artist == expected.album_artist, "album_artist" + assert actual.track_number == expected.track_number, "track_number" + assert actual.track_total == expected.track_total, "track_total" + assert actual.genres == expected.genres, "genres" + assert actual.date == expected.date, "date" + assert actual.year == expected.year, "year" + assert actual.month == expected.month, "month" + assert actual.day == expected.day, "day" + assert actual.bpm == expected.bpm, "bpm" + assert actual.key == expected.key, "key" + assert actual.disc_number == expected.disc_number, "disc_number" + assert actual.disc_total == expected.disc_total, "disc_total" + assert actual.compilation == expected.compilation, "compilation" + + @staticmethod + def assert_track_tags_equal_on_existing(actual: LocalTrack, expected: LocalTrack): + """Assert the tags of the givens tracks equal only when a mapping for that tag exists.""" + assert not actual.tag_map.title or actual.title == expected.title, "title" + assert not actual.tag_map.artist or actual.artist == expected.artist, "artist" + assert not actual.tag_map.album or actual.album == expected.album, "album" + assert not actual.tag_map.album_artist or actual.album_artist == expected.album_artist, "album_artist" + assert not actual.tag_map.track_number or actual.track_number == expected.track_number, "track_number" + assert not actual.tag_map.track_total or actual.track_total == expected.track_total, "track_total" + assert not actual.tag_map.genres or actual.genres == expected.genres, "genres" + assert not actual.tag_map.date or actual.date == expected.date, "date" + assert not actual.tag_map.year or actual.year == expected.year, "year" + assert not actual.tag_map.month or actual.month == expected.month, "month" + assert not actual.tag_map.day or actual.day == expected.day, "day" + assert not actual.tag_map.bpm or actual.bpm == expected.bpm, "bpm" + assert not actual.tag_map.key or actual.key == expected.key, "key" + assert not actual.tag_map.disc_number or actual.disc_number == expected.disc_number, "disc_number" + assert not actual.tag_map.disc_total or actual.disc_total == expected.disc_total, "disc_total" + assert not actual.tag_map.compilation or actual.compilation == expected.compilation, "compilation" def test_clear_tags_dry_run(self, track: LocalTrack): track_update = track @@ -467,7 +468,7 @@ def test_update_tags_with_replace(self, track: LocalTrack): assert result.saved track_update_replace = deepcopy(track_update) - self.assert_track_tags_equal(track_update_replace, track_update, check_tag_exists=True) + self.assert_track_tags_equal_on_existing(track_update_replace, track_update) assert track_update_replace.comments == [new_uri] if new_uri == track._reader.unavailable_uri_dummy: diff --git a/tests/libraries/remote/core/utils.py b/tests/libraries/remote/core/utils.py index f3411d84..616c577c 100644 --- a/tests/libraries/remote/core/utils.py +++ b/tests/libraries/remote/core/utils.py @@ -65,33 +65,55 @@ def get_requests( """Get a get request from the history from the given URL and params""" requests = [] for request in self.request_history: - match_url = url is None - if not match_url: - if isinstance(url, str): - match_url = url.strip("/").endswith(request.path.strip("/")) - elif isinstance(url, re.Pattern): - match_url = bool(url.search(request.url)) - - match_method = method is None - if not match_method: - # noinspection PyProtectedMember - match_method = request._request.method.upper() == method.upper() - - match_params = params is None - if not match_params and request.query: - for k, v in parse_qs(request.query).items(): - if k in params and str(params[k]) != v[0]: - break - match_params = True - - match_response = response is None - if not match_response and request.body: - for k, v in request.json().items(): - if k in response and str(response[k]) != str(v): - break - match_response = True - - if match_url and match_method and match_params and match_response: + matches = [ + self._get_match_from_url(request=request, url=url), + self._get_match_from_method(request=request, method=method), + self._get_match_from_params(request=request, params=params), + self._get_match_from_response(request=request, response=response), + ] + if all(matches): requests.append(request) return requests + + @staticmethod + def _get_match_from_url(request: _RequestObjectProxy, url: str | re.Pattern[str] | None = None) -> bool: + match = url is None + if not match: + if isinstance(url, str): + match = url.strip("/").endswith(request.path.strip("/")) + elif isinstance(url, re.Pattern): + match = bool(url.search(request.url)) + + return match + + @staticmethod + def _get_match_from_method(request: _RequestObjectProxy, method: str | None = None) -> bool: + match = method is None + if not match: + # noinspection PyProtectedMember + match = request._request.method.upper() == method.upper() + + return match + + @staticmethod + def _get_match_from_params(request: _RequestObjectProxy, params: dict[str, Any] | None = None) -> bool: + match = params is None + if not match and request.query: + for k, v in parse_qs(request.query).items(): + if k in params and str(params[k]) != v[0]: + break + match = True + + return match + + @staticmethod + def _get_match_from_response(request: _RequestObjectProxy, response: dict[str, Any] | None = None) -> bool: + match = response is None + if not match and request.body: + for k, v in request.json().items(): + if k in response and str(response[k]) != str(v): + break + match = True + + return match diff --git a/tests/libraries/remote/spotify/api/test_item.py b/tests/libraries/remote/spotify/api/test_item.py index 6b68b098..41b85d9f 100644 --- a/tests/libraries/remote/spotify/api/test_item.py +++ b/tests/libraries/remote/spotify/api/test_item.py @@ -631,34 +631,42 @@ def assert_extend_tracks_results( assert len(results) == len(test) for result in results: - if features: - expected = features[result[self.id_key]] - if features_in_results: - assert result["audio_features"] == expected - else: - assert "audio_features" not in result - - if test: - assert test[result[self.id_key]]["audio_features"] == expected + self._assert_extend_tracks_result( + result=result, + key="audio_features", + test=test, + extension=features[result[self.id_key]] if features else None, + extension_in_results=features_in_results + ) + + self._assert_extend_tracks_result( + result=result, + key="audio_analysis", + test=test, + extension=analysis[result[self.id_key]] | {self.id_key: result[self.id_key]} if analysis else None, + extension_in_results=analysis_in_results + ) + + def _assert_extend_tracks_result( + self, + result: dict[str, Any], + key: str, + test: dict[str, dict[str, Any]] | None = None, + extension: dict[str, Any] | None = None, + extension_in_results: bool = True, + ): + if extension: + if extension_in_results: + assert result[key] == extension else: - assert "audio_features" not in result - if test: - assert "audio_features" not in test[result[self.id_key]] + assert key not in result - if analysis: - expected = analysis[result[self.id_key]] | {self.id_key: result[self.id_key]} - assert result["audio_analysis"] == expected - if analysis_in_results: - assert result["audio_analysis"] == expected - else: - assert "audio_analysis" not in result - - if test: - assert test[result[self.id_key]]["audio_analysis"] == expected - else: - assert "audio_analysis" not in result - if test: - assert "audio_analysis" not in test[result[self.id_key]] + if test: + assert test[result[self.id_key]][key] == extension + else: + assert key not in result + if test: + assert key not in test[result[self.id_key]] def assert_extend_tracks_calls( self,