From 1ea89be7a163ae1f1422b8116543322c09caa55b Mon Sep 17 00:00:00 2001 From: George <41969151+geo-martino@users.noreply.github.com> Date: Fri, 24 May 2024 17:22:39 -0400 Subject: [PATCH] implement async requests framework (#77) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * implement async sqlite cache backend * implement async responses & session * implement async changes to entire api package * add repository contains and clear test * implement async changes to spotify/remote api items methods * implement log method for RequestHandler and make SpotifyAPI use this * implement async changes to all spotify/remote api methods * implement async changes to all spotify/remote object methods * implement async changes to RemoteItemChecker * implement async changes to all remote processors + slow tests run first * all basic implementation done + tests passing again 🎉 * reduce complexity in extend_items method * update release history + docs * repositories now only create table on await/async with + perpetuate changes to higher order classes * remove PaginatedRequestSettings + RequestSettings now responsible for getting key * add explicit context management tests * update how-to scripts for changes * update release history * update readme how-to --- .flake8 | 4 +- README.md | 221 +++++++---- README.template.md | 221 +++++++---- .../scripts/local.library.backup-restore.py | 31 -- .../local.library.backup-restore/p0.py | 2 + .../local.library.backup-restore/p1.py | 7 + .../local.library.backup-restore/p2.py | 6 + .../local.library.backup-restore/p3.py | 15 + .../local.library.backup-restore/p4.py | 7 + docs/_howto/scripts/local.library.load.py | 46 --- .../scripts/local.library.load/p0_local.py | 6 + .../scripts/local.library.load/p0_musicbee.py | 3 + docs/_howto/scripts/local.library.load/p1.py | 15 + docs/_howto/scripts/local.library.load/p2.py | 10 + docs/_howto/scripts/local.library.load/p3.py | 14 + .../scripts/local.playlist.load-save.py | 50 --- .../scripts/local.playlist.load-save/p1.py | 11 + .../scripts/local.playlist.load-save/p2.py | 9 + .../scripts/local.playlist.load-save/p3.py | 5 + .../local.playlist.load-save/p3_mapper.py | 5 + .../local.playlist.load-save/p3_tracks.py | 13 + .../local.playlist.load-save/p3_wrangler.py | 5 + .../scripts/local.playlist.load-save/p4.py | 8 + .../scripts/local.playlist.load-save/p5.py | 4 + docs/_howto/scripts/local.track.load-save.py | 57 --- .../scripts/local.track.load-save/p1.py | 9 + .../scripts/local.track.load-save/p1_load.py | 3 + .../local.track.load-save/p1_wrangler.py | 4 + .../scripts/local.track.load-save/p2.py | 20 + .../scripts/local.track.load-save/p3_all.py | 6 + .../scripts/local.track.load-save/p3_tags.py | 19 + docs/_howto/scripts/remote.new-music.py | 52 --- docs/_howto/scripts/remote.new-music/p0.py | 2 + docs/_howto/scripts/remote.new-music/p2.py | 10 + docs/_howto/scripts/remote.new-music/p3.py | 16 + docs/_howto/scripts/remote.new-music/p4.py | 25 ++ docs/_howto/scripts/remote.new-music/p5.py | 25 ++ docs/_howto/scripts/remote.new-music/p99.py | 13 + docs/_howto/scripts/reports.py | 26 -- docs/_howto/scripts/reports/p0.py | 8 + .../_howto/scripts/reports/p1_missing_tags.py | 16 + .../reports/p1_playlist_differences.py | 5 + docs/_howto/scripts/spotify.api.py | 6 - .../scripts/spotify.library.backup-restore.py | 17 - .../spotify.library.backup-restore/p0.py | 5 + .../spotify.library.backup-restore/p1.py | 7 + .../spotify.library.backup-restore/p2.py | 19 + docs/_howto/scripts/spotify.load.py | 62 --- docs/_howto/scripts/spotify.load/p0.py | 2 + docs/_howto/scripts/spotify.load/p1.py | 44 +++ docs/_howto/scripts/spotify.load/p2.py | 53 +++ docs/_howto/scripts/spotify.load/p3.py | 22 ++ docs/_howto/scripts/spotify.load/p99.py | 6 + docs/_howto/scripts/sync.py | 60 --- docs/_howto/scripts/sync/p0.py | 3 + docs/_howto/scripts/sync/p1.py | 22 ++ docs/_howto/scripts/sync/p2.py | 26 ++ docs/_howto/scripts/sync/p3.py | 18 + docs/_howto/scripts/sync/p4.py | 13 + docs/_howto/scripts/sync/p99.py | 12 + docs/howto.library.backup-restore.rst | 30 +- docs/howto.local.library.load.rst | 18 +- docs/howto.local.playlist.load-save.rst | 36 +- docs/howto.local.track.load-save.rst | 27 +- docs/howto.remote.new-music.rst | 33 +- docs/howto.reports.rst | 10 +- docs/howto.spotify.load.rst | 26 +- docs/howto.sync.rst | 39 +- docs/release-history.rst | 41 +- musify/api/authorise.py | 94 +++-- musify/api/cache/backend/base.py | 233 ++++++----- musify/api/cache/backend/sqlite.py | 291 +++++++++----- musify/api/cache/response.py | 40 ++ musify/api/cache/session.py | 179 +++------ musify/api/exception.py | 8 +- musify/api/request.py | 251 +++++++----- musify/core/printer.py | 2 +- musify/libraries/core/collection.py | 10 +- musify/libraries/local/library/library.py | 16 +- musify/libraries/local/library/musicbee.py | 6 +- musify/libraries/local/playlist/m3u.py | 13 +- musify/libraries/local/playlist/xautopf.py | 3 +- musify/libraries/local/track/track.py | 4 +- musify/libraries/remote/core/api.py | 68 ++-- musify/libraries/remote/core/base.py | 17 +- musify/libraries/remote/core/library.py | 112 +++--- musify/libraries/remote/core/object.py | 23 +- .../libraries/remote/core/processors/check.py | 107 ++--- .../remote/core/processors/search.py | 81 ++-- musify/libraries/remote/core/types.py | 5 +- musify/libraries/remote/spotify/api/api.py | 37 +- musify/libraries/remote/spotify/api/base.py | 15 +- musify/libraries/remote/spotify/api/cache.py | 51 ++- musify/libraries/remote/spotify/api/item.py | 155 ++++---- musify/libraries/remote/spotify/api/misc.py | 19 +- .../libraries/remote/spotify/api/playlist.py | 44 +-- musify/libraries/remote/spotify/base.py | 2 - musify/libraries/remote/spotify/library.py | 18 +- musify/libraries/remote/spotify/object.py | 127 +++--- musify/libraries/remote/spotify/processors.py | 7 +- musify/log/logger.py | 7 +- musify/processors/match.py | 8 +- musify/types.py | 12 +- pyproject.toml | 11 +- .../__resources/library/musicbee_library.xml | 4 +- .../__resources/playlist/Simple Playlist.m3u | 4 +- .../playlist/The Best Playlist Ever.xautopf | 4 +- .../{noise_flac.flac => NOISE_FLaC.flac} | Bin .../track/{noise_mp3.mp3 => noiSE_mP3.mp3} | Bin tests/api/cache/backend/test_sqlite.py | 260 ++++++++----- tests/api/cache/backend/testers.py | 366 +++++++++--------- tests/api/cache/backend/utils.py | 52 +-- tests/api/cache/test_response.py | 37 ++ tests/api/cache/test_session.py | 104 +++-- tests/api/test_authorise.py | 104 ++--- tests/api/test_request.py | 231 ++++++----- tests/conftest.py | 47 ++- tests/libraries/local/conftest.py | 1 + tests/libraries/local/playlist/test_m3u.py | 11 +- tests/libraries/local/utils.py | 4 +- tests/libraries/remote/core/api.py | 17 +- tests/libraries/remote/core/library.py | 108 +++--- tests/libraries/remote/core/object.py | 66 ++-- .../libraries/remote/core/processors/check.py | 105 ++--- .../remote/core/processors/search.py | 73 ++-- tests/libraries/remote/core/utils.py | 130 +++++-- tests/libraries/remote/spotify/api/mock.py | 152 ++++---- .../libraries/remote/spotify/api/test_api.py | 131 ++++--- .../remote/spotify/api/test_artist.py | 36 +- .../remote/spotify/api/test_cache.py | 4 +- .../libraries/remote/spotify/api/test_item.py | 257 ++++++------ .../libraries/remote/spotify/api/test_misc.py | 62 ++- .../remote/spotify/api/test_playlist.py | 145 +++---- tests/libraries/remote/spotify/api/utils.py | 5 +- tests/libraries/remote/spotify/conftest.py | 2 +- .../remote/spotify/object/test_album.py | 56 +-- .../remote/spotify/object/test_artist.py | 55 ++- .../remote/spotify/object/test_playlist.py | 119 +++--- .../remote/spotify/object/test_track.py | 12 +- .../remote/spotify/object/testers.py | 33 +- .../libraries/remote/spotify/test_library.py | 68 ++-- 141 files changed, 3775 insertions(+), 2857 deletions(-) delete mode 100644 docs/_howto/scripts/local.library.backup-restore.py create mode 100644 docs/_howto/scripts/local.library.backup-restore/p0.py create mode 100644 docs/_howto/scripts/local.library.backup-restore/p1.py create mode 100644 docs/_howto/scripts/local.library.backup-restore/p2.py create mode 100644 docs/_howto/scripts/local.library.backup-restore/p3.py create mode 100644 docs/_howto/scripts/local.library.backup-restore/p4.py delete mode 100644 docs/_howto/scripts/local.library.load.py create mode 100644 docs/_howto/scripts/local.library.load/p0_local.py create mode 100644 docs/_howto/scripts/local.library.load/p0_musicbee.py create mode 100644 docs/_howto/scripts/local.library.load/p1.py create mode 100644 docs/_howto/scripts/local.library.load/p2.py create mode 100644 docs/_howto/scripts/local.library.load/p3.py delete mode 100644 docs/_howto/scripts/local.playlist.load-save.py create mode 100644 docs/_howto/scripts/local.playlist.load-save/p1.py create mode 100644 docs/_howto/scripts/local.playlist.load-save/p2.py create mode 100644 docs/_howto/scripts/local.playlist.load-save/p3.py create mode 100644 docs/_howto/scripts/local.playlist.load-save/p3_mapper.py create mode 100644 docs/_howto/scripts/local.playlist.load-save/p3_tracks.py create mode 100644 docs/_howto/scripts/local.playlist.load-save/p3_wrangler.py create mode 100644 docs/_howto/scripts/local.playlist.load-save/p4.py create mode 100644 docs/_howto/scripts/local.playlist.load-save/p5.py delete mode 100644 docs/_howto/scripts/local.track.load-save.py create mode 100644 docs/_howto/scripts/local.track.load-save/p1.py create mode 100644 docs/_howto/scripts/local.track.load-save/p1_load.py create mode 100644 docs/_howto/scripts/local.track.load-save/p1_wrangler.py create mode 100644 docs/_howto/scripts/local.track.load-save/p2.py create mode 100644 docs/_howto/scripts/local.track.load-save/p3_all.py create mode 100644 docs/_howto/scripts/local.track.load-save/p3_tags.py delete mode 100644 docs/_howto/scripts/remote.new-music.py create mode 100644 docs/_howto/scripts/remote.new-music/p0.py create mode 100644 docs/_howto/scripts/remote.new-music/p2.py create mode 100644 docs/_howto/scripts/remote.new-music/p3.py create mode 100644 docs/_howto/scripts/remote.new-music/p4.py create mode 100644 docs/_howto/scripts/remote.new-music/p5.py create mode 100644 docs/_howto/scripts/remote.new-music/p99.py delete mode 100644 docs/_howto/scripts/reports.py create mode 100644 docs/_howto/scripts/reports/p0.py create mode 100644 docs/_howto/scripts/reports/p1_missing_tags.py create mode 100644 docs/_howto/scripts/reports/p1_playlist_differences.py delete mode 100644 docs/_howto/scripts/spotify.library.backup-restore.py create mode 100644 docs/_howto/scripts/spotify.library.backup-restore/p0.py create mode 100644 docs/_howto/scripts/spotify.library.backup-restore/p1.py create mode 100644 docs/_howto/scripts/spotify.library.backup-restore/p2.py delete mode 100644 docs/_howto/scripts/spotify.load.py create mode 100644 docs/_howto/scripts/spotify.load/p0.py create mode 100644 docs/_howto/scripts/spotify.load/p1.py create mode 100644 docs/_howto/scripts/spotify.load/p2.py create mode 100644 docs/_howto/scripts/spotify.load/p3.py create mode 100644 docs/_howto/scripts/spotify.load/p99.py delete mode 100644 docs/_howto/scripts/sync.py create mode 100644 docs/_howto/scripts/sync/p0.py create mode 100644 docs/_howto/scripts/sync/p1.py create mode 100644 docs/_howto/scripts/sync/p2.py create mode 100644 docs/_howto/scripts/sync/p3.py create mode 100644 docs/_howto/scripts/sync/p4.py create mode 100644 docs/_howto/scripts/sync/p99.py create mode 100644 musify/api/cache/response.py rename tests/__resources/track/{noise_flac.flac => NOISE_FLaC.flac} (100%) rename tests/__resources/track/{noise_mp3.mp3 => noiSE_mP3.mp3} (100%) create mode 100644 tests/api/cache/test_response.py diff --git a/.flake8 b/.flake8 index e9d59bf2..bc12799c 100644 --- a/.flake8 +++ b/.flake8 @@ -3,7 +3,7 @@ extend-exclude = .venv* jupyter/ notebooks/ extend-ignore = E266 per-file-ignores = **/__init__.py:F401 - docs/_howto/scripts/*:E402 + docs/_howto/scripts/*:E402,F401,F403,F405 tests/**/test_*.py:F811 max-line-length = 120 -max-complexity = 12 +max-complexity = 10 diff --git a/README.md b/README.md index aed0bd16..0c818bc9 100644 --- a/README.md +++ b/README.md @@ -81,94 +81,155 @@ For more detailed guides, check out the [documentation](https://geo-martino.gith > The scopes listed in this example will allow access to read your library data and write to your playlists. > See Spotify Web API documentation for more information about [scopes](https://developer.spotify.com/documentation/web-api/concepts/scopes) ```python - from musify.libraries.remote.spotify.api import SpotifyAPI - - api = SpotifyAPI( - client_id="", - client_secret="", - scopes=[ - "user-library-read", - "user-follow-read", - "playlist-read-collaborative", - "playlist-read-private", - "playlist-modify-public", - "playlist-modify-private" - ], - # providing a `token_file_path` will save the generated token to your system - # for quicker authorisations in future - token_file_path="" - ) - - # authorise the program to access your Spotify data in your web browser - api.authorise() + from musify.libraries.remote.spotify.api import SpotifyAPI + + spotify_api = SpotifyAPI( + client_id="", + client_secret="", + scopes=[ + "user-library-read", + "user-follow-read", + "playlist-read-collaborative", + "playlist-read-private", + "playlist-modify-public", + "playlist-modify-private" + ], + # providing a `token_file_path` will save the generated token to your system + # for quicker authorisations in future + token_file_path="" + ) ``` -4. Create a `SpotifyLibrary` object and load your library data as follows: +4. Define helper functions for loading your `SpotifyLibrary` data: ```python - from musify.libraries.remote.spotify.library import SpotifyLibrary - - library = SpotifyLibrary(api=api) - - # if you have a very large library, this will take some time... - library.load() - - # ...or you may also just load distinct sections of your library - library.load_playlists() - library.load_tracks() - library.load_saved_albums() - library.load_saved_artists() - - # enrich the loaded objects; see each function's docstring for more info on arguments - # each of these will take some time depending on the size of your library - library.enrich_tracks(features=True, analysis=False, albums=False, artists=False) - library.enrich_saved_albums() - library.enrich_saved_artists(tracks=True, types=("album", "single")) - - # optionally log stats about these sections - library.log_playlists() - library.log_tracks() - library.log_albums() - library.log_artists() - - # pretty print an overview of your library - print(library) + from musify.libraries.remote.spotify.library import SpotifyLibrary + + + async def load_library(library: SpotifyLibrary) -> None: + """Load the objects for a given ``library``. Does not enrich the loaded data.""" + # authorise the program to access your Spotify data in your web browser + async with library: + # if you have a very large library, this will take some time... + await library.load() + + + async def load_library_by_parts(library: SpotifyLibrary) -> None: + """Load the objects for a given ``library`` by each of its distinct parts. Does not enrich the loaded data.""" + # authorise the program to access your Spotify data in your web browser + async with library: + # load distinct sections of your library + await library.load_playlists() + await library.load_tracks() + await library.load_saved_albums() + await library.load_saved_artists() + + + async def enrich_library(library: SpotifyLibrary) -> None: + """Enrich the loaded objects in the given ``library``""" + # authorise the program to access your Spotify data in your web browser + async with library: + # enrich the loaded objects; see each function's docstring for more info on arguments + # each of these will take some time depending on the size of your library + await library.enrich_tracks(features=True, analysis=False, albums=False, artists=False) + await library.enrich_saved_albums() + await library.enrich_saved_artists(tracks=True, types=("album", "single")) + + + def log_library(library: SpotifyLibrary) -> None: + """Log stats about the loaded ``library``""" + library.log_playlists() + library.log_tracks() + library.log_albums() + library.log_artists() + + # pretty print an overview of your library + print(library) ``` -5. Load some Spotify objects using any of the supported identifiers as follows: +5. Define helper functions for loading some Spotify objects using any of the supported identifiers: ```python - from musify.libraries.remote.spotify.object import SpotifyTrack, SpotifyAlbum, SpotifyPlaylist, SpotifyArtist - - # load by ID - track1 = SpotifyTrack.load("6fWoFduMpBem73DMLCOh1Z", api=api) - # load by URI - track2 = SpotifyTrack.load("spotify:track:4npv0xZO9fVLBmDS2XP9Bw", api=api) - # load by open/external style URL - track3 = SpotifyTrack.load("https://open.spotify.com/track/1TjVbzJUAuOvas1bL00TiH", api=api) - # load by API style URI - track4 = SpotifyTrack.load("https://api.spotify.com/v1/tracks/6pmSweeisgfxxsiLINILdJ", api=api) - - # load many different kinds of supported Spotify types - playlist = SpotifyPlaylist.load("spotify:playlist:37i9dQZF1E4zg1xOOORiP1", api=api, extend_tracks=True) - album = SpotifyAlbum.load("https://open.spotify.com/album/0rAWaAAMfzHzCbYESj4mfx", api=api, extend_tracks=True) - artist = SpotifyArtist.load("1odSzdzUpm3ZEEb74GdyiS", api=api, extend_tracks=True) - - # pretty print information about the loaded objects - print(track1, track2, track3, playlist, album, artist, sep="\n") + from musify.libraries.remote.spotify.object import SpotifyTrack, SpotifyAlbum, SpotifyPlaylist, SpotifyArtist + + + async def load_playlist(api: SpotifyAPI) -> SpotifyPlaylist: + # authorise the program to access your Spotify data in your web browser + async with api as a: + playlist = await SpotifyPlaylist.load("spotify:playlist:37i9dQZF1E4zg1xOOORiP1", api=a, extend_tracks=True) + return playlist + + + async def load_tracks(api: SpotifyAPI) -> list[SpotifyTrack]: + tracks = [] + + # authorise the program to access your Spotify data in your web browser + async with api as a: + # load by ID + tracks.append(await SpotifyTrack.load("6fWoFduMpBem73DMLCOh1Z", api=a)) + # load by URI + tracks.append(await SpotifyTrack.load("spotify:track:4npv0xZO9fVLBmDS2XP9Bw", api=a)) + # load by open/external style URL + tracks.append(await SpotifyTrack.load("https://open.spotify.com/track/1TjVbzJUAuOvas1bL00TiH", api=a)) + # load by API style URI + tracks.append(await SpotifyTrack.load("https://api.spotify.com/v1/tracks/6pmSweeisgfxxsiLINILdJ", api=api)) + + return tracks + + + async def load_album(api: SpotifyAPI) -> SpotifyAlbum: + # authorise the program to access your Spotify data in your web browser + async with api as a: + album = await SpotifyAlbum.load( + "https://open.spotify.com/album/0rAWaAAMfzHzCbYESj4mfx", api=a, extend_tracks=True + ) + return album + + + async def load_artist(api: SpotifyAPI) -> SpotifyArtist: + # authorise the program to access your Spotify data in your web browser + async with api as a: + artist = await SpotifyArtist.load("1odSzdzUpm3ZEEb74GdyiS", api=a, extend_tracks=True) + return artist + + + async def load_objects(api: SpotifyAPI) -> None: + playlist = await load_playlist(api) + tracks = await load_tracks(api) + album = await load_album(api) + artist = await load_artist(api) + + # pretty print information about the loaded objects + print(playlist, *tracks, album, artist, sep="\n") ``` -6. Add some tracks to a playlist in your library, synchronise with Spotify, and log the results as follows: +6. Define helper function for adding some tracks to a playlist in your library, synchronising with Spotify, and logging the results: > **NOTE**: This step will only work if you chose to load either your playlists or your entire library in step 4. ```python - my_playlist = library.playlists[""] # case sensitive - - # add a track to the playlist - my_playlist.append(track1) - - # add an album to the playlist using either of the following - my_playlist.extend(album) - my_playlist += album - - # sync the object with Spotify and log the results - result = my_playlist.sync(dry_run=False) - library.log_sync(result) + async def update_playlist(name: str, library: SpotifyLibrary) -> None: + """Update a playlist with the given ``name`` in the given ``library``""" + tracks = await load_tracks(library.api) + album = await load_album(library.api) + await load_library(library) + + my_playlist = library.playlists[name] + + # add a track to the playlist + my_playlist.append(tracks[0]) + + # add an album to the playlist using either of the following + my_playlist.extend(album) + my_playlist += album + + # sync the object with Spotify and log the results + async with library: + result = await my_playlist.sync(dry_run=False) + library.log_sync(result) + + asyncio.run(update_playlist(spotify_api)) + ``` +7. Run the program: + ```python + import asyncio + + asyncio.run(load_objects(api)) + asyncio.run(update_playlist("", api)) # case sensitive ``` ### Local diff --git a/README.template.md b/README.template.md index 1a5fe148..d3a41966 100644 --- a/README.template.md +++ b/README.template.md @@ -81,94 +81,155 @@ For more detailed guides, check out the [documentation](https://{program_owner_u > The scopes listed in this example will allow access to read your library data and write to your playlists. > See Spotify Web API documentation for more information about [scopes](https://developer.spotify.com/documentation/web-api/concepts/scopes) ```python - from musify.libraries.remote.spotify.api import SpotifyAPI - - api = SpotifyAPI( - client_id="", - client_secret="", - scopes=[ - "user-library-read", - "user-follow-read", - "playlist-read-collaborative", - "playlist-read-private", - "playlist-modify-public", - "playlist-modify-private" - ], - # providing a `token_file_path` will save the generated token to your system - # for quicker authorisations in future - token_file_path="" - ) - - # authorise the program to access your Spotify data in your web browser - api.authorise() + from musify.libraries.remote.spotify.api import SpotifyAPI + + spotify_api = SpotifyAPI( + client_id="", + client_secret="", + scopes=[ + "user-library-read", + "user-follow-read", + "playlist-read-collaborative", + "playlist-read-private", + "playlist-modify-public", + "playlist-modify-private" + ], + # providing a `token_file_path` will save the generated token to your system + # for quicker authorisations in future + token_file_path="" + ) ``` -4. Create a `SpotifyLibrary` object and load your library data as follows: +4. Define helper functions for loading your `SpotifyLibrary` data: ```python - from musify.libraries.remote.spotify.library import SpotifyLibrary - - library = SpotifyLibrary(api=api) - - # if you have a very large library, this will take some time... - library.load() - - # ...or you may also just load distinct sections of your library - library.load_playlists() - library.load_tracks() - library.load_saved_albums() - library.load_saved_artists() - - # enrich the loaded objects; see each function's docstring for more info on arguments - # each of these will take some time depending on the size of your library - library.enrich_tracks(features=True, analysis=False, albums=False, artists=False) - library.enrich_saved_albums() - library.enrich_saved_artists(tracks=True, types=("album", "single")) - - # optionally log stats about these sections - library.log_playlists() - library.log_tracks() - library.log_albums() - library.log_artists() - - # pretty print an overview of your library - print(library) + from musify.libraries.remote.spotify.library import SpotifyLibrary + + + async def load_library(library: SpotifyLibrary) -> None: + """Load the objects for a given ``library``. Does not enrich the loaded data.""" + # authorise the program to access your Spotify data in your web browser + async with library: + # if you have a very large library, this will take some time... + await library.load() + + + async def load_library_by_parts(library: SpotifyLibrary) -> None: + """Load the objects for a given ``library`` by each of its distinct parts. Does not enrich the loaded data.""" + # authorise the program to access your Spotify data in your web browser + async with library: + # load distinct sections of your library + await library.load_playlists() + await library.load_tracks() + await library.load_saved_albums() + await library.load_saved_artists() + + + async def enrich_library(library: SpotifyLibrary) -> None: + """Enrich the loaded objects in the given ``library``""" + # authorise the program to access your Spotify data in your web browser + async with library: + # enrich the loaded objects; see each function's docstring for more info on arguments + # each of these will take some time depending on the size of your library + await library.enrich_tracks(features=True, analysis=False, albums=False, artists=False) + await library.enrich_saved_albums() + await library.enrich_saved_artists(tracks=True, types=("album", "single")) + + + def log_library(library: SpotifyLibrary) -> None: + """Log stats about the loaded ``library``""" + library.log_playlists() + library.log_tracks() + library.log_albums() + library.log_artists() + + # pretty print an overview of your library + print(library) ``` -5. Load some Spotify objects using any of the supported identifiers as follows: +5. Define helper functions for loading some Spotify objects using any of the supported identifiers: ```python - from musify.libraries.remote.spotify.object import SpotifyTrack, SpotifyAlbum, SpotifyPlaylist, SpotifyArtist - - # load by ID - track1 = SpotifyTrack.load("6fWoFduMpBem73DMLCOh1Z", api=api) - # load by URI - track2 = SpotifyTrack.load("spotify:track:4npv0xZO9fVLBmDS2XP9Bw", api=api) - # load by open/external style URL - track3 = SpotifyTrack.load("https://open.spotify.com/track/1TjVbzJUAuOvas1bL00TiH", api=api) - # load by API style URI - track4 = SpotifyTrack.load("https://api.spotify.com/v1/tracks/6pmSweeisgfxxsiLINILdJ", api=api) - - # load many different kinds of supported Spotify types - playlist = SpotifyPlaylist.load("spotify:playlist:37i9dQZF1E4zg1xOOORiP1", api=api, extend_tracks=True) - album = SpotifyAlbum.load("https://open.spotify.com/album/0rAWaAAMfzHzCbYESj4mfx", api=api, extend_tracks=True) - artist = SpotifyArtist.load("1odSzdzUpm3ZEEb74GdyiS", api=api, extend_tracks=True) - - # pretty print information about the loaded objects - print(track1, track2, track3, playlist, album, artist, sep="\n") + from musify.libraries.remote.spotify.object import SpotifyTrack, SpotifyAlbum, SpotifyPlaylist, SpotifyArtist + + + async def load_playlist(api: SpotifyAPI) -> SpotifyPlaylist: + # authorise the program to access your Spotify data in your web browser + async with api as a: + playlist = await SpotifyPlaylist.load("spotify:playlist:37i9dQZF1E4zg1xOOORiP1", api=a, extend_tracks=True) + return playlist + + + async def load_tracks(api: SpotifyAPI) -> list[SpotifyTrack]: + tracks = [] + + # authorise the program to access your Spotify data in your web browser + async with api as a: + # load by ID + tracks.append(await SpotifyTrack.load("6fWoFduMpBem73DMLCOh1Z", api=a)) + # load by URI + tracks.append(await SpotifyTrack.load("spotify:track:4npv0xZO9fVLBmDS2XP9Bw", api=a)) + # load by open/external style URL + tracks.append(await SpotifyTrack.load("https://open.spotify.com/track/1TjVbzJUAuOvas1bL00TiH", api=a)) + # load by API style URI + tracks.append(await SpotifyTrack.load("https://api.spotify.com/v1/tracks/6pmSweeisgfxxsiLINILdJ", api=api)) + + return tracks + + + async def load_album(api: SpotifyAPI) -> SpotifyAlbum: + # authorise the program to access your Spotify data in your web browser + async with api as a: + album = await SpotifyAlbum.load( + "https://open.spotify.com/album/0rAWaAAMfzHzCbYESj4mfx", api=a, extend_tracks=True + ) + return album + + + async def load_artist(api: SpotifyAPI) -> SpotifyArtist: + # authorise the program to access your Spotify data in your web browser + async with api as a: + artist = await SpotifyArtist.load("1odSzdzUpm3ZEEb74GdyiS", api=a, extend_tracks=True) + return artist + + + async def load_objects(api: SpotifyAPI) -> None: + playlist = await load_playlist(api) + tracks = await load_tracks(api) + album = await load_album(api) + artist = await load_artist(api) + + # pretty print information about the loaded objects + print(playlist, *tracks, album, artist, sep="\n") ``` -6. Add some tracks to a playlist in your library, synchronise with Spotify, and log the results as follows: +6. Define helper function for adding some tracks to a playlist in your library, synchronising with Spotify, and logging the results: > **NOTE**: This step will only work if you chose to load either your playlists or your entire library in step 4. ```python - my_playlist = library.playlists[""] # case sensitive - - # add a track to the playlist - my_playlist.append(track1) - - # add an album to the playlist using either of the following - my_playlist.extend(album) - my_playlist += album - - # sync the object with Spotify and log the results - result = my_playlist.sync(dry_run=False) - library.log_sync(result) + async def update_playlist(name: str, library: SpotifyLibrary) -> None: + """Update a playlist with the given ``name`` in the given ``library``""" + tracks = await load_tracks(library.api) + album = await load_album(library.api) + await load_library(library) + + my_playlist = library.playlists[name] + + # add a track to the playlist + my_playlist.append(tracks[0]) + + # add an album to the playlist using either of the following + my_playlist.extend(album) + my_playlist += album + + # sync the object with Spotify and log the results + async with library: + result = await my_playlist.sync(dry_run=False) + library.log_sync(result) + + asyncio.run(update_playlist(spotify_api)) + ``` +7. Run the program: + ```python + import asyncio + + asyncio.run(load_objects(api)) + asyncio.run(update_playlist("", api)) # case sensitive ``` ### Local diff --git a/docs/_howto/scripts/local.library.backup-restore.py b/docs/_howto/scripts/local.library.backup-restore.py deleted file mode 100644 index 30480f80..00000000 --- a/docs/_howto/scripts/local.library.backup-restore.py +++ /dev/null @@ -1,31 +0,0 @@ -from musify.libraries.local.library import LocalLibrary -library = LocalLibrary() - -import json - -path = "local_backup.json" -with open(path, "w") as file: - json.dump(library.json(), file, indent=2) - -with open(path, "r") as file: - backup = json.load(file) -tracks = {track["path"]: track for track in backup["tracks"]} - -library.restore_tracks(tracks) - -from musify.libraries.local.track.field import LocalTrackField - -tags = [ - LocalTrackField.TITLE, - LocalTrackField.GENRES, - LocalTrackField.KEY, - LocalTrackField.BPM, - LocalTrackField.DATE, - LocalTrackField.COMPILATION, - LocalTrackField.IMAGES -] - -library.restore_tracks(tracks, tags=tags) - -results = library.save_tracks(tags=tags, replace=True, dry_run=False) -library.log_save_tracks_result(results) diff --git a/docs/_howto/scripts/local.library.backup-restore/p0.py b/docs/_howto/scripts/local.library.backup-restore/p0.py new file mode 100644 index 00000000..8f9c401e --- /dev/null +++ b/docs/_howto/scripts/local.library.backup-restore/p0.py @@ -0,0 +1,2 @@ +from musify.libraries.local.library import LocalLibrary +library = LocalLibrary() diff --git a/docs/_howto/scripts/local.library.backup-restore/p1.py b/docs/_howto/scripts/local.library.backup-restore/p1.py new file mode 100644 index 00000000..b2024baf --- /dev/null +++ b/docs/_howto/scripts/local.library.backup-restore/p1.py @@ -0,0 +1,7 @@ +from p0 import * + +import json + +path = "local_backup.json" +with open(path, "w") as file: + json.dump(library.json(), file, indent=2) diff --git a/docs/_howto/scripts/local.library.backup-restore/p2.py b/docs/_howto/scripts/local.library.backup-restore/p2.py new file mode 100644 index 00000000..57808f30 --- /dev/null +++ b/docs/_howto/scripts/local.library.backup-restore/p2.py @@ -0,0 +1,6 @@ +from p1 import * + +with open(path, "r") as file: + backup = json.load(file) + +library.restore_tracks(backup["tracks"]) diff --git a/docs/_howto/scripts/local.library.backup-restore/p3.py b/docs/_howto/scripts/local.library.backup-restore/p3.py new file mode 100644 index 00000000..035fedb5 --- /dev/null +++ b/docs/_howto/scripts/local.library.backup-restore/p3.py @@ -0,0 +1,15 @@ +from p2 import * + +from musify.libraries.local.track.field import LocalTrackField + +tags = [ + LocalTrackField.TITLE, + LocalTrackField.GENRES, + LocalTrackField.KEY, + LocalTrackField.BPM, + LocalTrackField.DATE, + LocalTrackField.COMPILATION, + LocalTrackField.IMAGES +] + +library.restore_tracks(backup, tags=tags) diff --git a/docs/_howto/scripts/local.library.backup-restore/p4.py b/docs/_howto/scripts/local.library.backup-restore/p4.py new file mode 100644 index 00000000..60d733b8 --- /dev/null +++ b/docs/_howto/scripts/local.library.backup-restore/p4.py @@ -0,0 +1,7 @@ +from p3 import * + +results = library.save_tracks(replace=True, dry_run=False) +# ... or if tags were specified +results = library.save_tracks(tags=tags, replace=True, dry_run=False) + +library.log_save_tracks_result(results) diff --git a/docs/_howto/scripts/local.library.load.py b/docs/_howto/scripts/local.library.load.py deleted file mode 100644 index 92829565..00000000 --- a/docs/_howto/scripts/local.library.load.py +++ /dev/null @@ -1,46 +0,0 @@ -from musify.libraries.local.library import LocalLibrary - -library = LocalLibrary( - library_folders=["", ...], - playlist_folder="", -) - -from musify.libraries.local.library import MusicBee - -library = MusicBee(musicbee_folder="") - -# if you have a very large library, this will take some time... -library.load() - -# ...or you may also just load distinct sections of your library -library.load_tracks() -library.load_playlists() - -# optionally log stats about these sections -library.log_tracks() -library.log_playlists() - -# pretty print an overview of your library -print(library) - -playlist = library.playlists[""] # case sensitive -album = next(album for album in library.albums if album.name == "") -artist = next(artist for artist in library.artists if artist.name == "") -folder = next(folder for folder in library.folders if folder.name == "") -genre = next(genre for genre in library.genres if genre.name == "") - -# pretty print information about the loaded objects -print(playlist, album, artist, folder, genre, sep="\n") - -# get a track via its title -track = library[""] # if multiple tracks have the same title, the first matching one if returned - -# get a track via its path -track = library[""] # must be an absolute path - -# get a track according to a specific tag -track = next(track for track in library if track.artist == "") -track = next(track for track in library if "" in (track.genres or [])) - -# pretty print information about this track -print(track) diff --git a/docs/_howto/scripts/local.library.load/p0_local.py b/docs/_howto/scripts/local.library.load/p0_local.py new file mode 100644 index 00000000..05d01423 --- /dev/null +++ b/docs/_howto/scripts/local.library.load/p0_local.py @@ -0,0 +1,6 @@ +from musify.libraries.local.library import LocalLibrary + +library = LocalLibrary( + library_folders=["", ...], + playlist_folder="", +) diff --git a/docs/_howto/scripts/local.library.load/p0_musicbee.py b/docs/_howto/scripts/local.library.load/p0_musicbee.py new file mode 100644 index 00000000..901959d8 --- /dev/null +++ b/docs/_howto/scripts/local.library.load/p0_musicbee.py @@ -0,0 +1,3 @@ +from musify.libraries.local.library import MusicBee + +library = MusicBee(musicbee_folder="") diff --git a/docs/_howto/scripts/local.library.load/p1.py b/docs/_howto/scripts/local.library.load/p1.py new file mode 100644 index 00000000..e164ba17 --- /dev/null +++ b/docs/_howto/scripts/local.library.load/p1.py @@ -0,0 +1,15 @@ +from p0_local import * + +# if you have a very large library, this will take some time... +library.load() + +# ...or you may also just load distinct sections of your library +library.load_tracks() +library.load_playlists() + +# optionally log stats about these sections +library.log_tracks() +library.log_playlists() + +# pretty print an overview of your library +print(library) diff --git a/docs/_howto/scripts/local.library.load/p2.py b/docs/_howto/scripts/local.library.load/p2.py new file mode 100644 index 00000000..f9554693 --- /dev/null +++ b/docs/_howto/scripts/local.library.load/p2.py @@ -0,0 +1,10 @@ +from p1 import * + +playlist = library.playlists[""] # case sensitive +album = next(album for album in library.albums if album.name == "") +artist = next(artist for artist in library.artists if artist.name == "") +folder = next(folder for folder in library.folders if folder.name == "") +genre = next(genre for genre in library.genres if genre.name == "") + +# pretty print information about the loaded objects +print(playlist, album, artist, folder, genre, sep="\n") diff --git a/docs/_howto/scripts/local.library.load/p3.py b/docs/_howto/scripts/local.library.load/p3.py new file mode 100644 index 00000000..6bcd728f --- /dev/null +++ b/docs/_howto/scripts/local.library.load/p3.py @@ -0,0 +1,14 @@ +from p2 import * + +# get a track via its title +track = library[""] # if multiple tracks have the same title, the first matching one if returned + +# get a track via its path +track = library[""] # must be an absolute path + +# get a track according to a specific tag +track = next(track for track in library if track.artist == "") +track = next(track for track in library if "" in (track.genres or [])) + +# pretty print information about this track +print(track) diff --git a/docs/_howto/scripts/local.playlist.load-save.py b/docs/_howto/scripts/local.playlist.load-save.py deleted file mode 100644 index 35d33f41..00000000 --- a/docs/_howto/scripts/local.playlist.load-save.py +++ /dev/null @@ -1,50 +0,0 @@ -from musify.libraries.local.playlist import M3U, XAutoPF -from musify.libraries.local.track import load_track - -tracks = [ - load_track(""), - load_track(""), - load_track(""), - load_track(""), -] - -playlist = M3U("", tracks=tracks) - -playlist = M3U("") -playlist = XAutoPF("") - -# pretty print information about this playlist -print(playlist) - -from musify.libraries.local.playlist import load_playlist - -playlist = load_playlist("") - -from musify.libraries.local.track import load_track - -tracks = [ - load_track(""), - load_track(""), - load_track(""), - load_track(""), -] - -playlist = M3U("", tracks=tracks) - -from musify.file.path_mapper import PathMapper - -playlist = M3U("", path_mapper=PathMapper()) - -from musify.libraries.remote.spotify.processors import SpotifyDataWrangler - -playlist = M3U("", remote_wrangler=SpotifyDataWrangler()) - -# add a track to the playlist -playlist.append(load_track("")) - -# add album's and artist's tracks to the playlist using either of the following -playlist.extend(tracks) -playlist += tracks - -result = playlist.save(dry_run=False) -print(result) diff --git a/docs/_howto/scripts/local.playlist.load-save/p1.py b/docs/_howto/scripts/local.playlist.load-save/p1.py new file mode 100644 index 00000000..d2a82315 --- /dev/null +++ b/docs/_howto/scripts/local.playlist.load-save/p1.py @@ -0,0 +1,11 @@ +from musify.libraries.local.playlist import M3U +from musify.libraries.local.track import load_track + +tracks = [ + load_track(""), + load_track(""), + load_track(""), + load_track(""), +] + +playlist = M3U("", tracks=tracks) diff --git a/docs/_howto/scripts/local.playlist.load-save/p2.py b/docs/_howto/scripts/local.playlist.load-save/p2.py new file mode 100644 index 00000000..6c1b1469 --- /dev/null +++ b/docs/_howto/scripts/local.playlist.load-save/p2.py @@ -0,0 +1,9 @@ +from p1 import * + +from musify.libraries.local.playlist import XAutoPF + +playlist = M3U("") +playlist = XAutoPF("") + +# pretty print information about this playlist +print(playlist) diff --git a/docs/_howto/scripts/local.playlist.load-save/p3.py b/docs/_howto/scripts/local.playlist.load-save/p3.py new file mode 100644 index 00000000..a3417bf0 --- /dev/null +++ b/docs/_howto/scripts/local.playlist.load-save/p3.py @@ -0,0 +1,5 @@ +from p2 import * + +from musify.libraries.local.playlist import load_playlist + +playlist = load_playlist("") diff --git a/docs/_howto/scripts/local.playlist.load-save/p3_mapper.py b/docs/_howto/scripts/local.playlist.load-save/p3_mapper.py new file mode 100644 index 00000000..2599f8de --- /dev/null +++ b/docs/_howto/scripts/local.playlist.load-save/p3_mapper.py @@ -0,0 +1,5 @@ +from p3 import * + +from musify.file.path_mapper import PathMapper + +playlist = M3U("", path_mapper=PathMapper()) diff --git a/docs/_howto/scripts/local.playlist.load-save/p3_tracks.py b/docs/_howto/scripts/local.playlist.load-save/p3_tracks.py new file mode 100644 index 00000000..00c2e126 --- /dev/null +++ b/docs/_howto/scripts/local.playlist.load-save/p3_tracks.py @@ -0,0 +1,13 @@ +from p3 import * + +from musify.libraries.local.track import load_track + +tracks = [ + load_track(""), + load_track(""), + load_track(""), + load_track(""), +] + +name = "" # case sensitive +playlist = M3U(name, tracks=tracks) diff --git a/docs/_howto/scripts/local.playlist.load-save/p3_wrangler.py b/docs/_howto/scripts/local.playlist.load-save/p3_wrangler.py new file mode 100644 index 00000000..c19975d3 --- /dev/null +++ b/docs/_howto/scripts/local.playlist.load-save/p3_wrangler.py @@ -0,0 +1,5 @@ +from p3 import * + +from musify.libraries.remote.spotify.processors import SpotifyDataWrangler + +playlist = M3U("", remote_wrangler=SpotifyDataWrangler()) diff --git a/docs/_howto/scripts/local.playlist.load-save/p4.py b/docs/_howto/scripts/local.playlist.load-save/p4.py new file mode 100644 index 00000000..21b7eacc --- /dev/null +++ b/docs/_howto/scripts/local.playlist.load-save/p4.py @@ -0,0 +1,8 @@ +from p3 import * + +# add a track to the playlist +playlist.append(load_track("")) + +# add album's and artist's tracks to the playlist using either of the following +playlist.extend(tracks) +playlist += tracks diff --git a/docs/_howto/scripts/local.playlist.load-save/p5.py b/docs/_howto/scripts/local.playlist.load-save/p5.py new file mode 100644 index 00000000..1b1bc471 --- /dev/null +++ b/docs/_howto/scripts/local.playlist.load-save/p5.py @@ -0,0 +1,4 @@ +from p4 import * + +result = playlist.save(dry_run=False) +print(result) diff --git a/docs/_howto/scripts/local.track.load-save.py b/docs/_howto/scripts/local.track.load-save.py deleted file mode 100644 index f506f517..00000000 --- a/docs/_howto/scripts/local.track.load-save.py +++ /dev/null @@ -1,57 +0,0 @@ -from musify.libraries.local.track import FLAC, MP3, M4A, WMA - -track = FLAC("") -track = MP3("") -track = M4A("") -track = WMA("") - -# pretty print information about this track -print(track) - -from musify.libraries.local.track import load_track - -track = load_track("") - -from musify.libraries.remote.spotify.processors import SpotifyDataWrangler - -track = MP3("", remote_wrangler=SpotifyDataWrangler()) - -from datetime import date - -track.title = "new title" -track.artist = "new artist" -track.album = "new album" -track.track_number = 200 -track.genres = ["super cool genre", "awesome genre"] -track.key = "C#" -track.bpm = 120.5 -track.date = date(year=2024, month=1, day=1) -track.compilation = True -track.image_links.update({ - "cover front": "https://i.scdn.co/image/ab67616d0000b2737f0918f1560fc4b40b967dd4", - "cover back": "" -}) - -# see the updated information -print(track) - -# save all the tags like so... -results = track.save(replace=True, dry_run=False) - -# ...or select which tags you wish to save like so -from musify.libraries.local.track.field import LocalTrackField - -tags = [ - LocalTrackField.TITLE, - LocalTrackField.GENRES, - LocalTrackField.KEY, - LocalTrackField.BPM, - LocalTrackField.DATE, - LocalTrackField.COMPILATION, - LocalTrackField.IMAGES -] - -results = track.save(tags=tags, replace=True, dry_run=False) - -# print a list of the tags that were saved -print([tag.name for tag in results.updated]) diff --git a/docs/_howto/scripts/local.track.load-save/p1.py b/docs/_howto/scripts/local.track.load-save/p1.py new file mode 100644 index 00000000..8d1f90e4 --- /dev/null +++ b/docs/_howto/scripts/local.track.load-save/p1.py @@ -0,0 +1,9 @@ +from musify.libraries.local.track import FLAC, MP3, M4A, WMA + +track = FLAC("") +track = MP3("") +track = M4A("") +track = WMA("") + +# pretty print information about this track +print(track) diff --git a/docs/_howto/scripts/local.track.load-save/p1_load.py b/docs/_howto/scripts/local.track.load-save/p1_load.py new file mode 100644 index 00000000..b895abe6 --- /dev/null +++ b/docs/_howto/scripts/local.track.load-save/p1_load.py @@ -0,0 +1,3 @@ +from musify.libraries.local.track import load_track + +track = load_track("") diff --git a/docs/_howto/scripts/local.track.load-save/p1_wrangler.py b/docs/_howto/scripts/local.track.load-save/p1_wrangler.py new file mode 100644 index 00000000..e102f9f3 --- /dev/null +++ b/docs/_howto/scripts/local.track.load-save/p1_wrangler.py @@ -0,0 +1,4 @@ +from musify.libraries.local.track import MP3 +from musify.libraries.remote.spotify.processors import SpotifyDataWrangler + +track = MP3("", remote_wrangler=SpotifyDataWrangler()) diff --git a/docs/_howto/scripts/local.track.load-save/p2.py b/docs/_howto/scripts/local.track.load-save/p2.py new file mode 100644 index 00000000..7a3533b1 --- /dev/null +++ b/docs/_howto/scripts/local.track.load-save/p2.py @@ -0,0 +1,20 @@ +from p1_load import * + +from datetime import date + +track.title = "new title" +track.artist = "new artist" +track.album = "new album" +track.track_number = 200 +track.genres = ["super cool genre", "awesome genre"] +track.key = "C#" +track.bpm = 120.5 +track.date = date(year=2024, month=1, day=1) +track.compilation = True +track.image_links.update({ + "cover front": "https://i.scdn.co/image/ab67616d0000b2737f0918f1560fc4b40b967dd4", + "cover back": "" +}) + +# see the updated information +print(track) diff --git a/docs/_howto/scripts/local.track.load-save/p3_all.py b/docs/_howto/scripts/local.track.load-save/p3_all.py new file mode 100644 index 00000000..4068ab6b --- /dev/null +++ b/docs/_howto/scripts/local.track.load-save/p3_all.py @@ -0,0 +1,6 @@ +from p2 import * + +results = track.save(replace=True, dry_run=False) + +# print a list of the tags that were saved +print([tag.name for tag in results.updated]) diff --git a/docs/_howto/scripts/local.track.load-save/p3_tags.py b/docs/_howto/scripts/local.track.load-save/p3_tags.py new file mode 100644 index 00000000..371d765d --- /dev/null +++ b/docs/_howto/scripts/local.track.load-save/p3_tags.py @@ -0,0 +1,19 @@ +from p2 import * + +# ...or select which tags you wish to save like so +from musify.libraries.local.track.field import LocalTrackField + +tags = [ + LocalTrackField.TITLE, + LocalTrackField.GENRES, + LocalTrackField.KEY, + LocalTrackField.BPM, + LocalTrackField.DATE, + LocalTrackField.COMPILATION, + LocalTrackField.IMAGES +] + +results = track.save(tags=tags, replace=True, dry_run=False) + +# print a list of the tags that were saved +print([tag.name for tag in results.updated]) diff --git a/docs/_howto/scripts/remote.new-music.py b/docs/_howto/scripts/remote.new-music.py deleted file mode 100644 index d1e15cc8..00000000 --- a/docs/_howto/scripts/remote.new-music.py +++ /dev/null @@ -1,52 +0,0 @@ -from musify.libraries.remote.spotify.api import SpotifyAPI -from musify.libraries.remote.spotify.library import SpotifyLibrary -api = SpotifyAPI() -library = SpotifyLibrary(api=api) - -library.load_saved_artists() -library.enrich_saved_artists() - -from datetime import datetime, date - -start_date = date(2024, 1, 1) -end_date = datetime.now().date() - - -def match_date(alb) -> bool: - """Match start and end dates to the release date of the given ``alb``""" - if alb.date: - return start_date <= alb.date <= end_date - if alb.month: - return start_date.year <= alb.year <= end_date.year and start_date.month <= alb.month <= end_date.month - if alb.year: - return start_date.year <= alb.year <= end_date.year - return False - - -from musify.libraries.remote.core.enum import RemoteObjectType - -albums = [album for artist in library.artists for album in artist.albums if match_date(album)] -albums_need_extend = [album for album in albums if len(album.tracks) < album.track_total] -if albums_need_extend: - kind = RemoteObjectType.ALBUM - key = api.collection_item_map[kind] - - bar = library.logger.get_iterator(iterable=albums_need_extend, desc="Getting album tracks", unit="albums") - for album in bar: - api.extend_items(album.response, kind=kind, key=key) - album.refresh(skip_checks=False) - -# log stats about the loaded artists -library.log_artists() - -from musify.libraries.remote.spotify.object import SpotifyPlaylist - -name = "New Music Playlist" -playlist = SpotifyPlaylist.create(api=api, name=name) - -tracks = [track for album in sorted(albums, key=lambda x: x.date, reverse=True) for track in album] -playlist.extend(tracks, allow_duplicates=False) - -# sync the object with Spotify and log the results -results = playlist.sync(kind="refresh", reload=False, dry_run=False) -library.log_sync({name: results}) diff --git a/docs/_howto/scripts/remote.new-music/p0.py b/docs/_howto/scripts/remote.new-music/p0.py new file mode 100644 index 00000000..737e8a0f --- /dev/null +++ b/docs/_howto/scripts/remote.new-music/p0.py @@ -0,0 +1,2 @@ +from musify.libraries.remote.spotify.api import SpotifyAPI +api = SpotifyAPI() diff --git a/docs/_howto/scripts/remote.new-music/p2.py b/docs/_howto/scripts/remote.new-music/p2.py new file mode 100644 index 00000000..23ce42bc --- /dev/null +++ b/docs/_howto/scripts/remote.new-music/p2.py @@ -0,0 +1,10 @@ +from p0 import * + +from musify.libraries.remote.core.library import RemoteLibrary + + +async def load_artists(library: RemoteLibrary) -> None: + """Loads the artists followed by a given user in their given ``library`` and enriches them.""" + async with library: + await library.load_saved_artists() + await library.enrich_saved_artists(types=("album", "single")) diff --git a/docs/_howto/scripts/remote.new-music/p3.py b/docs/_howto/scripts/remote.new-music/p3.py new file mode 100644 index 00000000..529381ad --- /dev/null +++ b/docs/_howto/scripts/remote.new-music/p3.py @@ -0,0 +1,16 @@ +from p2 import * + +from datetime import date + +from musify.libraries.remote.core.object import RemoteAlbum + + +def match_date(alb: RemoteAlbum, start: date, end: date) -> bool: + """Match ``start`` and ``end`` dates to the release date of the given ``alb``""" + if alb.date: + return start <= alb.date <= end + if alb.month: + return start.year <= alb.year <= end.year and start.month <= alb.month <= end.month + if alb.year: + return start.year <= alb.year <= end.year + return False diff --git a/docs/_howto/scripts/remote.new-music/p4.py b/docs/_howto/scripts/remote.new-music/p4.py new file mode 100644 index 00000000..a0db8c28 --- /dev/null +++ b/docs/_howto/scripts/remote.new-music/p4.py @@ -0,0 +1,25 @@ +from p3 import * + +from musify.libraries.remote.core.enum import RemoteObjectType +from musify.libraries.remote.spotify.object import SpotifyAlbum + + +async def get_albums(library: RemoteLibrary, start: date, end: date) -> list[SpotifyAlbum]: + """ + Get the albums that match the ``start`` and ``end`` date range from a given ``library`` + and get the tracks on those albums if needed. + """ + albums = [album for artist in library.artists for album in artist.albums if match_date(album, start, end)] + albums_need_extend = [album for album in albums if len(album.tracks) < album.track_total] + + if albums_need_extend: + kind = RemoteObjectType.ALBUM + key = api.collection_item_map[kind] + + bar = library.logger.get_iterator(iterable=albums_need_extend, desc="Getting album tracks", unit="albums") + async with library: + for album in bar: + await api.extend_items(album.response, kind=kind, key=key) + album.refresh(skip_checks=False) + + return albums diff --git a/docs/_howto/scripts/remote.new-music/p5.py b/docs/_howto/scripts/remote.new-music/p5.py new file mode 100644 index 00000000..2d9601ae --- /dev/null +++ b/docs/_howto/scripts/remote.new-music/p5.py @@ -0,0 +1,25 @@ +from p4 import * + +from musify.libraries.remote.spotify.object import SpotifyPlaylist + + +async def create_new_music_playlist(name: str, library: RemoteLibrary, start: date, end: date) -> None: + """ + Create a playlist with the given ``name`` in the given ``library`` featuring + new music by followed artists released between ``start`` date and ``end`` date. + """ + await load_artists(library) + albums = await get_albums(library, start, end) + + # log stats about the loaded artists + library.log_artists() + + async with library: + playlist = await SpotifyPlaylist.create(api=api, name=name) + + tracks = [track for album in sorted(albums, key=lambda x: x.date, reverse=True) for track in album] + playlist.extend(tracks, allow_duplicates=False) + + # sync the object with Spotify and log the results + results = await playlist.sync(kind="refresh", reload=False, dry_run=False) + library.log_sync({name: results}) diff --git a/docs/_howto/scripts/remote.new-music/p99.py b/docs/_howto/scripts/remote.new-music/p99.py new file mode 100644 index 00000000..da5c715b --- /dev/null +++ b/docs/_howto/scripts/remote.new-music/p99.py @@ -0,0 +1,13 @@ +from p5 import * + +import asyncio +from datetime import datetime, timedelta + +from musify.libraries.remote.spotify.library import SpotifyLibrary + +playlist_name = "New Music Playlist" +library = SpotifyLibrary(api=api) +end = datetime.now().date() +start = end - timedelta(weeks=4) + +asyncio.run(create_new_music_playlist(playlist_name, library, start, end)) diff --git a/docs/_howto/scripts/reports.py b/docs/_howto/scripts/reports.py deleted file mode 100644 index a1416a48..00000000 --- a/docs/_howto/scripts/reports.py +++ /dev/null @@ -1,26 +0,0 @@ -from musify.libraries.local.library import LocalLibrary -local_library = LocalLibrary() - -from musify.libraries.remote.spotify.api import SpotifyAPI -from musify.libraries.remote.spotify.library import SpotifyLibrary -api = SpotifyAPI() -remote_library = SpotifyLibrary(api=api) - -from musify.report import report_playlist_differences - -report_playlist_differences(source=local_library, reference=remote_library) - -from musify.libraries.local.track.field import LocalTrackField -from musify.report import report_missing_tags - -tags = [ - LocalTrackField.TITLE, - LocalTrackField.GENRES, - LocalTrackField.KEY, - LocalTrackField.BPM, - LocalTrackField.DATE, - LocalTrackField.COMPILATION, - LocalTrackField.IMAGES -] - -report_missing_tags(collections=local_library, tags=tags, match_all=False) diff --git a/docs/_howto/scripts/reports/p0.py b/docs/_howto/scripts/reports/p0.py new file mode 100644 index 00000000..b2bc429f --- /dev/null +++ b/docs/_howto/scripts/reports/p0.py @@ -0,0 +1,8 @@ +from musify.libraries.local.library import LocalLibrary +from musify.libraries.remote.spotify.api import SpotifyAPI +from musify.libraries.remote.spotify.library import SpotifyLibrary + +local_library = LocalLibrary() + +api = SpotifyAPI() +remote_library = SpotifyLibrary(api=api) diff --git a/docs/_howto/scripts/reports/p1_missing_tags.py b/docs/_howto/scripts/reports/p1_missing_tags.py new file mode 100644 index 00000000..dc7b87a3 --- /dev/null +++ b/docs/_howto/scripts/reports/p1_missing_tags.py @@ -0,0 +1,16 @@ +from p0 import * + +from musify.libraries.local.track.field import LocalTrackField +from musify.report import report_missing_tags + +tags = [ + LocalTrackField.TITLE, + LocalTrackField.GENRES, + LocalTrackField.KEY, + LocalTrackField.BPM, + LocalTrackField.DATE, + LocalTrackField.COMPILATION, + LocalTrackField.IMAGES +] + +report_missing_tags(collections=local_library, tags=tags, match_all=False) diff --git a/docs/_howto/scripts/reports/p1_playlist_differences.py b/docs/_howto/scripts/reports/p1_playlist_differences.py new file mode 100644 index 00000000..23aa6828 --- /dev/null +++ b/docs/_howto/scripts/reports/p1_playlist_differences.py @@ -0,0 +1,5 @@ +from p0 import * + +from musify.report import report_playlist_differences + +report_playlist_differences(source=local_library, reference=remote_library) diff --git a/docs/_howto/scripts/spotify.api.py b/docs/_howto/scripts/spotify.api.py index f7b07a69..54cf3ff0 100644 --- a/docs/_howto/scripts/spotify.api.py +++ b/docs/_howto/scripts/spotify.api.py @@ -15,9 +15,3 @@ # for quicker authorisations in future token_file_path="" ) - -# authorise the program to access your Spotify data in your web browser -api.authorise() - -from musify.libraries.remote.spotify.library import SpotifyLibrary -library = SpotifyLibrary(api=api) diff --git a/docs/_howto/scripts/spotify.library.backup-restore.py b/docs/_howto/scripts/spotify.library.backup-restore.py deleted file mode 100644 index 2e567d9b..00000000 --- a/docs/_howto/scripts/spotify.library.backup-restore.py +++ /dev/null @@ -1,17 +0,0 @@ -from musify.libraries.remote.spotify.api import SpotifyAPI -from musify.libraries.remote.spotify.library import SpotifyLibrary -api = SpotifyAPI() -library = SpotifyLibrary(api=api) - -import json - -path = "remote_backup.json" -with open(path, "w") as file: - json.dump(library.json(), file, indent=2) - -with open(path, "r") as file: - backup = json.load(file) - -library.restore_playlists(backup["playlists"]) -results = library.sync(kind="refresh", reload=False, dry_run=False) -library.log_sync(results) diff --git a/docs/_howto/scripts/spotify.library.backup-restore/p0.py b/docs/_howto/scripts/spotify.library.backup-restore/p0.py new file mode 100644 index 00000000..d6ef4e30 --- /dev/null +++ b/docs/_howto/scripts/spotify.library.backup-restore/p0.py @@ -0,0 +1,5 @@ +from musify.libraries.remote.spotify.api import SpotifyAPI +from musify.libraries.remote.spotify.library import SpotifyLibrary + +api = SpotifyAPI() +library = SpotifyLibrary(api=api) diff --git a/docs/_howto/scripts/spotify.library.backup-restore/p1.py b/docs/_howto/scripts/spotify.library.backup-restore/p1.py new file mode 100644 index 00000000..d3592725 --- /dev/null +++ b/docs/_howto/scripts/spotify.library.backup-restore/p1.py @@ -0,0 +1,7 @@ +from p0 import * + +import json + +path = "remote_backup.json" +with open(path, "w") as file: + json.dump(library.json(), file, indent=2) diff --git a/docs/_howto/scripts/spotify.library.backup-restore/p2.py b/docs/_howto/scripts/spotify.library.backup-restore/p2.py new file mode 100644 index 00000000..41fe8507 --- /dev/null +++ b/docs/_howto/scripts/spotify.library.backup-restore/p2.py @@ -0,0 +1,19 @@ +from p1 import * + +import asyncio + +from musify.libraries.remote.core.library import RemoteLibrary + +with open(path, "r") as file: + backup = json.load(file) + + +async def restore_remote_library(library: RemoteLibrary, backup) -> None: + """Restore the playlists in a remote ``library`` from the given ``backup``""" + async with library: + await library.restore_playlists(backup["playlists"]) + results = await library.sync(kind="refresh", reload=False, dry_run=False) + + library.log_sync(results) + +asyncio.run(restore_remote_library(library, backup)) diff --git a/docs/_howto/scripts/spotify.load.py b/docs/_howto/scripts/spotify.load.py deleted file mode 100644 index e7d60cfc..00000000 --- a/docs/_howto/scripts/spotify.load.py +++ /dev/null @@ -1,62 +0,0 @@ -from musify.libraries.remote.spotify.api import SpotifyAPI -api = SpotifyAPI() - -from musify.libraries.remote.spotify.library import SpotifyLibrary - -library = SpotifyLibrary(api=api) - -# if you have a very large library, this will take some time... -library.load() - -# ...or you may also just load distinct sections of your library -library.load_playlists() -library.load_tracks() -library.load_saved_albums() -library.load_saved_artists() - -# enrich the loaded objects; see each function's docstring for more info on arguments -# each of these will take some time depending on the size of your library -library.enrich_tracks(features=True, analysis=False, albums=False, artists=False) -library.enrich_saved_albums() -library.enrich_saved_artists(tracks=True, types=("album", "single")) - -# optionally log stats about these sections -library.log_playlists() -library.log_tracks() -library.log_albums() -library.log_artists() - -# pretty print an overview of your library -print(library) - -from musify.libraries.remote.spotify.object import SpotifyTrack, SpotifyAlbum, SpotifyPlaylist, SpotifyArtist - -# load by ID -track1 = SpotifyTrack.load("6fWoFduMpBem73DMLCOh1Z", api=api) -# load by URI -track2 = SpotifyTrack.load("spotify:track:4npv0xZO9fVLBmDS2XP9Bw", api=api) -# load by open/external style URL -track3 = SpotifyTrack.load("https://open.spotify.com/track/1TjVbzJUAuOvas1bL00TiH", api=api) -# load by API style URI -track4 = SpotifyTrack.load("https://api.spotify.com/v1/tracks/6pmSweeisgfxxsiLINILdJ", api=api) - -# load many different kinds of supported Spotify types -playlist = SpotifyPlaylist.load("spotify:playlist:37i9dQZF1E4zg1xOOORiP1", api=api, extend_tracks=True) -album = SpotifyAlbum.load("https://open.spotify.com/album/0rAWaAAMfzHzCbYESj4mfx", api=api, extend_tracks=True) -artist = SpotifyArtist.load("1odSzdzUpm3ZEEb74GdyiS", api=api, extend_tracks=True) - -# pretty print information about the loaded objects -print(track1, track2, track3, playlist, album, artist, sep="\n") - -my_playlist = library.playlists[""] # case sensitive - -# add a track to the playlist -my_playlist.append(track1) - -# add an album to the playlist using either of the following -my_playlist.extend(album) -my_playlist += album - -# sync the object with Spotify and log the results -result = my_playlist.sync(dry_run=False) -library.log_sync(result) diff --git a/docs/_howto/scripts/spotify.load/p0.py b/docs/_howto/scripts/spotify.load/p0.py new file mode 100644 index 00000000..737e8a0f --- /dev/null +++ b/docs/_howto/scripts/spotify.load/p0.py @@ -0,0 +1,2 @@ +from musify.libraries.remote.spotify.api import SpotifyAPI +api = SpotifyAPI() diff --git a/docs/_howto/scripts/spotify.load/p1.py b/docs/_howto/scripts/spotify.load/p1.py new file mode 100644 index 00000000..f63707a4 --- /dev/null +++ b/docs/_howto/scripts/spotify.load/p1.py @@ -0,0 +1,44 @@ +from p2 import * + +from musify.libraries.remote.spotify.library import SpotifyLibrary + + +async def load_library(library: SpotifyLibrary) -> None: + """Load the objects for a given ``library``. Does not enrich the loaded data.""" + # authorise the program to access your Spotify data in your web browser + async with library: + # if you have a very large library, this will take some time... + await library.load() + + +async def load_library_by_parts(library: SpotifyLibrary) -> None: + """Load the objects for a given ``library`` by each of its distinct parts. Does not enrich the loaded data.""" + # authorise the program to access your Spotify data in your web browser + async with library: + # load distinct sections of your library + await library.load_playlists() + await library.load_tracks() + await library.load_saved_albums() + await library.load_saved_artists() + + +async def enrich_library(library: SpotifyLibrary) -> None: + """Enrich the loaded objects in the given ``library``""" + # authorise the program to access your Spotify data in your web browser + async with library: + # enrich the loaded objects; see each function's docstring for more info on arguments + # each of these will take some time depending on the size of your library + await library.enrich_tracks(features=True, analysis=False, albums=False, artists=False) + await library.enrich_saved_albums() + await library.enrich_saved_artists(tracks=True, types=("album", "single")) + + +def log_library(library: SpotifyLibrary) -> None: + """Log stats about the loaded ``library``""" + library.log_playlists() + library.log_tracks() + library.log_albums() + library.log_artists() + + # pretty print an overview of your library + print(library) diff --git a/docs/_howto/scripts/spotify.load/p2.py b/docs/_howto/scripts/spotify.load/p2.py new file mode 100644 index 00000000..ec1c4ce4 --- /dev/null +++ b/docs/_howto/scripts/spotify.load/p2.py @@ -0,0 +1,53 @@ +from p0 import * + +from musify.libraries.remote.spotify.object import SpotifyTrack, SpotifyAlbum, SpotifyPlaylist, SpotifyArtist + + +async def load_playlist(api: SpotifyAPI) -> SpotifyPlaylist: + # authorise the program to access your Spotify data in your web browser + async with api as a: + playlist = await SpotifyPlaylist.load("spotify:playlist:37i9dQZF1E4zg1xOOORiP1", api=a, extend_tracks=True) + return playlist + + +async def load_tracks(api: SpotifyAPI) -> list[SpotifyTrack]: + tracks = [] + + # authorise the program to access your Spotify data in your web browser + async with api as a: + # load by ID + tracks.append(await SpotifyTrack.load("6fWoFduMpBem73DMLCOh1Z", api=a)) + # load by URI + tracks.append(await SpotifyTrack.load("spotify:track:4npv0xZO9fVLBmDS2XP9Bw", api=a)) + # load by open/external style URL + tracks.append(await SpotifyTrack.load("https://open.spotify.com/track/1TjVbzJUAuOvas1bL00TiH", api=a)) + # load by API style URI + tracks.append(await SpotifyTrack.load("https://api.spotify.com/v1/tracks/6pmSweeisgfxxsiLINILdJ", api=api)) + + return tracks + + +async def load_album(api: SpotifyAPI) -> SpotifyAlbum: + # authorise the program to access your Spotify data in your web browser + async with api as a: + album = await SpotifyAlbum.load( + "https://open.spotify.com/album/0rAWaAAMfzHzCbYESj4mfx", api=a, extend_tracks=True + ) + return album + + +async def load_artist(api: SpotifyAPI) -> SpotifyArtist: + # authorise the program to access your Spotify data in your web browser + async with api as a: + artist = await SpotifyArtist.load("1odSzdzUpm3ZEEb74GdyiS", api=a, extend_tracks=True) + return artist + + +async def load_objects(api: SpotifyAPI) -> None: + playlist = await load_playlist(api) + tracks = await load_tracks(api) + album = await load_album(api) + artist = await load_artist(api) + + # pretty print information about the loaded objects + print(playlist, *tracks, album, artist, sep="\n") diff --git a/docs/_howto/scripts/spotify.load/p3.py b/docs/_howto/scripts/spotify.load/p3.py new file mode 100644 index 00000000..d252b716 --- /dev/null +++ b/docs/_howto/scripts/spotify.load/p3.py @@ -0,0 +1,22 @@ +from p1 import * + + +async def update_playlist(name: str, library: SpotifyLibrary) -> None: + """Update a playlist with the given ``name`` in the given ``library``""" + tracks = await load_tracks(library.api) + album = await load_album(library.api) + await load_library(library) + + my_playlist = library.playlists[name] + + # add a track to the playlist + my_playlist.append(tracks[0]) + + # add an album to the playlist using either of the following + my_playlist.extend(album) + my_playlist += album + + # sync the object with Spotify and log the results + async with library: + result = await my_playlist.sync(dry_run=False) + library.log_sync(result) diff --git a/docs/_howto/scripts/spotify.load/p99.py b/docs/_howto/scripts/spotify.load/p99.py new file mode 100644 index 00000000..d0de1bd3 --- /dev/null +++ b/docs/_howto/scripts/spotify.load/p99.py @@ -0,0 +1,6 @@ +from p3 import * + +import asyncio + +asyncio.run(load_objects(api)) +asyncio.run(update_playlist("", api)) # case sensitive diff --git a/docs/_howto/scripts/sync.py b/docs/_howto/scripts/sync.py deleted file mode 100644 index b626659d..00000000 --- a/docs/_howto/scripts/sync.py +++ /dev/null @@ -1,60 +0,0 @@ -from musify.libraries.remote.spotify.api import SpotifyAPI -from musify.processors.match import ItemMatcher -api = SpotifyAPI() - -from musify.libraries.local.library import LocalLibrary -from musify.libraries.remote.spotify.processors import SpotifyDataWrangler - -local_library = LocalLibrary( - library_folders=["", ...], - playlist_folder="", - # this wrangler will be needed to interpret matched URIs as valid - remote_wrangler=SpotifyDataWrangler(), -) -local_library.load() - -from musify.libraries.remote.core.processors.search import RemoteItemSearcher -from musify.libraries.remote.core.processors.check import RemoteItemChecker -from musify.libraries.remote.spotify.factory import SpotifyObjectFactory - -albums = local_library.albums[:3] -factory = SpotifyObjectFactory(api=api) - -searcher = RemoteItemSearcher(matcher=ItemMatcher(), object_factory=factory) -searcher.search(albums) - -checker = RemoteItemChecker(matcher=ItemMatcher(), object_factory=factory) -checker.check(albums) - -from musify.libraries.remote.spotify.object import SpotifyTrack - -for album in albums: - for local_track in album: - remote_track = SpotifyTrack.load(local_track.uri, api=api) - - local_track.title = remote_track.title - local_track.artist = remote_track.artist - local_track.date = remote_track.date - local_track.genres = remote_track.genres - local_track.image_links = remote_track.image_links - - # alternatively, just merge all tags - local_track |= remote_track - - # save the track here or... - local_track.save(replace=True, dry_run=False) - - # ...save all tracks on the album at once here - album.save_tracks(replace=True, dry_run=False) - -from musify.libraries.remote.spotify.library import SpotifyLibrary - -remote_library = SpotifyLibrary(api=api) -remote_library.load_playlists() - -local_playlist = local_library.playlists[""] # case sensitive -remote_playlist = remote_library.playlists[""] # case sensitive - -# sync the object with Spotify and pretty print info about the reloaded remote playlist -remote_playlist.sync(items=local_playlist, kind="new", reload=True, dry_run=False) -print(remote_playlist) diff --git a/docs/_howto/scripts/sync/p0.py b/docs/_howto/scripts/sync/p0.py new file mode 100644 index 00000000..521e1e79 --- /dev/null +++ b/docs/_howto/scripts/sync/p0.py @@ -0,0 +1,3 @@ +from musify.libraries.remote.spotify.api import SpotifyAPI + +api = SpotifyAPI() diff --git a/docs/_howto/scripts/sync/p1.py b/docs/_howto/scripts/sync/p1.py new file mode 100644 index 00000000..95aa4ed4 --- /dev/null +++ b/docs/_howto/scripts/sync/p1.py @@ -0,0 +1,22 @@ +from p0 import * + +from collections.abc import Collection + +from musify.libraries.core.collection import MusifyCollection +from musify.libraries.remote.core.factory import RemoteObjectFactory +from musify.libraries.remote.core.processors.search import RemoteItemSearcher +from musify.libraries.remote.core.processors.check import RemoteItemChecker +from musify.processors.match import ItemMatcher + + +async def match_albums_to_remote(albums: Collection[MusifyCollection], factory: RemoteObjectFactory) -> None: + """Match the items in the given ``albums`` to the remote API's database and assign URIs to them.""" + matcher = ItemMatcher() + + searcher = RemoteItemSearcher(matcher=matcher, object_factory=factory) + async with searcher: + await searcher.search(albums) + + checker = RemoteItemChecker(matcher=matcher, object_factory=factory) + async with checker: + await checker.check(albums) diff --git a/docs/_howto/scripts/sync/p2.py b/docs/_howto/scripts/sync/p2.py new file mode 100644 index 00000000..227c72b4 --- /dev/null +++ b/docs/_howto/scripts/sync/p2.py @@ -0,0 +1,26 @@ +from p1 import * + +from musify.libraries.local.collection import LocalAlbum + + +async def sync_albums(albums: list[LocalAlbum], factory: RemoteObjectFactory) -> None: + """Sync the local ``albums`` with tag data from the api in the given ``factory``""" + async with factory.api: + for album in albums: + for local_track in album: + remote_track = await factory.track.load(local_track.uri, api=factory.api) + + local_track.title = remote_track.title + local_track.artist = remote_track.artist + local_track.date = remote_track.date + local_track.genres = remote_track.genres + local_track.image_links = remote_track.image_links + + # alternatively, just merge all tags + local_track |= remote_track + + # save the track here or... + local_track.save(replace=True, dry_run=False) + + # ...save all tracks on the album at once here + album.save_tracks(replace=True, dry_run=False) diff --git a/docs/_howto/scripts/sync/p3.py b/docs/_howto/scripts/sync/p3.py new file mode 100644 index 00000000..480a4ac1 --- /dev/null +++ b/docs/_howto/scripts/sync/p3.py @@ -0,0 +1,18 @@ +from p2 import * + +from musify.libraries.local.library import LocalLibrary +from musify.libraries.remote.core.library import RemoteLibrary + + +async def sync_local_playlist_with_remote(name: str, local_library: LocalLibrary, remote_library: RemoteLibrary): + """Sync ``local_library`` playlist with given ``name`` to its matching ``remote_library`` playlist.""" + async with api: + await remote_library.load_playlists() + + local_playlist = local_library.playlists[name] + remote_playlist = remote_library.playlists[name] + + # sync the object with Spotify and pretty print info about the reloaded remote playlist + await remote_playlist.sync(items=local_playlist, kind="new", reload=True, dry_run=False) + + print(remote_playlist) diff --git a/docs/_howto/scripts/sync/p4.py b/docs/_howto/scripts/sync/p4.py new file mode 100644 index 00000000..3fc0209c --- /dev/null +++ b/docs/_howto/scripts/sync/p4.py @@ -0,0 +1,13 @@ +from p3 import * + +from musify.libraries.local.library import LocalLibrary + +local_library = LocalLibrary( + library_folders=["", ...], + playlist_folder="", + # this wrangler will be needed to interpret matched URIs as valid + remote_wrangler=api.wrangler, +) +local_library.load() + +albums = local_library.albums diff --git a/docs/_howto/scripts/sync/p99.py b/docs/_howto/scripts/sync/p99.py new file mode 100644 index 00000000..e97e4d76 --- /dev/null +++ b/docs/_howto/scripts/sync/p99.py @@ -0,0 +1,12 @@ +from p4 import * + +import asyncio + +from musify.libraries.remote.spotify.library import SpotifyLibrary + +remote_library = SpotifyLibrary(api=api) +playlist = "" # case sensitive + +asyncio.run(match_albums_to_remote(albums, remote_library.factory)) +asyncio.run(sync_albums(albums, remote_library.factory)) +asyncio.run(sync_local_library_with_remote(playlist, local_library, remote_library)) diff --git a/docs/howto.library.backup-restore.rst b/docs/howto.library.backup-restore.rst index 97526b81..3f9064aa 100644 --- a/docs/howto.library.backup-restore.rst +++ b/docs/howto.library.backup-restore.rst @@ -16,37 +16,35 @@ Backup and restore a local library 1. Load a local library. For more information on how to do this see :ref:`load-local` - .. literalinclude:: _howto/scripts/local.library.load.py + .. literalinclude:: _howto/scripts/local.library.load/p0_local.py :language: Python - :lines: 1-6 - .. literalinclude:: _howto/scripts/local.library.load.py + .. literalinclude:: _howto/scripts/local.library.load/p0_musicbee.py :language: Python - :lines: 12-24 2. Backup your library to JSON: - .. literalinclude:: _howto/scripts/local.library.backup-restore.py + .. literalinclude:: _howto/scripts/local.library.backup-restore/p1.py :language: Python - :lines: 4-8 + :lines: 3- 3. Restore the tags for all tracks in your library from a JSON file: - .. literalinclude:: _howto/scripts/local.library.backup-restore.py + .. literalinclude:: _howto/scripts/local.library.backup-restore/p2.py :language: Python - :lines: 10-14 + :lines: 3- ... or restore only a specific set of tags: - .. literalinclude:: _howto/scripts/local.library.backup-restore.py + .. literalinclude:: _howto/scripts/local.library.backup-restore/p3.py :language: Python - :lines: 16-28 + :lines: 3- 4. Save the tags to the track files: - .. literalinclude:: _howto/scripts/local.library.backup-restore.py + .. literalinclude:: _howto/scripts/local.library.backup-restore/p4.py :language: Python - :lines: 30-31 + :lines: 3- Backup and restore a remote library @@ -64,12 +62,12 @@ Backup and restore a remote library 2. Backup your library to JSON: - .. literalinclude:: _howto/scripts/spotify.library.backup-restore.py + .. literalinclude:: _howto/scripts/spotify.library.backup-restore/p1.py :language: Python - :lines: 6-10 + :lines: 3- 3. Restore the playlists in your library from a JSON file and sync the playlists: - .. literalinclude:: _howto/scripts/spotify.library.backup-restore.py + .. literalinclude:: _howto/scripts/spotify.library.backup-restore/p2.py :language: Python - :lines: 12-17 + :lines: 3- diff --git a/docs/howto.local.library.load.rst b/docs/howto.local.library.load.rst index de5f1d7d..45d58e29 100644 --- a/docs/howto.local.library.load.rst +++ b/docs/howto.local.library.load.rst @@ -17,9 +17,8 @@ You can create one of any of the supported local library types for this guide as * Generic local library - .. literalinclude:: _howto/scripts/local.library.load.py + .. literalinclude:: _howto/scripts/local.library.load/p0_local.py :language: Python - :lines: 1-6 * MusicBee @@ -27,9 +26,8 @@ You can create one of any of the supported local library types for this guide as To be able to use a MusicBee library, you will need to have installed the ``musicbee`` optional dependencies. See :ref:`installation` for more details. - .. literalinclude:: _howto/scripts/local.library.load.py + .. literalinclude:: _howto/scripts/local.library.load/p0_musicbee.py :language: Python - :lines: 8-10 Load your library and other objects @@ -37,18 +35,18 @@ Load your library and other objects 1. Load your library: - .. literalinclude:: _howto/scripts/local.library.load.py + .. literalinclude:: _howto/scripts/local.library.load/p1.py :language: Python - :lines: 12-24 + :lines: 3- 2. Get collections from your library: - .. literalinclude:: _howto/scripts/local.library.load.py + .. literalinclude:: _howto/scripts/local.library.load/p2.py :language: Python - :lines: 26-33 + :lines: 3- 3. Get a track from your library using any of the following identifiers: - .. literalinclude:: _howto/scripts/local.library.load.py + .. literalinclude:: _howto/scripts/local.library.load/p3.py :language: Python - :lines: 35-46 + :lines: 3- diff --git a/docs/howto.local.playlist.load-save.rst b/docs/howto.local.playlist.load-save.rst index 3e2a15d1..94194a86 100644 --- a/docs/howto.local.playlist.load-save.rst +++ b/docs/howto.local.playlist.load-save.rst @@ -13,11 +13,10 @@ In this example, you will: Create a playlist ----------------- -You can create a playlist as follows: +You can create a playlist from scratch as follows: -.. literalinclude:: _howto/scripts/local.playlist.load-save.py +.. literalinclude:: _howto/scripts/local.playlist.load-save/p1.py :language: Python - :lines: 1-11 Load a playlist --------------- @@ -28,30 +27,31 @@ You can load a playlist as follows: To be able to use the XAutoPF playlist type, you will need to have installed the ``musicbee`` optional dependencies. See :ref:`installation` for more details. -.. literalinclude:: _howto/scripts/local.playlist.load-save.py +.. literalinclude:: _howto/scripts/local.playlist.load-save/p2.py :language: Python - :lines: 13-17 + :lines: 3- You can also just have Musify automatically determine the playlist type to load based on the file's extension: -.. literalinclude:: _howto/scripts/local.playlist.load-save.py +.. literalinclude:: _howto/scripts/local.playlist.load-save/p3.py :language: Python - :lines: 19-21 + :lines: 3- If you already have some tracks loaded, and you want the playlist to only use those tracks instead of loading -the tracks itself, you can pass these preloaded tracks to the playlist too. +the tracks itself, you can pass these preloaded tracks to the playlist too. This will still load the playlist +from the given file, but it will use the given track objects instead of creating new ones. -.. literalinclude:: _howto/scripts/local.playlist.load-save.py +.. literalinclude:: _howto/scripts/local.playlist.load-save/p3_tracks.py :language: Python - :lines: 23-32 + :lines: 3- There may also be cases where the files in the file need mapping to be loaded e.g. if the paths contained in the playlist file are relative paths. You may give the playlist object a :py:class:`.PathMapper` or :py:class:`.PathStemMapper` to handle this. -.. literalinclude:: _howto/scripts/local.playlist.load-save.py +.. literalinclude:: _howto/scripts/local.playlist.load-save/p3_mapper.py :language: Python - :lines: 34-36 + :lines: 3- If you want to be able to read/update URIs on the loaded tracks, you'll need to provide a :py:class:`.RemoteDataWrangler` @@ -59,9 +59,9 @@ to the playlist object for the relevant music streaming source. The following is an example for doing this with Spotify as the data source: -.. literalinclude:: _howto/scripts/local.playlist.load-save.py +.. literalinclude:: _howto/scripts/local.playlist.load-save/p3_wrangler.py :language: Python - :lines: 38-40 + :lines: 3- Modify the playlist @@ -69,12 +69,12 @@ Modify the playlist 1. Add some tracks to the playlist: - .. literalinclude:: _howto/scripts/local.playlist.load-save.py + .. literalinclude:: _howto/scripts/local.playlist.load-save/p4.py :language: Python - :lines: 42-47 + :lines: 3- 2. Save the playlist: - .. literalinclude:: _howto/scripts/local.playlist.load-save.py + .. literalinclude:: _howto/scripts/local.playlist.load-save/p5.py :language: Python - :lines: 49-50 + :lines: 3- diff --git a/docs/howto.local.track.load-save.rst b/docs/howto.local.track.load-save.rst index 90dd5f10..a90d8d35 100644 --- a/docs/howto.local.track.load-save.rst +++ b/docs/howto.local.track.load-save.rst @@ -12,26 +12,23 @@ In this example, you will: Load a track ------------ -You can load a track as follows: +Load a track as follows: -.. literalinclude:: _howto/scripts/local.track.load-save.py +.. literalinclude:: _howto/scripts/local.track.load-save/p1.py :language: Python - :lines: 1-9 You can also just have Musify automatically determine the track type to load based on the file's extension: -.. literalinclude:: _howto/scripts/local.track.load-save.py +.. literalinclude:: _howto/scripts/local.track.load-save/p1_load.py :language: Python - :lines: 11-13 If you want to be able to assign a URI to your track, you'll need to provide a :py:class:`.RemoteDataWrangler` to the track object for the relevant music streaming source. The following is an example for doing this with Spotify as the data source: -.. literalinclude:: _howto/scripts/local.track.load-save.py +.. literalinclude:: _howto/scripts/local.track.load-save/p1_wrangler.py :language: Python - :lines: 15-17 Modify the track's tags @@ -43,12 +40,18 @@ Modify the track's tags 1. Change some tags: - .. literalinclude:: _howto/scripts/local.track.load-save.py + .. literalinclude:: _howto/scripts/local.track.load-save/p2.py :language: Python - :lines: 19-36 + :lines: 3- -2. Save the tags to the file: +2. Save all the modified tags to the file: - .. literalinclude:: _howto/scripts/local.track.load-save.py + .. literalinclude:: _howto/scripts/local.track.load-save/p3_all.py :language: Python - :lines: 38-57 + :lines: 3- + + ... or select exactly which modified tags you wish to save: + + .. literalinclude:: _howto/scripts/local.track.load-save/p3_tags.py + :language: Python + :lines: 3- diff --git a/docs/howto.remote.new-music.rst b/docs/howto.remote.new-music.rst index 50d98ef3..c2f4e936 100644 --- a/docs/howto.remote.new-music.rst +++ b/docs/howto.remote.new-music.rst @@ -13,7 +13,7 @@ In this example, you will: Create the playlist ------------------- -1. Create a remote library object: +1. Create a :py:class:`.RemoteAPI` object: .. note:: This step uses the :py:class:`.SpotifyLibrary`, but any supported music streaming service @@ -22,26 +22,33 @@ Create the playlist .. literalinclude:: _howto/scripts/spotify.api.py :language: Python -2. Load data about your followed artists: +2. If you haven't already, you will need to load and enrich data about your followed artists. + You may use this helper function to help do so: - .. literalinclude:: _howto/scripts/remote.new-music.py + .. literalinclude:: _howto/scripts/remote.new-music/p2.py :language: Python - :lines: 6-7 + :lines: 3- -3. Define the date range you wish to get track for and define this helper function for filtering albums: +3. Define helper function for filtering albums: - .. literalinclude:: _howto/scripts/remote.new-music.py + .. literalinclude:: _howto/scripts/remote.new-music/p3.py :language: Python - :lines: 9-23 + :lines: 3- -4. Filter the albums and load the tracks for only these albums: +4. Define helper function for filtering the albums and loading the tracks for only these albums: - .. literalinclude:: _howto/scripts/remote.new-music.py + .. literalinclude:: _howto/scripts/remote.new-music/p4.py :language: Python - :lines: 26-39 + :lines: 3- -5. Create a new playlist and add these tracks: +5. Define driver function for creating the playlist: - .. literalinclude:: _howto/scripts/remote.new-music.py + .. literalinclude:: _howto/scripts/remote.new-music/p5.py :language: Python - :lines: 41-51 + :lines: 3- + +5. Define the required parameters and run the operation: + + .. literalinclude:: _howto/scripts/remote.new-music/p99.py + :language: Python + :lines: 3- diff --git a/docs/howto.reports.rst b/docs/howto.reports.rst index 9b2d5f0f..70ae0204 100644 --- a/docs/howto.reports.rst +++ b/docs/howto.reports.rst @@ -16,18 +16,18 @@ Report on differences in playlists 2. Run the report: - .. literalinclude:: _howto/scripts/reports.py + .. literalinclude:: _howto/scripts/reports/p1_playlist_differences.py :language: Python - :lines: 9-11 + :lines: 3- Report on missing tags ---------------------- -1. Load a local library or collection of local objects. See other guides for more info on how to do this. +1. Load a local library or collection of local objects. See :ref:`load-local` for more info on how to do this. 2. Run the report: - .. literalinclude:: _howto/scripts/reports.py + .. literalinclude:: _howto/scripts/reports/p1_missing_tags.py :language: Python - :lines: 13-26 + :lines: 3- diff --git a/docs/howto.spotify.load.rst b/docs/howto.spotify.load.rst index c7229fa3..9158ffac 100644 --- a/docs/howto.spotify.load.rst +++ b/docs/howto.spotify.load.rst @@ -34,23 +34,27 @@ Set up the Spotify API Load your library ----------------- -1. Create a :py:class:`.SpotifyLibrary` object and load your library data as follows: +1. Define helper functions for loading your library data: - .. literalinclude:: _howto/scripts/spotify.load.py + .. literalinclude:: _howto/scripts/spotify.load/p1.py :language: Python - :lines: 4-30 + :lines: 3- -2. Load some Spotify objects using any of the supported identifiers as follows: +2. Define helper functions for loading some Spotify objects using any of the supported identifiers: - .. literalinclude:: _howto/scripts/spotify.load.py + .. literalinclude:: _howto/scripts/spotify.load/p2.py :language: Python - :lines: 32-49 + :lines: 3- -3. Add some tracks to a playlist in your library, synchronise with Spotify, and log the results as follows: +3. Define helper function for adding some tracks to a playlist in your library, synchronising with Spotify, + and logging the results: - .. note:: - This step will only work if you chose to load either your playlists or your entire library in step 4. + .. literalinclude:: _howto/scripts/spotify.load/p3.py + :language: Python + :lines: 3- + +4. Run the program: - .. literalinclude:: _howto/scripts/spotify.load.py + .. literalinclude:: _howto/scripts/spotify.load/p99.py :language: Python - :lines: 51-62 + :lines: 3- diff --git a/docs/howto.sync.rst b/docs/howto.sync.rst index 1aee4a51..ed8a3900 100644 --- a/docs/howto.sync.rst +++ b/docs/howto.sync.rst @@ -17,33 +17,40 @@ In this example, you will: Sync data --------- -1. Set up and load at least one local library with a remote wrangler attached, and one remote API object: +1. Define a helper function to search for tracks and check the results: - .. literalinclude:: _howto/scripts/sync.py + .. literalinclude:: _howto/scripts/sync/p1.py :language: Python - :lines: 5-14 + :lines: 3- - .. literalinclude:: _howto/scripts/spotify.api.py +2. Define a helper function to load the matched tracks, get tags from the music streaming service, + and save the tags to the file: + + .. note:: + By default, URIs are saved to the ``comments`` tag. + + .. literalinclude:: _howto/scripts/sync/p2.py :language: Python - :lines: 1-20 + :lines: 3- -2. Search for tracks and check the results: +3. Define a helper function to sync the local playlist with a remote playlist + once all tracks in a playlist have URIs assigned: - .. literalinclude:: _howto/scripts/sync.py + .. literalinclude:: _howto/scripts/sync/p3.py :language: Python - :lines: 16-27 + :lines: 3- -3. Load the matched tracks, get tags from the music streaming service, and save the tags to the file: +4. Set up and load a remote API object and local library with a wrangler attached: - .. note:: - By default, URIs are saved to the ``comments`` tag. + .. literalinclude:: _howto/scripts/spotify.api.py + :language: Python - .. literalinclude:: _howto/scripts/sync.py + .. literalinclude:: _howto/scripts/sync/p4.py :language: Python - :lines: 29-48 + :lines: 3- -4. Once all tracks in a playlist have URIs assigned, sync the local playlist with a remote playlist: +4. Set up the remote library and run the program: - .. literalinclude:: _howto/scripts/sync.py + .. literalinclude:: _howto/scripts/sync/p99.py :language: Python - :lines: 50-60 + :lines: 3- diff --git a/docs/release-history.rst b/docs/release-history.rst index 7ee8bf6b..f613e369 100644 --- a/docs/release-history.rst +++ b/docs/release-history.rst @@ -37,23 +37,38 @@ and this project adheres to `Semantic Versioning T: + self._sanitise_params(kwargs.get("params")) + self._sanitise_params(kwargs.get("data")) + return kwargs + + def _sanitise_params(self, params: MutableMapping[str, Any] | None) -> None: + if not params: + return + + for k, v in params.items(): + if isinstance(v, MutableMapping): + self._sanitise_params(v) + elif isinstance(v, bool) or not isinstance(v, str | int | float): + params[k] = json.dumps(v) + def load_token(self) -> dict[str, Any] | None: """Load stored token from given path""" if not self.token_file_path or not os.path.exists(self.token_file_path): @@ -188,7 +204,7 @@ def save_token(self) -> None: with open(self.token_file_path, "w") as file: json.dump(self.token, file, indent=2) - def __call__(self, force_load: bool = False, force_new: bool = False) -> dict[str, str]: + async def authorise(self, force_load: bool = False, force_new: bool = False) -> dict[str, str]: """ Main method for authorisation which tests/refreshes/reauthorises as needed. @@ -208,11 +224,11 @@ def __call__(self, force_load: bool = False, force_new: bool = False) -> dict[st if self.auth_args and self.token is None: log = "Saved access token not found" if self.token is None else "New token generation forced" self.logger.debug(f"{log}. Generating new token...") - self._authorise_user() - self._request_token(**self.auth_args) + await self._authorise_user() + await self._request_token(**self.auth_args) # test current token - valid = self.test_token() + valid = await self.test_token() refreshed = False # if invalid, first attempt to re-authorise via refresh_token @@ -222,9 +238,9 @@ def __call__(self, force_load: bool = False, force_new: bool = False) -> dict[st if "data" not in self.refresh_args: self.refresh_args["data"] = {} self.refresh_args["data"]["refresh_token"] = self.token["refresh_token"] - self._request_token(**self.refresh_args) + await self._request_token(**self.refresh_args) - valid = self.test_token() + valid = await self.test_token() refreshed = True if not valid and self.auth_args: # generate new token @@ -234,9 +250,9 @@ def __call__(self, force_load: bool = False, force_new: bool = False) -> dict[st log = "Access token is not valid and and no refresh data found" self.logger.debug(f"{log}. Generating new token...") - self._authorise_user() - self._request_token(**self.auth_args) - valid = self.test_token() + await self._authorise_user() + await self._request_token(**self.auth_args) + valid = await self.test_token() if not self.token: raise APIError("Token not generated") @@ -248,7 +264,7 @@ def __call__(self, force_load: bool = False, force_new: bool = False) -> dict[st return self.headers - def _authorise_user(self) -> None: + async def _authorise_user(self) -> None: """ Get user authentication code by authorising through user's browser. @@ -272,17 +288,13 @@ def _authorise_user(self) -> None: print(f"\33[1mWaiting for code, timeout in {socket_listener.timeout} seconds... \33[0m") # add redirect URI to auth_args and user_args - if not self.auth_args.get("data"): - self.auth_args["data"] = {} - if not self.user_args.get("params"): - self.user_args["params"] = {} redirect_uri = f"http://{self._user_auth_socket_address}:{self._user_auth_socket_port}/" - self.auth_args["data"]["redirect_uri"] = redirect_uri - self.user_args["params"]["redirect_uri"] = redirect_uri + self.auth_args.setdefault("data", {}).setdefault("redirect_uri", redirect_uri) + self.user_args.setdefault("params", {}).setdefault("redirect_uri", redirect_uri) # open authorise webpage and wait for the redirect - auth_response = requests.post(**self.user_args) - webopen(auth_response.url) + async with aiohttp.request(method="POST", **self._sanitise_kwargs(self.user_args)) as resp: + webopen(str(resp.url)) request, _ = socket_listener.accept() request.send(f"Code received! You may now close this window and return to {PROGRAM_NAME}...".encode("utf-8")) @@ -291,13 +303,10 @@ def _authorise_user(self) -> None: # format out the access code from the returned response path_raw = next(line for line in request.recv(8196).decode("utf-8").split('\n') if line.startswith("GET")) - code = parse_qs(urlparse(path_raw).query)["code"][0] - - if "data" not in self.auth_args: - self.auth_args["data"] = {} - self.auth_args["data"]["code"] = code + code = unquote(URL(path_raw).query["code"]) + self.auth_args.setdefault("data", {}).setdefault("code", code) - def _request_token(self, **requests_args) -> dict[str, Any]: + async def _request_token(self, **requests_args) -> dict[str, Any]: """ Authenticates/refreshes basic API access and returns token. @@ -305,7 +314,8 @@ def _request_token(self, **requests_args) -> dict[str, Any]: :param data: requests.post() ``data`` parameter to send as a request for authorisation. :param requests_args: Other requests.post() parameters to send as a request for authorisation. """ - auth_response = requests.post(**requests_args).json() + async with aiohttp.request(method="POST", **self._sanitise_kwargs(requests_args)) as resp: + auth_response = await resp.json() # add granted and expiry times to token auth_response["granted_at"] = datetime.now().timestamp() @@ -322,7 +332,7 @@ def _request_token(self, **requests_args) -> dict[str, Any]: self.logger.debug(f"New token successfully generated: {self.token_safe}") return auth_response - def test_token(self) -> bool: + async def test_token(self) -> bool: """Test validity of token and given headers. Returns True if all tests pass, False otherwise""" if not self.token: return False @@ -333,7 +343,7 @@ def test_token(self) -> bool: if not token_has_no_error: # skip other tests if error return False - valid_response = self._test_valid_response() + valid_response = await self._test_valid_response() not_expired = self._test_expiry() return token_has_no_error and valid_response and not_expired @@ -344,16 +354,18 @@ def _test_no_error(self) -> bool: self.logger.debug(f"Token contains no error test: {result}") return result - def _test_valid_response(self) -> bool: + async def _test_valid_response(self) -> bool: """Check for expected response""" if self.test_args is None or self.test_condition is None: return True - response = requests.get(headers=self.headers, **self.test_args) - try: - response = response.json() - except json.decoder.JSONDecodeError: - response = response.text + async with aiohttp.request( + method="GET", headers=self.headers, **self._sanitise_kwargs(self.test_args) + ) as response: + try: + response = await response.json() + except json.decoder.JSONDecodeError: + response = await response.text() result = self.test_condition(response) self.logger.debug(f"Valid response test: {result}") diff --git a/musify/api/cache/backend/base.py b/musify/api/cache/backend/base.py index 27a851b1..ff58987a 100644 --- a/musify/api/cache/backend/base.py +++ b/musify/api/cache/backend/base.py @@ -1,26 +1,23 @@ import logging from abc import ABC, abstractmethod -from collections.abc import MutableMapping, Callable, Collection, Hashable +from collections.abc import MutableMapping, Callable, Collection, Hashable, AsyncIterable, Mapping, Awaitable from dataclasses import dataclass from datetime import datetime, timedelta -from typing import Any, Protocol, Self +from typing import Any, Self, AsyncContextManager from dateutil.relativedelta import relativedelta -from requests import Request, PreparedRequest, Response +from aiohttp import RequestInfo, ClientRequest, ClientResponse +from yarl import URL from musify.api.exception import CacheError from musify.log.logger import MusifyLogger from musify.types import UnitCollection from musify.utils import to_collection -DEFAULT_EXPIRE: timedelta = timedelta(weeks=1) - - -class Connection(Protocol): - """The expected protocol for a backend connection""" +type CacheRequestType = RequestInfo | ClientRequest | ClientResponse +type RepositoryRequestType[K] = K | CacheRequestType - def close(self) -> None: - """Close the connection to the repository.""" +DEFAULT_EXPIRE: timedelta = timedelta(weeks=1) @dataclass @@ -29,46 +26,34 @@ class RequestSettings(ABC): #: That name of the repository in the backend name: str - @staticmethod + @property @abstractmethod - def get_name(value: Any) -> str | None: - """Extracts the name to assign to a cache entry in the repository from a given ``value``.""" + def fields(self) -> tuple[str, ...]: + """ + The names of the fields relating to the keys extracted by :py:meth:`get_key` in the order + in which they appear from the results of this method. + """ raise NotImplementedError - @staticmethod @abstractmethod - def get_id(url: str) -> str | None: - """Extracts the ID for a request from the given ``url``.""" - raise NotImplementedError - - -class PaginatedRequestSettings(RequestSettings, ABC): - """ - Settings for a request type for a given endpoint which returns a paginated response - to be used to configure a repository in the cache backend. - """ - - @staticmethod - @abstractmethod - def get_offset(url: str) -> int: - """Extracts the offset for a paginated request from the given ``url``.""" + def get_key(self, *args, **kwargs) -> tuple: + """Extracts the name to assign to a cache entry in the repository from a given ``value``.""" raise NotImplementedError @staticmethod @abstractmethod - def get_limit(url: str) -> int: - """Extracts the limit for a paginated request from the given ``url``.""" + def get_name(response: dict[str, Any]) -> str | None: + """Extracts the name to assign to a cache entry in the repository from a given ``response``.""" raise NotImplementedError -class ResponseRepository[T: Connection, KT, VT](MutableMapping[KT, VT], Hashable, ABC): +class ResponseRepository[K, V](AsyncIterable[tuple[K, V]], Awaitable, AsyncContextManager, Hashable, ABC): """ Represents a repository in the backend cache, providing a dict-like interface for interacting with this repository. A repository is a data store within the backend e.g. a table in a database. - :param connection: The connection to the backend cache. :param settings: The settings to use to identify and interact with the repository in the backend. :param expire: The expiry time to apply to cached responses after which responses are invalidated. """ @@ -80,111 +65,145 @@ def expire(self) -> datetime: """The datetime representing the maximum allowed expiry time from now.""" return datetime.now() + self._expire - def __init__(self, connection: T, settings: RequestSettings, expire: timedelta | relativedelta = DEFAULT_EXPIRE): + @classmethod + @abstractmethod + def create(cls, *args, **kwargs) -> Self: + """ + Set up the backend repository in the backend database if it doesn't already exist + and return the initialised object that represents it. + """ + raise NotImplementedError + + def __init__(self, settings: RequestSettings, expire: timedelta | relativedelta = DEFAULT_EXPIRE): # noinspection PyTypeChecker #: The :py:class:`MusifyLogger` for this object self.logger: MusifyLogger = logging.getLogger(__name__) - self.connection = connection self.settings = settings self._expire = expire def __hash__(self): return hash(self.settings.name) - def close(self) -> None: - """Close the connection to the repository.""" - self.commit() - self.connection.close() - @abstractmethod - def commit(self) -> None: + async def commit(self) -> None: """Commit the changes to the data""" raise NotImplementedError @abstractmethod - def count(self, expired: bool = True) -> int: + async def close(self) -> None: + """Close the connection to the repository.""" + raise NotImplementedError + + @abstractmethod + async def count(self, include_expired: bool = True) -> int: """ Get the number of responses in this repository. - :param expired: Whether to include expired responses in the final count. + :param include_expired: Whether to include expired responses in the final count. :return: The number of responses in this repository. """ raise NotImplementedError @abstractmethod - def serialize(self, value: Any) -> VT: + async def contains(self, request: RepositoryRequestType[K]) -> bool: + """Check whether the repository contains a given ``request``""" + raise NotImplementedError + + @abstractmethod + async def clear(self, expired_only: bool = False) -> int: + """ + Clear the repository of all entries. + + :param expired_only: Whether to only remove responses that have expired. + :return: The number of responses cleared from the repository. + """ + raise NotImplementedError + + @abstractmethod + def serialize(self, value: Any) -> V: """Serialize a given ``value`` to a type that can be persisted to the repository.""" raise NotImplementedError @abstractmethod - def deserialize(self, value: VT) -> Any: + def deserialize(self, value: V) -> Any: """Deserialize a value from the repository to the expected response value type.""" raise NotImplementedError @abstractmethod - def get_key_from_request(self, request: Request | PreparedRequest | Response) -> KT: - """Extract the keys to use when persisting responses for a given ``request``""" + def get_key_from_request(self, request: RepositoryRequestType[K]) -> K: + """Extract the key to use when persisting responses for a given ``request``""" raise NotImplementedError - def get_response(self, request: KT | Request | PreparedRequest | Response) -> VT | None: + @abstractmethod + async def get_response(self, request: RepositoryRequestType[K]) -> V | None: """Get the response relating to the given ``request`` from this repository if it exists.""" - if isinstance(request, Request | PreparedRequest | Response): - request = self.get_key_from_request(request) - if not request: - return - - return self.get(request, None) + raise NotImplementedError - def get_responses(self, requests: Collection[KT | Request | PreparedRequest | Response]) -> list[VT]: + async def get_responses(self, requests: Collection[RepositoryRequestType[K]]) -> list[V]: """ Get the responses relating to the given ``requests`` from this repository if they exist. Returns results unordered. """ - results = [self.get_response(request) for request in requests] + results = [await self.get_response(request) for request in requests] return [result for result in results if result is not None] - def save_response(self, response: Response) -> None: + async def save_response(self, response: Collection[K, V] | ClientResponse) -> None: """Save the given ``response`` to this repository if a key can be extracted from it. Safely fail if not""" - keys = self.get_key_from_request(response) - if not keys: + if isinstance(response, Collection): + key, value = response + await self._set_item_from_key_value_pair(key, value) + return + + key = self.get_key_from_request(response) + if not key: return - self[keys] = response.text + await self._set_item_from_key_value_pair(key, await response.text()) - def save_responses(self, responses: Collection[Response]) -> None: + @abstractmethod + async def _set_item_from_key_value_pair(self, __key: K, __value: Any) -> None: + raise NotImplementedError + + async def save_responses(self, responses: Mapping[K, V] | Collection[ClientResponse]) -> None: """ - Save the given ``responses`` to this repository if keys can be extracted from them. + Save the given ``responses`` to this repository if a key can be extracted from them. Safely fail on those that can't. """ + if isinstance(responses, Mapping): + for key, value in responses.items(): + await self._set_item_from_key_value_pair(key, value) + return + for response in responses: - self.save_response(response) + await self.save_response(response) - def delete_response(self, request: KT | Request | PreparedRequest | Response) -> None: - """Delete the given ``request`` from this repository if it exists.""" - if isinstance(request, Request | PreparedRequest | Response): - request = self.get_key_from_request(request) - if not request: - return - self.pop(request, None) + @abstractmethod + async def delete_response(self, request: RepositoryRequestType[K]) -> bool: + """ + Delete the given ``request`` from this repository if it exists. + Returns True if deleted in the repository and False if ``request`` was not found in the repository. + """ + raise NotImplementedError - def delete_responses(self, requests: Collection[KT | Request | PreparedRequest | Response]) -> None: - """Delete the given ``requests`` from this repository if they exist.""" - for request in requests: - self.delete_response(request) + async def delete_responses(self, requests: Collection[RepositoryRequestType[K]]) -> int: + """ + Delete the given ``requests`` from this repository if they exist. + Returns the number of the given ``requests`` deleted in the repository. + """ + return sum([await self.delete_response(request) for request in requests]) -class ResponseCache[CT: Connection, ST: ResponseRepository](MutableMapping[str, ST], ABC): +class ResponseCache[ST: ResponseRepository](MutableMapping[str, ST], Awaitable, AsyncContextManager): """ Represents a backend cache of many repositories, providing a dict-like interface for interacting with them. :param cache_name: The name to give to this cache. - :param connection: The connection to the backend cache. :param repository_getter: A function that can be used to identify the repository in this cache that matches a given URL. :param expire: The expiry time to apply to cached responses after which responses are invalidated. """ - __slots__ = ("cache_name", "connection", "repository_getter", "expire", "_repositories") + __slots__ = ("cache_name", "repository_getter", "expire", "_repositories") # noinspection PyPropertyDefinition @classmethod @@ -196,27 +215,29 @@ def type(cls) -> str: @classmethod @abstractmethod - def connect(cls, value: Any, **kwargs) -> Self: + async def connect(cls, value: Any, **kwargs) -> Self: """Connect to the backend from a given generic ``value``.""" raise NotImplementedError def __init__( self, cache_name: str, - connection: CT, - repository_getter: Callable[[Self, str], ST] = None, + repository_getter: Callable[[Self, str | URL], ST] = None, expire: timedelta | relativedelta = DEFAULT_EXPIRE, ): super().__init__() self.cache_name = cache_name - self.connection = connection self.repository_getter = repository_getter self.expire = expire self._repositories: dict[str, ST] = {} + async def __aenter__(self) -> Self: + return await self + def __repr__(self): + self.__aenter__() return repr(self._repositories) def __str__(self): @@ -237,9 +258,15 @@ def __setitem__(self, key, value): def __delitem__(self, key): del self._repositories[key] - def close(self): + @abstractmethod + async def commit(self) -> None: + """Commit the changes to the data""" + raise NotImplementedError + + @abstractmethod + async def close(self): """Close the connection to the repository.""" - self.connection.close() + raise NotImplementedError @abstractmethod def create_repository(self, settings: RequestSettings) -> ResponseRepository: @@ -250,12 +277,12 @@ def create_repository(self, settings: RequestSettings) -> ResponseRepository: """ raise NotImplementedError - def get_repository_from_url(self, url: str) -> ST | None: + def get_repository_from_url(self, url: str | URL) -> ST | None: """Returns the repository to use from the stored repositories in this cache for the given ``url``.""" if self.repository_getter is not None: return self.repository_getter(self, url) - def get_repository_from_requests(self, requests: UnitCollection[Request | PreparedRequest | Response]) -> ST | None: + def get_repository_from_requests(self, requests: UnitCollection[CacheRequestType]) -> ST | None: """Returns the repository to use from the stored repositories in this cache for the given ``requests``.""" requests = to_collection(requests) results = {self.get_repository_from_url(request.url) for request in requests} @@ -265,44 +292,50 @@ def get_repository_from_requests(self, requests: UnitCollection[Request | Prepar ) return next(iter(results)) - def get_response(self, request: Request | PreparedRequest | Response) -> Any: + async def get_response(self, request: CacheRequestType) -> Any: """Get the response relating to the given ``request`` from the appropriate repository if it exists.""" repository = self.get_repository_from_requests([request]) if repository is not None: - return repository.get_response(request) + return await repository.get_response(request) - def get_responses(self, requests: Collection[Request | PreparedRequest | Response]) -> list: + async def get_responses(self, requests: Collection[CacheRequestType]) -> list: """ Get the responses relating to the given ``requests`` from the appropriate repository if they exist. Returns results unordered. """ repository = self.get_repository_from_requests(requests) if repository is not None: - return repository.get_responses(requests) + return await repository.get_responses(requests) - def save_response(self, response: Response) -> None: + async def save_response(self, response: ClientResponse) -> None: """Save the given ``response`` to the appropriate repository if a key can be extracted from it.""" repository = self.get_repository_from_requests([response]) if repository is not None: - return repository.save_response(response) + return await repository.save_response(response) - def save_responses(self, responses: Collection[Response]) -> None: + async def save_responses(self, responses: Collection[ClientResponse]) -> None: """ - Save the given ``responses`` to the appropriate repository if keys can be extracted from them. + Save the given ``responses`` to the appropriate repository if a key can be extracted from them. Safely fail on those that can't. """ repository = self.get_repository_from_requests(responses) if repository is not None: - return repository.save_responses(responses) + return await repository.save_responses(responses) - def delete_response(self, request: Request | PreparedRequest | Response) -> None: - """Delete the given ``request`` from the appropriate repository if it exists.""" + async def delete_response(self, request: CacheRequestType) -> bool: + """ + Delete the given ``request`` from the appropriate repository if it exists. + Returns True if deleted in the repository and False if ``request`` was not found in the repository. + """ repository = self.get_repository_from_requests([request]) if repository is not None: - return repository.delete_response(request) + return await repository.delete_response(request) - def delete_responses(self, requests: Collection[Request | PreparedRequest | Response]) -> None: - """Delete the given ``requests`` from the appropriate repository.""" + async def delete_responses(self, requests: Collection[CacheRequestType]) -> int: + """ + Delete the given ``requests`` from the appropriate repository. + Returns the number of the given ``requests`` deleted in the repository. + """ repository = self.get_repository_from_requests(requests) if repository is not None: - return repository.delete_responses(requests) + return await repository.delete_responses(requests) diff --git a/musify/api/cache/backend/sqlite.py b/musify/api/cache/backend/sqlite.py index 2202bc19..5877ddcf 100644 --- a/musify/api/cache/backend/sqlite.py +++ b/musify/api/cache/backend/sqlite.py @@ -1,7 +1,8 @@ +from __future__ import annotations + import json import os -import sqlite3 -from collections.abc import Mapping +from collections.abc import Mapping, Callable from datetime import datetime, timedelta from os.path import splitext, dirname, join from pathlib import Path @@ -9,16 +10,24 @@ from typing import Any, Self from dateutil.relativedelta import relativedelta -from requests import Request, PreparedRequest, Response +from aiohttp import RequestInfo, ClientRequest, ClientResponse +from yarl import URL from musify import PROGRAM_NAME -from musify.api.cache.backend.base import ResponseCache, ResponseRepository, RequestSettings, PaginatedRequestSettings -from musify.api.cache.backend.base import DEFAULT_EXPIRE +from musify.api.cache.backend.base import DEFAULT_EXPIRE, ResponseCache, ResponseRepository, RepositoryRequestType +from musify.api.cache.backend.base import RequestSettings from musify.api.exception import CacheError -from musify.exception import MusifyKeyError +from musify.utils import required_modules_installed + +try: + import aiosqlite +except ImportError: + aiosqlite = None + +REQUIRED_MODULES = [aiosqlite] -class SQLiteTable[KT: tuple[Any, ...], VT: str](ResponseRepository[sqlite3.Connection, KT, VT]): +class SQLiteTable[K: tuple[Any, ...], V: str](ResponseRepository[K, V]): __slots__ = () @@ -31,45 +40,8 @@ class SQLiteTable[KT: tuple[Any, ...], VT: str](ResponseRepository[sqlite3.Conne #: The column under which the response expiry time is stored in the table expiry_column = "expires_at" - @property - def _primary_key_columns(self) -> Mapping[str, str]: - """A map of column names to column data types for the primary keys of this repository.""" - keys = {"method": "VARCHAR(10)", "id": "VARCHAR(50)"} - if isinstance(self.settings, PaginatedRequestSettings): - keys["offset"] = "INT2" - keys["size"] = "INT2" - - return keys - - def get_key_from_request(self, request: Request | PreparedRequest | Response) -> KT | None: - if isinstance(request, Response): - request = request.request - - id_ = self.settings.get_id(request.url) - if not id_: - return - - keys = [str(request.method), id_] - if isinstance(self.settings, PaginatedRequestSettings): - keys.append(self.settings.get_offset(request.url)) - keys.append(self.settings.get_limit(request.url)) - - return tuple(keys) - - def __init__( - self, - connection: sqlite3.Connection, - settings: RequestSettings, - expire: timedelta | relativedelta = DEFAULT_EXPIRE - ): - super().__init__(connection=connection, settings=settings, expire=expire) - - self.create_table() - - def create_table(self): - """Create the table for this repository type in the backend database if it doesn't already exist.""" + async def create(self) -> Self: ddl_sep = "\t, " - ddl = "\n".join(( f"CREATE TABLE IF NOT EXISTS {self.settings.name} (", "\t" + f"\n{ddl_sep}".join( @@ -80,45 +52,126 @@ def create_table(self): f"{ddl_sep}{self.expiry_column} TIMESTAMP NOT NULL", f"{ddl_sep}{self.data_column} TEXT", f"{ddl_sep}PRIMARY KEY ({", ".join(self._primary_key_columns)})", - ");\n" - f"CREATE INDEX IF NOT EXISTS idx_{self.expiry_column} ON {self.settings.name}({self.expiry_column});" + ");", + f"CREATE INDEX IF NOT EXISTS idx_{self.expiry_column} " + f"ON {self.settings.name}({self.expiry_column});" )) self.logger.debug(f"Creating {self.settings.name!r} table with the following DDL:\n{ddl}") - self.connection.executescript(ddl) + await self.connection.executescript(ddl) + await self.commit() + + return self + + def __init__( + self, + connection: aiosqlite.Connection, + settings: RequestSettings, + expire: timedelta | relativedelta = DEFAULT_EXPIRE, + ): + required_modules_installed(REQUIRED_MODULES, self) + + super().__init__(settings=settings, expire=expire) + + self.connection = connection + + def __await__(self): + return self.create().__await__() + + async def __aenter__(self) -> Self: + if not self.connection.is_alive(): + await self.connection + return await self + + async def __aexit__(self, __exc_type, __exc_value, __traceback) -> None: + if self.connection.is_alive(): + await self.connection.__aexit__(__exc_type, __exc_value, __traceback) + + async def commit(self) -> None: + await self.connection.commit() + + async def close(self) -> None: + await self.commit() + await self.connection.close() + + @property + def _primary_key_columns(self) -> Mapping[str, str]: + """A map of column names to column data types for the primary keys of this repository.""" + expected_columns = self.settings.fields - def commit(self) -> None: - self.connection.commit() + keys = {"method": "VARCHAR(10)"} + if "id" in expected_columns: + keys["id"] = "VARCHAR(50)" + if "offset" in expected_columns: + keys["offset"] = "INT2" + if "size" in expected_columns: + keys["size"] = "INT2" + + return keys + + def get_key_from_request(self, request: RepositoryRequestType[K]) -> K | None: + if isinstance(request, ClientRequest | ClientResponse): + request = request.request_info + if not isinstance(request, RequestInfo): + return request # `request` is the key - def count(self, expired: bool = True) -> int: + key = self.settings.get_key(request.url) + if any(part is None for part in key): + return + + return str(request.method).upper(), *key + + async def count(self, include_expired: bool = True) -> int: query = f"SELECT COUNT(*) FROM {self.settings.name}" - if expired: - cur = self.connection.execute(query) - else: - query += f"\nWHERE {self.expiry_column} IS NULL OR {self.expiry_column} > ?" - cur = self.connection.execute(query, (datetime.now().isoformat(),)) + params = [] + + if not include_expired: + query += f"\nWHERE {self.expiry_column} > ?" + params.append(datetime.now().isoformat()) + + async with self.connection.execute(query, params) as cur: + row = await cur.fetchone() - return cur.fetchone()[0] + return row[0] - def __repr__(self): - return repr(dict(self.items())) + async def contains(self, request: RepositoryRequestType[K]) -> bool: + key = self.get_key_from_request(request) + query = "\n".join(( + f"SELECT COUNT(*) FROM {self.settings.name}", + f"WHERE {self.expiry_column} > ?", + f"\tAND {"\n\tAND ".join(f"{key} = ?" for key in self._primary_key_columns)}", + )) + async with self.connection.execute(query, (datetime.now().isoformat(), *key)) as cur: + rows = await cur.fetchone() + return rows[0] > 0 + + async def clear(self, expired_only: bool = False) -> int: + query = f"DELETE FROM {self.settings.name}" + params = [] - def __str__(self): - return str(dict(self.items())) + if expired_only: + query += f"\nWHERE {self.expiry_column} > ?" + params.append(datetime.now().isoformat()) - def __iter__(self): + async with self.connection.execute(query, params) as cur: + count = cur.rowcount + return count + + async def __aiter__(self): query = "\n".join(( f"SELECT {", ".join(self._primary_key_columns)}, {self.data_column} ", f"FROM {self.settings.name}", f"WHERE {self.expiry_column} > ?", )) - for row in self.connection.execute(query, (datetime.now().isoformat(),)): - yield row[:-1] + async with self.connection.execute(query, (datetime.now().isoformat(),)) as cur: + async for row in cur: + yield row[:-1], self.deserialize(row[-1]) - def __len__(self): - return self.count(expired=False) + async def get_response(self, request: RepositoryRequestType[K]) -> V | None: + key = self.get_key_from_request(request) + if not key: + return - def __getitem__(self, __key): query = "\n".join(( f"SELECT {self.data_column} FROM {self.settings.name}", f"WHERE {self.data_column} IS NOT NULL", @@ -126,15 +179,14 @@ def __getitem__(self, __key): f"\tAND {"\n\tAND ".join(f"{key} = ?" for key in self._primary_key_columns)}", )) - cur = self.connection.execute(query, (datetime.now().isoformat(), *__key)) - row = cur.fetchone() - cur.close() - if not row: - raise MusifyKeyError(__key) + async with self.connection.execute(query, (datetime.now().isoformat(), *key)) as cur: + row = await cur.fetchone() + if not row: + return return self.deserialize(row[0]) - def __setitem__(self, __key, __value): + async def _set_item_from_key_value_pair(self, __key: K, __value: Any) -> None: columns = ( *self._primary_key_columns, self.name_column, @@ -148,27 +200,28 @@ def __setitem__(self, __key, __value): ") ", f"VALUES({",".join("?" * len(columns))});", )) - params = ( *__key, - self.settings.get_name(__value), + self.settings.get_name(self.deserialize(__value)), datetime.now().isoformat(), self.expire.isoformat(), self.serialize(__value) ) - self.connection.execute(query, params) - def __delitem__(self, __key): + await self.connection.execute(query, params) + + async def delete_response(self, request: RepositoryRequestType[K]) -> bool: + key = self.get_key_from_request(request) query = "\n".join(( f"DELETE FROM {self.settings.name}", f"WHERE {"\n\tAND ".join(f"{key} = ?" for key in self._primary_key_columns)}", )) - cur = self.connection.execute(query, __key) - if not cur.rowcount: - raise MusifyKeyError(__key) + async with self.connection.execute(query, key) as cur: + count = cur.rowcount + return count > 0 - def serialize(self, value: Any) -> VT | None: + def serialize(self, value: Any) -> V | None: if isinstance(value, str): try: json.loads(value) # check it is a valid json value @@ -178,7 +231,7 @@ def serialize(self, value: Any) -> VT | None: return json.dumps(value, indent=2) - def deserialize(self, value: VT | dict) -> Any: + def deserialize(self, value: V | dict) -> Any: if isinstance(value, dict): return value @@ -188,9 +241,9 @@ def deserialize(self, value: VT | dict) -> Any: return -class SQLiteCache(ResponseCache[sqlite3.Connection, SQLiteTable]): +class SQLiteCache(ResponseCache[SQLiteTable]): - __slots__ = () + __slots__ = ("connection",) # noinspection PyPropertyDefinition @classmethod @@ -198,6 +251,11 @@ class SQLiteCache(ResponseCache[sqlite3.Connection, SQLiteTable]): def type(cls): return "sqlite" + @property + def closed(self): + """Is the stored client session closed.""" + return self.connection is None or not self.connection.is_alive() + @staticmethod def _get_sqlite_path(path: str) -> str: if not splitext(path)[1] == ".sqlite": # add/replace extension if not given @@ -212,9 +270,7 @@ def _clean_kwargs[T: dict](kwargs: T) -> T: @classmethod def connect(cls, value: Any, **kwargs) -> Self: - cache = cls.connect_with_path(path=value, **kwargs) - cache.connection.autocommit = True - return cache + return cls.connect_with_path(path=value, **kwargs) @classmethod def connect_with_path(cls, path: str | Path, **kwargs) -> Self: @@ -223,22 +279,71 @@ def connect_with_path(cls, path: str | Path, **kwargs) -> Self: if dirname(path): os.makedirs(dirname(path), exist_ok=True) - connection = sqlite3.Connection(database=path) - return cls(cache_name=path, connection=connection, **cls._clean_kwargs(kwargs)) + return cls( + cache_name=path, + connector=lambda: aiosqlite.connect(database=path), + **cls._clean_kwargs(kwargs) + ) @classmethod def connect_with_in_memory_db(cls, **kwargs) -> Self: """Connect with an in-memory SQLite DB and return an instantiated :py:class:`SQLiteResponseCache`""" - connection = sqlite3.Connection(database="file::memory:?cache=shared", uri=True) - return cls(cache_name="__IN_MEMORY__", connection=connection, **cls._clean_kwargs(kwargs)) + return cls( + cache_name="__IN_MEMORY__", + connector=lambda: aiosqlite.connect(database="file::memory:?cache=shared", uri=True), + **cls._clean_kwargs(kwargs) + ) @classmethod def connect_with_temp_db(cls, name: str = f"{PROGRAM_NAME.lower()}_db.tmp", **kwargs) -> Self: """Connect with a temporary SQLite DB and return an instantiated :py:class:`SQLiteResponseCache`""" path = cls._get_sqlite_path(join(gettempdir(), name)) + return cls( + cache_name=name, + connector=lambda: aiosqlite.connect(database=path), + **cls._clean_kwargs(kwargs) + ) + + def __init__( + self, + cache_name: str, + connector: Callable[[], aiosqlite.Connection], + repository_getter: Callable[[Self, str | URL], SQLiteTable] = None, + expire: timedelta | relativedelta = DEFAULT_EXPIRE, + ): + required_modules_installed(REQUIRED_MODULES, self) + + super().__init__(cache_name=cache_name, repository_getter=repository_getter, expire=expire) + + self._connector = connector + self.connection: aiosqlite.Connection | None = None + + async def _connect(self) -> Self: + if self.closed: + self.connection = self._connector() + await self.connection + + for repository in self._repositories.values(): + repository.connection = self.connection + await repository.create() + + return self + + def __await__(self) -> Self: + return self._connect().__await__() + + async def __aexit__(self, __exc_type, __exc_value, __traceback) -> None: + if not self.closed: # TODO: this shouldn't be needed? + await self.connection.__aexit__(__exc_type, __exc_value, __traceback) + self.connection = None + + async def commit(self): + """Commit the transactions to the database.""" + await self.connection.commit() - connection = sqlite3.Connection(database=path) - return cls(cache_name=name, connection=connection, **cls._clean_kwargs(kwargs)) + async def close(self): + await self.commit() + await self.connection.close() def create_repository(self, settings: RequestSettings) -> SQLiteTable: if settings.name in self: diff --git a/musify/api/cache/response.py b/musify/api/cache/response.py new file mode 100644 index 00000000..2cfba833 --- /dev/null +++ b/musify/api/cache/response.py @@ -0,0 +1,40 @@ +from asyncio import StreamReader + +from aiohttp import ClientRequest, ClientResponse +# noinspection PyProtectedMember +from aiohttp.helpers import TimerNoop +from multidict import CIMultiDictProxy + + +class CachedResponse(ClientResponse): + """Emulates :py:class:`ClientResponse` for a response found in a cache backed.""" + + def __init__(self, request: ClientRequest, data: str | bytes): + # noinspection PyTypeChecker,PyProtectedMember + super().__init__( + method=request.method, + url=request.url, + writer=None, + continue100=None, + timer=TimerNoop(), + request_info=request.request_info, + traces=[], + loop=request.loop, + session=request._session, + ) + + # response status + self.version = request.version + self.status = 200 + self.reason = "cached" + + # headers + self._headers = CIMultiDictProxy(request.headers) + self._raw_headers = () + + self.content = StreamReader(loop=self._loop) + + if isinstance(data, str): + data = data.encode() + self.content.feed_data(data) + self.content.feed_eof() diff --git a/musify/api/cache/session.py b/musify/api/cache/session.py index 993bbdb0..df3a5ad4 100644 --- a/musify/api/cache/session.py +++ b/musify/api/cache/session.py @@ -1,9 +1,17 @@ -from requests import Session, Request, Response, PreparedRequest +import contextlib +from http.client import InvalidURL +from typing import Self + +from aiohttp import ClientSession, ClientRequest +from yarl import URL from musify.api.cache.backend.base import ResponseCache, ResponseRepository +from musify.api.cache.response import CachedResponse + +ClientSession.__init_subclass__ = lambda *_, **__: _ # WORKAROUND: disables inheritance warning -class CachedSession(Session): +class CachedSession(ClientSession): """ A modified session which attempts to get/save responses from/to a stored cache before/after sending it. @@ -12,139 +20,70 @@ class CachedSession(Session): __slots__ = ("cache",) - def __init__(self, cache: ResponseCache): - super().__init__() + def __init__(self, cache: ResponseCache, **kwargs): + super().__init__(**kwargs) #: The cache to use when attempting to return a cached response. self.cache = cache - def request( - self, - method, - url, - params=None, - data=None, - headers=None, - cookies=None, - files=None, - auth=None, - timeout=None, - allow_redirects=True, - proxies=None, - hooks=None, - stream=None, - verify=None, - cert=None, - json=None, - persist: bool = True - ): + async def __aenter__(self) -> Self: + self.cache = await self.cache.__aenter__() + await super().__aenter__() + return self + + async def __aexit__(self, exc_type, exc_val, exc_tb) -> None: + await super().__aexit__(exc_type, exc_val, exc_tb) + await self.cache.__aexit__(exc_type, exc_val, exc_tb) + + @contextlib.asynccontextmanager + async def request(self, method: str, url: str | URL, persist: bool = True, **kwargs): """ - Constructs a :class:`Request ` and prepares it. - First attempts to find the response for the request in the cache and, if not found, sends it. - If ``persist`` is True request was sent and matching cache repository was found and , - persist the response to the repository. - Returns :class:`Response ` object. - - :param method: method for the new :class:`Request` object. - :param url: URL for the new :class:`Request` object. - :param params: (optional) Dictionary or bytes to be sent in the query - string for the :class:`Request`. - :param data: (optional) Dictionary, list of tuples, bytes, or file-like - object to send in the body of the :class:`Request`. - :param json: (optional) json to send in the body of the - :class:`Request`. - :param headers: (optional) Dictionary of HTTP Headers to send with the - :class:`Request`. - :param cookies: (optional) Dict or CookieJar object to send with the - :class:`Request`. - :param files: (optional) Dictionary of ``'filename': file-like-objects`` - for multipart encoding upload. - :param auth: (optional) Auth tuple or callable to enable - Basic/Digest/Custom HTTP Auth. - :param timeout: (optional) How long to wait for the server to send - data before giving up, as a float, or a `(connect timeout, - read timeout)` tuple. - :type timeout: float or tuple - :param allow_redirects: (optional) Set to True by default. - :type allow_redirects: bool - :param proxies: (optional) Dictionary mapping protocol or protocol and - hostname to the URL of the proxy. - :param hooks: Unknown. - :param stream: (optional) whether to immediately download the response - content. Defaults to ``False``. - :param verify: (optional) Either a boolean, in which case it controls whether we verify - the server's TLS certificate, or a string, in which case it must be a path - to a CA bundle to use. Defaults to ``True``. When set to - ``False``, requests will accept any TLS certificate presented by - the server, and will ignore hostname mismatches and/or expired - certificates, which will make your application vulnerable to - man-in-the-middle (MitM) attacks. Setting verify to ``False`` - may be useful during local development or testing. - :param cert: (optional) if String, path to ssl client cert file (.pem). - If Tuple, ('cert', 'key') pair. + Perform HTTP request. + + :param method: HTTP request method (such as GET, POST, PUT, etc.) + :param url: The URL to perform the request on. :param persist: Whether to persist responses returned from sending network requests i.e. non-cached responses. - :rtype: requests.Response + :return: Either the :py:class:`CachedResponse` if a response was found in the cache, + or the :py:class:`ClientResponse` if the request was sent. """ - req = Request( - method=method.upper(), + try: + url = self._build_url(url) + except ValueError as e: + raise InvalidURL(url) from e + + kwargs["headers"] = kwargs.get("headers", {}) | dict(self.headers) + req = ClientRequest( + method=method, url=url, - headers=headers, - files=files, - data=data or {}, - json=json, - params=params or {}, - auth=auth, - cookies=cookies, - hooks=hooks, + loop=self._loop, + response_class=self._response_class, + session=self, + trust_env=self.trust_env, + **kwargs, ) - prep = self.prepare_request(req) - - repository = self.cache.get_repository_from_requests(prep) - response = self._get_cached_response(prep, repository=repository) + repository = self.cache.get_repository_from_requests(req.request_info) + response = await self._get_cached_response(req, repository=repository) if response is None: - response = super().request( - method, - url, - params=params, - data=data, - headers=headers, - cookies=cookies, - files=files, - auth=auth, - timeout=timeout, - allow_redirects=allow_redirects, - proxies=proxies, - hooks=hooks, - stream=stream, - verify=verify, - cert=cert, - json=json, - ) - - if persist and repository is not None: - repository.save_response(response) - - return response - - def _get_cached_response(self, request: PreparedRequest, repository: ResponseRepository | None) -> Response | None: + response = await super().request(method=method, url=url, **kwargs) + + yield response + + if persist and repository is not None and not isinstance(response, CachedResponse): + await repository.save_response(response) + + async def _get_cached_response( + self, request: ClientRequest, repository: ResponseRepository | None + ) -> CachedResponse | None: if repository is None: return - cached_data = repository.get_response(request) - if cached_data is None: + data = await repository.get_response(request) + if data is None: return - # emulate a response object and return it - if not isinstance(cached_data, str): + if not isinstance(data, str | bytes): repository = self.cache.get_repository_from_url(request.url) - cached_data = repository.serialize(cached_data) - - response = Response() - response.encoding = "utf-8" - response._content = cached_data.encode(response.encoding) - response.status_code = 200 - response.url = request.url - response.request = request + data = repository.serialize(data) - return response + return CachedResponse(request=request, data=data) diff --git a/musify/api/exception.py b/musify/api/exception.py index 7cbef9b6..094bfd84 100644 --- a/musify/api/exception.py +++ b/musify/api/exception.py @@ -1,7 +1,7 @@ """ Exceptions relating to API operations. """ -from requests import Response +from aiohttp import ClientResponse from musify.exception import MusifyError @@ -11,13 +11,13 @@ class APIError(MusifyError): Exception raised for API errors. :param message: Explanation of the error. - :param response: The :py:class:`Response` related to the error. + :param response: The :py:class:`ClientResponse` related to the error. """ - def __init__(self, message: str | None = None, response: Response | None = None): + def __init__(self, message: str | None = None, response: ClientResponse | None = None): self.message = message self.response = response - formatted = f"Status code: {response.status_code} | {message}" if response else message + formatted = f"Status code: {response.status} | {message}" if response else message super().__init__(formatted) diff --git a/musify/api/request.py b/musify/api/request.py index 742a50d0..7e463e67 100644 --- a/musify/api/request.py +++ b/musify/api/request.py @@ -1,26 +1,29 @@ """ All operations relating to handling of requests to an API. """ +import contextlib import json import logging -from collections.abc import Mapping, Iterable +from collections.abc import Mapping, Callable from datetime import datetime, timedelta from http import HTTPStatus from time import sleep -from typing import Any +from typing import Any, AsyncContextManager, Self +from urllib.parse import unquote -import requests -from requests import Response, Session +import aiohttp +from aiohttp import ClientResponse, ClientSession +from yarl import URL from musify.api.authorise import APIAuthoriser -from musify.api.cache.backend.base import ResponseCache +from musify.api.cache.backend import ResponseCache from musify.api.cache.session import CachedSession -from musify.api.exception import APIError +from musify.api.exception import APIError, RequestError from musify.log.logger import MusifyLogger from musify.utils import clean_kwargs -class RequestHandler: +class RequestHandler(AsyncContextManager): """ Generic API request handler using cached responses for GET requests only. Caches GET responses for a maximum of 4 weeks by default. @@ -28,14 +31,13 @@ class RequestHandler: See :py:class:`APIAuthoriser` for more info on which params to pass to authorise requests. :param authoriser: The authoriser to use when authorising requests to the API. - :param cache: When given, set up a :py:class:`CachedSession` and attempt to use the cache - for certain request types before calling the API. + :param session: The session to use when making requests. """ - __slots__ = ("logger", "authoriser", "cache", "session", "backoff_start", "backoff_factor", "backoff_count") + __slots__ = ("logger", "_connector", "_session", "authoriser", "backoff_start", "backoff_factor", "backoff_count") @property - def backoff_final(self) -> int: + def backoff_final(self) -> float: """ The maximum wait time to retry a request in seconds until giving up when applying backoff to failed requests @@ -50,17 +52,36 @@ def timeout(self) -> int: """ return sum(self.backoff_start * self.backoff_factor ** i for i in range(self.backoff_count + 1)) - def __init__(self, authoriser: APIAuthoriser, cache: ResponseCache | None = None): + @property + def closed(self): + """Is the stored client session closed.""" + return self._session is None or self._session.closed + + @property + def session(self) -> ClientSession: + """The :py:class:`ClientSession` object if exist and it is open.""" + if not self.closed: + return self._session + + @classmethod + def create(cls, authoriser: APIAuthoriser | None = None, cache: ResponseCache | None = None, **session_kwargs): + def connector() -> ClientSession: + if cache is not None: + return CachedSession(cache=cache, **session_kwargs) + return ClientSession(**session_kwargs) + + return cls(connector=connector, authoriser=authoriser) + + def __init__(self, connector: Callable[[], ClientSession], authoriser: APIAuthoriser | None = None): # noinspection PyTypeChecker #: The :py:class:`MusifyLogger` for this object self.logger: MusifyLogger = logging.getLogger(__name__) + self._connector = connector + self._session: ClientSession | CachedSession | None = None + #: The :py:class:`APIAuthoriser` object self.authoriser = authoriser - #: The cache to use when attempting to return a cached response. - self.cache = cache - #: The :py:class:`Session` object - self.session = Session() if cache is None else CachedSession(cache=cache) #: The initial backoff time for failed requests self.backoff_start = 0.5 @@ -69,7 +90,20 @@ def __init__(self, authoriser: APIAuthoriser, cache: ResponseCache | None = None #: The maximum number of request attempts to make before giving up and raising an exception self.backoff_count = 10 - def authorise(self, force_load: bool = False, force_new: bool = False) -> dict[str, str]: + async def __aenter__(self) -> Self: + if self.closed: + self._session = self._connector() + + await self.session.__aenter__() + await self.authorise() + + return self + + async def __aexit__(self, __exc_type, __exc_value, __traceback) -> None: + await self.session.__aexit__(__exc_type, __exc_value, __traceback) + self._session = None + + async def authorise(self, force_load: bool = False, force_new: bool = False) -> dict[str, str]: """ Method for API authorisation which tests/refreshes/reauthorises as needed. @@ -79,106 +113,139 @@ def authorise(self, force_load: bool = False, force_new: bool = False) -> dict[s :return: Headers for request authorisation. :raise APIError: If the token cannot be validated. """ - headers = self.authoriser(force_load=force_load, force_new=force_new) - self.session.headers.update(headers) + if self.closed: + raise RequestError("Session is closed. Enter the API context to start a new session.") + + headers = {} + if self.authoriser is not None: + headers = await self.authoriser.authorise(force_load=force_load, force_new=force_new) + self.session.headers.update(headers) + return headers - def close(self) -> None: + async def close(self) -> None: """Close the current session. No more requests will be possible once this has been called.""" - self.session.close() + await self.session.close() - def request(self, method: str, url: str, *args, **kwargs) -> dict[str, Any]: + async def request(self, method: str, url: str, **kwargs) -> dict[str, Any]: """ Generic method for handling API requests with back-off on failed requests. See :py:func:`request` for more arguments. - :param method: method for the new :class:`Request` object: + :param method: method for the request: ``GET``, ``OPTIONS``, ``HEAD``, ``POST``, ``PUT``, ``PATCH``, or ``DELETE``. - :param url: URL for the new :class:`Request` object. + :param url: URL to call. :return: The JSON formatted response or, if JSON formatting not possible, the text response. :raise APIError: On any logic breaking error/response. """ - kwargs.pop("headers", None) - response = self._request(method=method, url=url, *args, **kwargs) backoff = self.backoff_start - while response is None or response.status_code >= 400: # error response code received - waited = False - if response is not None: - self._log_response(response=response, method=method, url=url) - self._handle_unexpected_response(response=response) - waited = self._handle_wait_time(response=response) - - if not waited and backoff < self.backoff_final: # exponential backoff - self.logger.warning(f"Request failed: retrying in {backoff} seconds...") - sleep(backoff) - backoff *= self.backoff_factor - elif not waited: # max backoff exceeded - raise APIError("Max retries exceeded") - - response = self._request(method=method, url=url, *args, **kwargs) - - return self._response_as_json(response) - - def _request( + while True: + async with self._request(method=method, url=url, **kwargs) as response: + if response is not None and response.status < 400: + data = await self._response_as_json(response) + break + + waited = None + if response is not None: + await self._log_response(response=response, method=method, url=url) + await self._handle_unexpected_response(response=response) + waited = await self._handle_wait_time(response=response) + + if not waited and backoff < self.backoff_final: # exponential backoff + self.logger.warning(f"Request failed: retrying in {backoff} seconds...") + sleep(backoff) + backoff *= self.backoff_factor + elif waited is False: # max backoff exceeded + raise APIError("Max retries exceeded") + elif waited is None: # max backoff exceeded + raise APIError("No response received") + + return data + + @contextlib.asynccontextmanager + async def _request( self, method: str, - url: str, - log_pad: int = 43, - log_extra: Iterable[str] = (), - *args, + url: str | URL, + log_message: str | list[str] = None, **kwargs - ) -> Response | None: + ) -> ClientResponse | None: """Handle logging a request, send the request, and return the response""" - log = [f"{method.upper():<7}: {url:<{log_pad}}"] - if log_extra: - log.extend(log_extra) - if len(args) > 0: - log.append(f"Args: ({', '.join(args)})") - if len(kwargs) > 0: - log.extend(f"{k.title()}: {v}" for k, v in kwargs.items()) + if isinstance(log_message, str): + log_message = [log_message] + elif log_message is None: + log_message = [] + if isinstance(self.session, CachedSession): - log.append("Cached Request") + log_message.append("Cached Request") + self.log(method=method, url=url, message=log_message, **kwargs) if not isinstance(self.session, CachedSession): - clean_kwargs(self.session.request, kwargs) + clean_kwargs(aiohttp.request, kwargs) if "headers" in kwargs: kwargs["headers"].update(self.session.headers) - self.logger.debug(" | ".join(log)) try: - return self.session.request(method=method.upper(), url=url, *args, **kwargs) - except requests.exceptions.ConnectionError as ex: + async with self.session.request(method=method.upper(), url=url, **kwargs) as response: + yield response + except aiohttp.ClientError as ex: self.logger.warning(str(ex)) - return + yield + + def log( + self, method: str, url: str | URL, message: str | list = None, level: int = logging.DEBUG, **kwargs + ) -> None: + """Format and log a request or request adjacent message to the given ``level``.""" + log: list[Any] = [] + + url = URL(url) + if url.query: + log.extend(f"{k}: {unquote(v):<4}" for k, v in sorted(url.query.items())) + if kwargs.get("params"): + log.extend(f"{k}: {v:<4}" for k, v in sorted(kwargs.pop("params").items())) + if kwargs.get("json"): + log.extend(f"{k}: {str(v):<4}" for k, v in sorted(kwargs.pop("json").items())) + if len(kwargs) > 0: + log.extend(f"{k.title()}: {str(v):<4}" for k, v in kwargs.items() if v) + if message: + log.append(message) if isinstance(message, str) else log.extend(message) - def _log_response(self, response: Response, method: str, url: str) -> None: + url = str(url.with_query(None)) + url_pad_map = [30, 40, 70, 100] + url_pad = next((pad for pad in url_pad_map if len(url) < pad), url_pad_map[-1]) + + self.logger.log( + level=level, msg=f"{method.upper():<7}: {url:<{url_pad}} | {" | ".join(str(part) for part in log)}" + ) + + async def _log_response(self, response: ClientResponse, method: str, url: str) -> None: """Log the method, URL, response text, and response headers.""" response_headers = response.headers if isinstance(response.headers, Mapping): # format headers if JSON response_headers = json.dumps(dict(response.headers), indent=2) self.logger.warning( - f"\33[91m{method.upper():<7}: {url} | Code: {response.status_code} | " + f"\33[91m{method.upper():<7}: {url} | Code: {response.status} | " f"Response text and headers follow:\n" - f"Response text:\n{response.text}\n" + f"Response text:\n{await response.text()}\n" f"Headers:\n{response_headers}\33[0m" ) - def _handle_unexpected_response(self, response: Response) -> bool: + async def _handle_unexpected_response(self, response: ClientResponse) -> bool: """Handle bad response by extracting message and handling status codes that should raise an exception.""" - message = self._response_as_json(response).get("error", {}).get("message") + message = (await self._response_as_json(response)).get("error", {}).get("message") error_message_found = message is not None if not error_message_found: - status = HTTPStatus(response.status_code) + status = HTTPStatus(response.status) message = f"{status.phrase} | {status.description}" - if 400 <= response.status_code < 408: + if 400 <= response.status < 408: raise APIError(message, response=response) return error_message_found - def _handle_wait_time(self, response: Response) -> bool: + async def _handle_wait_time(self, response: ClientResponse) -> bool: """Handle when a wait time is included in the response headers.""" if "retry-after" not in response.headers: return False @@ -195,48 +262,49 @@ def _handle_wait_time(self, response: Response) -> bool: return True @staticmethod - def _response_as_json(response: Response) -> dict[str, Any]: + async def _response_as_json(response: ClientResponse) -> dict[str, Any]: """Format the response to JSON and handle any errors""" try: - return response.json() - except json.decoder.JSONDecodeError: + data = await response.json() + return data if isinstance(data, dict) else {} + except (aiohttp.ContentTypeError, json.decoder.JSONDecodeError): return {} - def get(self, url: str, **kwargs) -> dict[str, Any]: + async def get(self, url: str, **kwargs) -> dict[str, Any]: """Sends a GET request.""" kwargs.pop("method", None) - return self.request("get", url=url, **kwargs) + return await self.request("get", url=url, **kwargs) - def post(self, url: str, **kwargs) -> dict[str, Any]: + async def post(self, url: str, **kwargs) -> dict[str, Any]: """Sends a POST request.""" kwargs.pop("method", None) - return self.request("post", url=url, **kwargs) + return await self.request("post", url=url, **kwargs) - def put(self, url: str, **kwargs) -> dict[str, Any]: + async def put(self, url: str, **kwargs) -> dict[str, Any]: """Sends a PUT request.""" kwargs.pop("method", None) - return self.request("put", url=url, **kwargs) + return await self.request("put", url=url, **kwargs) - def delete(self, url: str, **kwargs) -> dict[str, Any]: + async def delete(self, url: str, **kwargs) -> dict[str, Any]: """Sends a DELETE request.""" kwargs.pop("method", None) - return self.request("delete", url, **kwargs) + return await self.request("delete", url, **kwargs) - def options(self, url: str, **kwargs) -> dict[str, Any]: + async def options(self, url: str, **kwargs) -> dict[str, Any]: """Sends an OPTIONS request.""" kwargs.pop("method", None) - return self.request("options", url=url, **kwargs) + return await self.request("options", url=url, **kwargs) - def head(self, url: str, **kwargs) -> dict[str, Any]: + async def head(self, url: str, **kwargs) -> dict[str, Any]: """Sends a HEAD request.""" kwargs.pop("method", None) kwargs.setdefault("allow_redirects", False) - return self.request("head", url=url, **kwargs) + return await self.request("head", url=url, **kwargs) - def patch(self, url: str, **kwargs) -> dict[str, Any]: + async def patch(self, url: str, **kwargs) -> dict[str, Any]: """Sends a PATCH request.""" kwargs.pop("method", None) - return self.request("patch", url=url, **kwargs) + return await self.request("patch", url=url, **kwargs) def __copy__(self): """Do not copy handler""" @@ -245,10 +313,3 @@ def __copy__(self): def __deepcopy__(self, _: dict = None): """Do not copy handler""" return self - - def __enter__(self): - self.authorise() - return self - - def __exit__(self, exc_type, exc_val, exc_tb): - self.close() diff --git a/musify/core/printer.py b/musify/core/printer.py index ae5c8380..83e019ac 100644 --- a/musify/core/printer.py +++ b/musify/core/printer.py @@ -218,7 +218,7 @@ def get_settings(kls: type) -> None: for cls in classes: attributes |= { k: getattr(self, k) for k in cls.__dict__.keys() - if k not in ignore and isinstance(getattr(cls, k), property) + if k not in ignore and isinstance(getattr(cls, k), property) and not k.startswith("_") } return attributes diff --git a/musify/libraries/core/collection.py b/musify/libraries/core/collection.py index eeecc6ff..c199092a 100644 --- a/musify/libraries/core/collection.py +++ b/musify/libraries/core/collection.py @@ -27,7 +27,7 @@ class ItemGetterStrategy(ABC): @property @abstractmethod def name(self) -> str: - """The name of the name to assign to this ItemGetter when logging""" + """The name to assign to this ItemGetter when logging""" raise NotImplementedError @abstractmethod @@ -47,6 +47,7 @@ def get_item[T](self, collection: MusifyCollection[T]) -> T: class NameGetter(ItemGetterStrategy): """Get an item via its name for a :py:class:`MusifyCollection`""" + @property def name(self) -> str: return "name" @@ -56,6 +57,7 @@ def get_value_from_item(self, item: MusifyObject) -> str: class PathGetter(ItemGetterStrategy): """Get an item via its path for a :py:class:`MusifyCollection`""" + @property def name(self) -> str: return "path" @@ -65,6 +67,7 @@ def get_value_from_item(self, item: File) -> str: class RemoteIDGetter(ItemGetterStrategy): """Get an item via its remote ID for a :py:class:`MusifyCollection`""" + @property def name(self) -> str: return "remote ID" @@ -74,6 +77,7 @@ def get_value_from_item(self, item: RemoteResponse) -> str: class RemoteURIGetter(ItemGetterStrategy): """Get an item via its remote URI for a :py:class:`MusifyCollection`""" + @property def name(self) -> str: return "URI" @@ -83,6 +87,7 @@ def get_value_from_item(self, item: MusifyItem | RemoteResponse) -> str: class RemoteURLAPIGetter(ItemGetterStrategy): """Get an item via its remote API URL for a :py:class:`MusifyCollection`""" + @property def name(self) -> str: return "API URL" @@ -92,6 +97,7 @@ def get_value_from_item(self, item: RemoteObject) -> str: class RemoteURLEXTGetter(ItemGetterStrategy): """Get an item via its remote external URL for a :py:class:`MusifyCollection`""" + @property def name(self) -> str: return "external URL" @@ -327,7 +333,7 @@ def __getitem__( caught_exceptions.append(ex) raise MusifyKeyError( - f"Key is invalid. The following errors were thrown: {[str(ex) for ex in caught_exceptions]}" + f"Key is invalid. The following errors were thrown: {", ".join(str(ex) for ex in caught_exceptions)}" ) @staticmethod diff --git a/musify/libraries/local/library/library.py b/musify/libraries/local/library/library.py index a67f5c75..f55b4173 100644 --- a/musify/libraries/local/library/library.py +++ b/musify/libraries/local/library/library.py @@ -2,7 +2,7 @@ The core, basic library implementation which is just a simple set of folders. """ import itertools -from collections.abc import Collection, Mapping +from collections.abc import Collection, Mapping, Iterable from concurrent.futures import ThreadPoolExecutor from functools import reduce from os.path import splitext, join, exists, basename, dirname @@ -394,7 +394,9 @@ def save_playlists(self, dry_run: bool = True) -> dict[str, Result]: ## Backup/restore ########################################################################### def restore_tracks( - self, backup: Mapping[str, Mapping[str, Any]], tags: UnitIterable[LocalTrackField] = LocalTrackField.ALL + self, + backup: Iterable[Mapping[str, Any]] | Mapping[str, Mapping[str, Any]], + tags: UnitIterable[LocalTrackField] = LocalTrackField.ALL ) -> int: """ Restore track tags from a backup to loaded track objects. This does not save the updated tags. @@ -404,7 +406,11 @@ def restore_tracks( :return: The number of tracks restored """ tag_names = set(LocalTrackField.to_tags(tags)) - backup = {path.casefold(): track_map for path, track_map in backup.items()} + if isinstance(backup, Mapping): + backup = {path.casefold(): track_map for path, track_map in backup.items()} + else: + backup = {track_map["path"].casefold(): track_map for track_map in backup} + backup: Mapping[str, Mapping[str, Any]] count = 0 for track in self.tracks: @@ -413,8 +419,8 @@ def restore_tracks( continue for tag in tag_names: - if tag in track.__dict__: - track[tag] = track_map.get(tag) + if tag in track_map: + track[tag] = track_map[tag] count += 1 return count diff --git a/musify/libraries/local/library/musicbee.py b/musify/libraries/local/library/musicbee.py index a6bbb23d..54b72730 100644 --- a/musify/libraries/local/library/musicbee.py +++ b/musify/libraries/local/library/musicbee.py @@ -4,11 +4,11 @@ """ import hashlib import re -import urllib.parse from collections.abc import Iterable, Mapping, Sequence, Collection, Iterator from datetime import datetime from os.path import join, exists, normpath from typing import Any +from urllib.parse import quote, unquote from musify.file.base import File from musify.file.exception import FileDoesNotExistError @@ -376,14 +376,14 @@ def from_xml_timestamp(cls, timestamp_str: str | None) -> datetime | None: @staticmethod def to_xml_path(path: str) -> str: """Convert a standard system path to a file path as found in the MusicBee XML library file""" - return f"file://localhost/{urllib.parse.quote(path.replace('\\', '/'), safe=':/!(),;@[]+')}"\ + return f"file://localhost/{quote(path.replace('\\', '/'), safe=':/!(),;@[]+')}"\ .replace("%26", "&")\ .replace("%27", "'") @staticmethod def from_xml_path(path: str) -> str: """Clean the file paths as found in the MusicBee XML library file to a standard system path""" - return normpath(urllib.parse.unquote(path.removeprefix("file://localhost/"))) + return normpath(unquote(path.removeprefix("file://localhost/"))) def _iter_elements(self) -> Iterator[Element]: for event, element in self._iterparse: diff --git a/musify/libraries/local/playlist/m3u.py b/musify/libraries/local/playlist/m3u.py index cf984173..aca6f2c3 100644 --- a/musify/libraries/local/playlist/m3u.py +++ b/musify/libraries/local/playlist/m3u.py @@ -53,7 +53,7 @@ class M3U(LocalPlaylist[FilterDefinedList[str | File]]): For more info on this, see :py:class:`LocalTrack`. """ - __slots__ = ("_description",) + __slots__ = ("_original_paths", "_description") valid_extensions = frozenset({".m3u"}) @@ -80,14 +80,14 @@ def __init__( if exists(path): # load from file with open(path, "r", encoding="utf-8") as file: - paths = path_mapper.map_many([line.strip() for line in file], check_existence=True) + self._original_paths = path_mapper.map_many([line.strip() for line in file], check_existence=True) else: # generating a new M3U - paths = [track.path for track in tracks] + self._original_paths = [track.path for track in tracks] self._description = None super().__init__( path=path, - matcher=FilterDefinedList(values=[path.casefold() for path in paths]), + matcher=FilterDefinedList(values=[path.casefold() for path in self._original_paths]), path_mapper=path_mapper, remote_wrangler=remote_wrangler ) @@ -102,7 +102,7 @@ def load(self, tracks: Collection[LocalTrack] = ()) -> list[LocalTrack]: if tracks: # match paths from given tracks using the matcher self._match(tracks) else: # use the paths in the matcher to load tracks from scratch - self.tracks = [self._load_track(path) for path in self.matcher.values if path is not None] + self.tracks = [self._load_track(path) for path in self._original_paths if path is not None] self._limit(ignore=self.matcher.values) self._sort() @@ -118,7 +118,8 @@ def save(self, dry_run: bool = True, *_, **__) -> SyncResultM3U: :return: The results of the sync as a :py:class:`SyncResultM3U` object. """ start_paths = {path.casefold() for path in self.path_mapper.unmap_many(self._original, check_existence=False)} - os.makedirs(dirname(self.path), exist_ok=True) + if dirname(self.path): + os.makedirs(dirname(self.path), exist_ok=True) if not dry_run: with open(self.path, "w", encoding="utf-8") as file: diff --git a/musify/libraries/local/playlist/xautopf.py b/musify/libraries/local/playlist/xautopf.py index 8e2f32d6..856cb36b 100644 --- a/musify/libraries/local/playlist/xautopf.py +++ b/musify/libraries/local/playlist/xautopf.py @@ -74,7 +74,8 @@ class XAutoPF(LocalPlaylist[AutoMatcher]): :param path: Absolute path of the playlist. :param tracks: Optional. Available Tracks to search through for matches. - If none are provided, no tracks will be loaded initially + If none are provided, no tracks will be loaded initially. In order to load the playlist in this case, + you will need to call :py:meth:`load` and provide some loaded tracks. :param path_mapper: Optionally, provide a :py:class:`PathMapper` for paths stored in the playlist file. Useful if the playlist file contains relative paths and/or paths for other systems that need to be mapped to absolute, system-specific paths to be loaded and back again when saved. diff --git a/musify/libraries/local/track/track.py b/musify/libraries/local/track/track.py index 37673003..05fcad45 100644 --- a/musify/libraries/local/track/track.py +++ b/musify/libraries/local/track/track.py @@ -150,7 +150,9 @@ def date(self): return datetime.date(self._year, self._month, self._day) @date.setter - def date(self, value: datetime.date | None): + def date(self, value: str | datetime.date | None): + if isinstance(value, str): + value = datetime.datetime.strptime(value, "%Y-%m-%d") self._year = value.year if value else None self._month = value.month if value else None self._day = value.day if value else None diff --git a/musify/libraries/remote/core/api.py b/musify/libraries/remote/core/api.py index 7c65644d..8732029c 100644 --- a/musify/libraries/remote/core/api.py +++ b/musify/libraries/remote/core/api.py @@ -6,7 +6,8 @@ import logging from abc import ABC, abstractmethod from collections.abc import Collection, MutableMapping, Mapping, Sequence -from typing import Any, Self +from typing import Any, Self, AsyncContextManager + from musify.api.authorise import APIAuthoriser from musify.api.cache.backend.base import ResponseCache @@ -20,7 +21,7 @@ from musify.utils import align_string, to_collection -class RemoteAPI(ABC): +class RemoteAPI(AsyncContextManager, ABC): """ Collection of endpoints for a remote API. See :py:class:`RequestHandler` and :py:class:`APIAuthoriser` @@ -76,22 +77,32 @@ def __init__(self, authoriser: APIAuthoriser, wrangler: RemoteDataWrangler, cach #: The :py:class:`MusifyLogger` for this object self.logger: MusifyLogger = logging.getLogger(__name__) - self._setup_cache(cache) - #: A :py:class:`RemoteDataWrangler` object for processing URIs self.wrangler = wrangler + #: The :py:class:`RequestHandler` for handling authorised requests to the API - self.handler = RequestHandler(authoriser=authoriser, cache=cache) + self.handler = RequestHandler.create(authoriser=authoriser, cache=cache) #: Stores the loaded user data for the currently authorised user self.user_data: dict[str, Any] = {} + async def __aenter__(self) -> Self: + await self.handler.__aenter__() + + await self._setup_cache() + await self.load_user_data() + + return self + + async def __aexit__(self, exc_type, exc_val, exc_tb) -> None: + await self.handler.__aexit__(exc_type, exc_val, exc_tb) + @abstractmethod - def _setup_cache(self, cache: ResponseCache) -> None: - """Set up the repositories and repository getter on the given ``cache``.""" + async def _setup_cache(self) -> None: + """Set up the repositories and repository getter on the self.handler.session's cache.""" raise NotImplementedError - def authorise(self, force_load: bool = False, force_new: bool = False) -> Self: + async def authorise(self, force_load: bool = False, force_new: bool = False) -> Self: """ Main method for authorisation, tests/refreshes/reauthorises as needed @@ -101,19 +112,19 @@ def authorise(self, force_load: bool = False, force_new: bool = False) -> Self: :return: Self. :raise APIError: If the token cannot be validated. """ - self.handler.authorise(force_load=force_load, force_new=force_new) + await self.handler.authorise(force_load=force_load, force_new=force_new) return self - def close(self) -> None: + async def close(self) -> None: """Close the current session. No more requests will be possible once this has been called.""" - self.handler.close() + await self.handler.close() ########################################################################### ## Misc helpers ########################################################################### - def load_user_data(self) -> None: + async def load_user_data(self) -> None: """Load and store user data in this API object for the currently authorised user""" - self.user_data = self.get_self() + self.user_data = await self.get_self() @staticmethod def _merge_results_to_input_mapping( @@ -213,7 +224,7 @@ def print_item( ) @abstractmethod - def print_collection( + async def print_collection( self, value: str | Mapping[str, Any] | RemoteResponse | None = None, kind: RemoteIDType | None = None, @@ -237,7 +248,7 @@ def print_collection( raise NotImplementedError @abstractmethod - def get_playlist_url(self, playlist: str | Mapping[str, Any] | RemoteResponse) -> str: + async def get_playlist_url(self, playlist: str | Mapping[str, Any] | RemoteResponse) -> str: """ Determine the type of the given ``playlist`` and return its API URL. If type cannot be determined, attempt to find the playlist in the @@ -254,7 +265,7 @@ def get_playlist_url(self, playlist: str | Mapping[str, Any] | RemoteResponse) - ## Core - GET endpoints ########################################################################### @abstractmethod - def get_self(self, update_user_data: bool = True) -> dict[str, Any]: + async def get_self(self, update_user_data: bool = True) -> dict[str, Any]: """ ``GET`` - Get API response for information on current user @@ -262,7 +273,7 @@ def get_self(self, update_user_data: bool = True) -> dict[str, Any]: raise NotImplementedError @abstractmethod - def query(self, query: str, kind: RemoteObjectType, limit: int = 10) -> list[dict[str, Any]]: + async def query(self, query: str, kind: RemoteObjectType, limit: int = 10) -> list[dict[str, Any]]: """ ``GET`` - Query for items. Modify result types returned with kind parameter @@ -277,7 +288,7 @@ def query(self, query: str, kind: RemoteObjectType, limit: int = 10) -> list[dic ## Item - GET endpoints ########################################################################### @abstractmethod - def extend_items( + async def extend_items( self, response: MutableMapping[str, Any] | RemoteResponse, kind: RemoteObjectType | str | None = None, @@ -301,7 +312,7 @@ def extend_items( raise NotImplementedError @abstractmethod - def get_items( + async def get_items( self, values: APIInputValue, kind: RemoteObjectType | None = None, @@ -338,7 +349,7 @@ def get_items( raise NotImplementedError @abstractmethod - def get_tracks(self, values: APIInputValue, limit: int = 50, *args, **kwargs) -> list[dict[str, Any]]: + async def get_tracks(self, values: APIInputValue, limit: int = 50, *args, **kwargs) -> list[dict[str, Any]]: """ Wrapper for :py:meth:`get_items` which only returns Track type responses. See :py:meth:`get_items` for more info. @@ -346,7 +357,7 @@ def get_tracks(self, values: APIInputValue, limit: int = 50, *args, **kwargs) -> raise NotImplementedError @abstractmethod - def get_user_items( + async def get_user_items( self, user: str | None = None, kind: RemoteObjectType = RemoteObjectType.PLAYLIST, limit: int = 50, ) -> list[dict[str, Any]]: """ @@ -366,7 +377,7 @@ def get_user_items( ## Collection - POST endpoints ########################################################################### @abstractmethod - def create_playlist(self, name: str, *args, **kwargs) -> str: + async def create_playlist(self, name: str, *args, **kwargs) -> str: """ ``POST`` - Create an empty playlist for the current user with the given name. @@ -376,7 +387,7 @@ def create_playlist(self, name: str, *args, **kwargs) -> str: raise NotImplementedError @abstractmethod - def add_to_playlist( + async def add_to_playlist( self, playlist: str | Mapping[str, Any] | RemoteResponse, items: Collection[str], @@ -406,7 +417,7 @@ def add_to_playlist( ## Collection - DELETE endpoints ########################################################################### @abstractmethod - def delete_playlist(self, playlist: str | Mapping[str, Any] | RemoteResponse) -> str: + async def delete_playlist(self, playlist: str | Mapping[str, Any] | RemoteResponse) -> str: """ ``DELETE`` - Unfollow/delete a given playlist. WARNING: This function will destructively modify your remote playlists. @@ -421,7 +432,7 @@ def delete_playlist(self, playlist: str | Mapping[str, Any] | RemoteResponse) -> raise NotImplementedError @abstractmethod - def clear_from_playlist( + async def clear_from_playlist( self, playlist: str | Mapping[str, Any] | RemoteResponse, items: Collection[str] | None = None, @@ -445,10 +456,3 @@ def clear_from_playlist( are not all tracks or IDs. """ raise NotImplementedError - - def __enter__(self): - self.handler.__enter__() - return self - - def __exit__(self, exc_type, exc_val, exc_tb): - self.handler.__exit__(exc_type, exc_val, exc_tb) diff --git a/musify/libraries/remote/core/base.py b/musify/libraries/remote/core/base.py index 77a35cf8..092ded7b 100644 --- a/musify/libraries/remote/core/base.py +++ b/musify/libraries/remote/core/base.py @@ -5,7 +5,7 @@ """ from abc import ABC, abstractmethod from collections.abc import Mapping -from typing import Any, Self +from typing import Any, Self, AsyncContextManager from musify.api.exception import APIError from musify.core.base import MusifyItem @@ -13,7 +13,7 @@ from musify.libraries.remote.core.api import RemoteAPI -class RemoteObject[T: (RemoteAPI | None)](RemoteResponse, ABC): +class RemoteObject[T: (RemoteAPI | None)](RemoteResponse, AsyncContextManager, ABC): """ Generic base class for remote objects. Extracts key data from a remote API JSON response. @@ -24,8 +24,6 @@ class RemoteObject[T: (RemoteAPI | None)](RemoteResponse, ABC): __slots__ = ("_response", "api") __attributes_ignore__ = ("response", "api") - _url_pad = 71 - @property @abstractmethod def uri(self) -> str: @@ -65,6 +63,13 @@ def __init__(self, response: dict[str, Any], api: T = None, skip_checks: bool = self._check_type() self.refresh(skip_checks=skip_checks) + async def __aenter__(self) -> Self: + await self.api.__aenter__() + return self + + async def __aexit__(self, exc_type, exc_val, exc_tb) -> None: + await self.api.__aexit__(exc_type, exc_val, exc_tb) + @abstractmethod def _check_type(self) -> None: """ @@ -85,7 +90,7 @@ def _check_for_api(self) -> None: @classmethod @abstractmethod - def load( + async def load( cls, value: str | Mapping[str, Any] | RemoteResponse, api: RemoteAPI, *args, **kwargs ) -> Self: """ @@ -104,7 +109,7 @@ def load( raise NotImplementedError @abstractmethod - def reload(self, *args, **kwargs) -> None: + async def reload(self, *args, **kwargs) -> None: """ Reload this object from the API, calling all required endpoints to get a complete set of data for this item type. diff --git a/musify/libraries/remote/core/library.py b/musify/libraries/remote/core/library.py index ae609f27..ac5ce958 100644 --- a/musify/libraries/remote/core/library.py +++ b/musify/libraries/remote/core/library.py @@ -4,7 +4,7 @@ import logging from abc import ABC, abstractmethod from collections.abc import Collection, Mapping, Iterable -from typing import Any, Literal +from typing import Any, Literal, Self from musify.core.base import MusifyItem from musify.libraries.core.object import Track, Library, Playlist @@ -19,6 +19,14 @@ from musify.processors.filter import FilterDefinedList from musify.utils import align_string, get_max_width +type RestorePlaylistsType = ( + Library | + Collection[Playlist] | + Mapping[str, Iterable[Track]] | + Mapping[str, Iterable[str]] | + Mapping[str, Iterable[Mapping[str, Any]]] +) + class RemoteLibrary[ A: RemoteAPI, PL: RemotePlaylist, TR: RemoteTrack, AL: RemoteAlbum, AR: RemoteArtist @@ -36,6 +44,11 @@ class RemoteLibrary[ __attributes_classes__ = (Library, RemoteCollection) __attributes_ignore__ = ("api", "factory") + @property + def _log_min_width(self) -> int: + max_type_width = max(len(str(enum.name)) for enum in RemoteObjectType.all()) + return len(f"USER'S {self.api.source.upper()} ") + max_type_width + @property def factory(self) -> RemoteObjectFactory[A, PL, TR, AL, AR]: """Stores the key object classes for a remote source.""" @@ -98,7 +111,14 @@ def __init__(self, api: A, playlist_filter: Collection[str] | Filter[str] = ()): self._albums: list[AL] = [] self._artists: list[AR] = [] - def extend(self, __items: Iterable[MusifyItem], allow_duplicates: bool = True) -> None: + async def __aenter__(self) -> Self: + await self.api.__aenter__() + return self + + async def __aexit__(self, exc_type, exc_val, exc_tb) -> None: + await self.api.__aexit__(exc_type, exc_val, exc_tb) + + async def extend(self, __items: Iterable[MusifyItem], allow_duplicates: bool = True) -> None: self.logger.debug(f"Extend {self.api.source} tracks data: START") load_uris = [] @@ -118,7 +138,7 @@ def extend(self, __items: Iterable[MusifyItem], allow_duplicates: bool = True) - f"with {len(load_uris)} additional tracks \33[0m" ) - load_tracks = self.api.get_tracks(load_uris, features=True) + load_tracks = await self.api.get_tracks(load_uris, features=True) self.items.extend(map(self.factory.track, load_tracks)) self.logger.print(STAT) @@ -126,15 +146,15 @@ def extend(self, __items: Iterable[MusifyItem], allow_duplicates: bool = True) - self.logger.print() self.logger.debug(f"Extend {self.api.source} tracks data: DONE\n") - def load(self) -> None: + async def load(self) -> None: """Loads all data from the remote API for this library and log results.""" self.logger.debug(f"Load {self.api.source} library: START") self.logger.info(f"\33[1;95m ->\33[1;97m Loading {self.api.source} library \33[0m") - self.load_playlists() - self.load_tracks() - self.load_saved_albums() - self.load_saved_artists() + await self.load_playlists() + await self.load_tracks() + await self.load_saved_albums() + await self.load_saved_artists() self.logger.print(STAT) self.log_playlists() @@ -148,22 +168,21 @@ def load(self) -> None: ########################################################################### ## Load - playlists ########################################################################### - def load_playlists(self) -> None: + async def load_playlists(self) -> None: """ Load all playlists from the API that match the filter rules in this library. Also loads all their tracks. WARNING: Overwrites any currently loaded playlists. """ self.logger.debug(f"Load {self.api.source} playlists: START") - self.api.load_user_data() - responses = self.api.get_user_items(kind=RemoteObjectType.PLAYLIST) + responses = await self.api.get_user_items(kind=RemoteObjectType.PLAYLIST) responses = self._filter_playlists(responses) self.logger.info( f"\33[1;95m >\33[1;97m Getting {self._get_total_tracks(responses=responses)} " f"{self.api.source} tracks from {len(responses)} playlists \33[0m" ) - self.api.get_items(responses, kind=RemoteObjectType.PLAYLIST) + await self.api.get_items(responses, kind=RemoteObjectType.PLAYLIST) playlists = [ self.factory.playlist(response=r, skip_checks=False) @@ -190,7 +209,7 @@ def _get_total_tracks(self, responses: list[dict[str, Any]]) -> int: raise NotImplementedError def log_playlists(self) -> None: - max_width = get_max_width(self.playlists) + max_width = get_max_width(self.playlists, min_width=self._log_min_width) self.logger.stat(f"\33[1;96m{self.api.source.upper()} PLAYLISTS: \33[0m") for name, playlist in self.playlists.items(): @@ -200,15 +219,14 @@ def log_playlists(self) -> None: ########################################################################### ## Load - tracks ########################################################################### - def load_tracks(self) -> None: + async def load_tracks(self) -> None: """ Load all user's saved tracks from the API. Updates currently loaded tracks in-place or appends if not already loaded. """ self.logger.debug(f"Load user's saved {self.api.source} tracks: START") - self.api.load_user_data() - responses = self.api.get_user_items(kind=RemoteObjectType.TRACK) + responses = await self.api.get_user_items(kind=RemoteObjectType.TRACK) for response in self.logger.get_iterator(iterable=responses, desc="Processing tracks", unit="tracks"): track = self.factory.track(response=response, skip_checks=True) @@ -218,13 +236,14 @@ def load_tracks(self) -> None: current = next((item for item in self._tracks if item == track), None) if current is None: self._tracks.append(track) - else: - current._response = track.response - current.refresh(skip_checks=False) + continue + + current._response = track.response + current.refresh(skip_checks=False) self.logger.debug(f"Load user's saved {self.api.source} tracks: DONE") - def enrich_tracks(self, *_, **__) -> None: + async def enrich_tracks(self, *_, **__) -> None: """ Call API to enrich elements of track objects improving metadata coverage. This is an optionally implementable method. Defaults to doing nothing. @@ -237,7 +256,7 @@ def log_tracks(self) -> None: album_tracks = [track.uri for tracks in self.albums for track in tracks] in_albums = len([track for track in self.tracks if track.uri in album_tracks]) - width = get_max_width(self.playlists) + width = get_max_width(self.playlists, min_width=self._log_min_width) self.logger.stat( f"\33[1;96m{"USER'S " + self.api.source.upper() + " TRACKS":<{width}}\33[1;0m |" f"\33[92m{in_playlists:>7} in playlists \33[0m|" @@ -248,15 +267,14 @@ def log_tracks(self) -> None: ########################################################################### ## Load - albums ########################################################################### - def load_saved_albums(self) -> None: + async def load_saved_albums(self) -> None: """ Load all user's saved albums from the API. Updates currently loaded albums in-place or appends if not already loaded. """ self.logger.debug(f"Load user's saved {self.api.source} albums: START") - self.api.load_user_data() - responses = self.api.get_user_items(kind=RemoteObjectType.ALBUM) + responses = await self.api.get_user_items(kind=RemoteObjectType.ALBUM) for response in self.logger.get_iterator(iterable=responses, desc="Processing albums", unit="albums"): album = self.factory.album(response=response, skip_checks=True) @@ -273,7 +291,7 @@ def load_saved_albums(self) -> None: self.logger.debug(f"Load user's saved {self.api.source} albums: DONE") - def enrich_saved_albums(self, *_, **__) -> None: + async def enrich_saved_albums(self, *_, **__) -> None: """ Call API to enrich elements of user's saved album objects improving metadata coverage. This is an optionally implementable method. Defaults to doing nothing. @@ -282,7 +300,7 @@ def enrich_saved_albums(self, *_, **__) -> None: def log_albums(self) -> None: """Log stats on currently loaded albums""" - width = get_max_width(self.playlists) + width = get_max_width(self.playlists, min_width=self._log_min_width) self.logger.stat( f"\33[1;96m{"USER'S " + self.api.source.upper() + " ALBUMS":<{width}}\33[1;0m |" f"\33[92m{sum(len(album.tracks) for album in self.albums):>7} album tracks \33[0m|" @@ -293,28 +311,28 @@ def log_albums(self) -> None: ########################################################################### ## Load - artists ########################################################################### - def load_saved_artists(self) -> None: + async def load_saved_artists(self) -> None: """ Load all user's saved artists from the API. Updates currently loaded artists in-place or appends if not already loaded. """ self.logger.debug(f"Load user's saved {self.api.source} artists: START") - self.api.load_user_data() - responses = self.api.get_user_items(kind=RemoteObjectType.ARTIST) + responses = await self.api.get_user_items(kind=RemoteObjectType.ARTIST) for response in self.logger.get_iterator(iterable=responses, desc="Processing artists", unit="artists"): artist = self.factory.artist(response=response, skip_checks=True) current = next((item for item in self._artists if item == artist), None) if current is None: self._artists.append(artist) - else: - current._response = artist.response - current.refresh(skip_checks=True) + continue + + current._response = artist.response + current.refresh(skip_checks=True) self.logger.debug(f"Load user's saved {self.api.source} artists: DONE") - def enrich_saved_artists(self, *_, **__) -> None: + async def enrich_saved_artists(self, *_, **__) -> None: """ Call API to enrich elements of user's saved artist objects improving metadata coverage. This is an optionally implementable method. Defaults to doing nothing. @@ -323,7 +341,7 @@ def enrich_saved_artists(self, *_, **__) -> None: def log_artists(self) -> None: """Log stats on currently loaded artists""" - width = get_max_width(self.playlists) + width = get_max_width(self.playlists, min_width=self._log_min_width) self.logger.stat( f"\33[1;96m{"USER'S " + self.api.source.upper() + " ARTISTS":<{width}}\33[1;0m |" f"\33[92m{sum(len(artist.tracks) for artist in self.artists):>7} artist tracks \33[0m|" @@ -341,11 +359,7 @@ def backup_playlists(self) -> dict[str, list[str]]: """ return {name: [track.uri for track in pl] for name, pl in self.playlists.items()} - def restore_playlists( - self, - playlists: Library | Collection[Playlist] | Mapping[str, Iterable[Track]] | Mapping[str, Iterable[str]], - dry_run: bool = True, - ) -> None: + async def restore_playlists(self, playlists: RestorePlaylistsType, dry_run: bool = True) -> None: """ Restore playlists from a backup to loaded playlist objects. @@ -366,7 +380,15 @@ def restore_playlists( ): # get URIs from playlists in map values playlists = {name: [item.uri for item in pl] for name, pl in playlists.items()} - elif not isinstance(playlists, Mapping) and isinstance(playlists, Collection): + elif isinstance(playlists, Mapping): + # get URIs from playlists in collection + playlists = { + name: + [t["uri"] if isinstance(t, Mapping) else t for t in tracks["tracks"]] + if isinstance(tracks, Mapping) else tracks + for name, tracks in playlists.items() + } + elif isinstance(playlists, Collection): # get URIs from playlists in collection playlists = {pl.name: [track.uri for track in pl] for pl in playlists} playlists: Mapping[str, Iterable[str]] @@ -375,7 +397,7 @@ def restore_playlists( uri_get = [uri for uri_list in playlists.values() for uri in uri_list if uri not in uri_tracks] if uri_get: - tracks_data = self.api.get_tracks(uri_get, features=False) + tracks_data = await self.api.get_tracks(uri_get, features=False) tracks = list(map(self.factory.track, tracks_data)) uri_tracks |= {track.uri: track for track in tracks} @@ -385,12 +407,12 @@ def restore_playlists( continue if not playlist: # new playlist given, create it on remote first # noinspection PyArgumentList - playlist = self.factory.playlist.create(name=name) + playlist = await self.factory.playlist.create(name=name) playlist._tracks = [uri_tracks.get(uri) for uri in uri_list] self.playlists[name] = playlist - def sync( + async def sync( self, playlists: Library | Mapping[str, Iterable[MusifyItem]] | Collection[Playlist] | None = None, kind: Literal["new", "refresh", "sync"] = "new", @@ -449,8 +471,8 @@ def sync( continue # noinspection PyArgumentList - self.playlists[name] = self.factory.playlist.create(name=name) - results[name] = self.playlists[name].sync(items=pl, kind=kind, reload=reload, dry_run=dry_run) + self.playlists[name] = await self.factory.playlist.create(name=name) + results[name] = await self.playlists[name].sync(items=pl, kind=kind, reload=reload, dry_run=dry_run) self.logger.print() self.logger.debug(f"Sync {self.api.source} playlists: DONE\n") diff --git a/musify/libraries/remote/core/object.py b/musify/libraries/remote/core/object.py index 77a23687..eed25ed5 100644 --- a/musify/libraries/remote/core/object.py +++ b/musify/libraries/remote/core/object.py @@ -57,7 +57,7 @@ def _total(self) -> int: @classmethod @abstractmethod - def load( + async def load( cls, value: str | Mapping[str, Any] | Self, api: RemoteAPI, items: Iterable[T] = (), *args, **kwargs ) -> Self: """ @@ -157,7 +157,7 @@ def writeable(self) -> bool: return self.api.user_id == self.owner_id @classmethod - def create(cls, api: RemoteAPI, name: str, public: bool = True, collaborative: bool = False) -> Self: + async def create(cls, api: RemoteAPI, name: str, public: bool = True, collaborative: bool = False) -> Self: """ Create an empty playlist for the current user with the given name and initialise and return a new RemotePlaylist object from this new playlist. @@ -168,20 +168,19 @@ def create(cls, api: RemoteAPI, name: str, public: bool = True, collaborative: b :param collaborative: Set playlist to collaborative i.e. other users may edit the playlist. :return: :py:class:`RemotePlaylist` object for the generated playlist. """ - url = api.create_playlist(name=name, public=public, collaborative=collaborative) - return cls(response=api.handler.get(url), api=api) + url = await api.create_playlist(name=name, public=public, collaborative=collaborative) + return cls(response=await api.handler.get(url), api=api) - def delete(self) -> None: + async def delete(self) -> None: """ Unfollow/delete the current playlist and clear the stored response for this object. WARNING: This function will destructively modify your remote playlists. """ self._check_for_api() - - self.api.delete_playlist(self.url) + await self.api.delete_playlist(self.url) self.response.clear() - def sync( + async def sync( self, items: Iterable[MusifyItem] = (), kind: PLAYLIST_SYNC_KINDS = "new", @@ -219,19 +218,19 @@ def sync( # process the remote playlist. when dry_run, mock the results if kind == "refresh": # remove all items from the remote playlist - removed = self.api.clear_from_playlist(self.url) if not dry_run else len(uri_remote) + removed = await self.api.clear_from_playlist(self.url) if not dry_run else len(uri_remote) uri_add = uri_initial uri_unchanged = [] elif kind == "sync": # remove items not present in the current list from the remote playlist uri_clear = [uri for uri in uri_remote if uri not in uri_initial] - removed = self.api.clear_from_playlist(self.url, items=uri_clear) if not dry_run else len(uri_clear) + removed = await self.api.clear_from_playlist(self.url, items=uri_clear) if not dry_run else len(uri_clear) uri_unchanged = [uri for uri in uri_remote if uri in uri_initial] added = len(uri_add) if not dry_run: - added = self.api.add_to_playlist(self.url, items=uri_add, skip_dupes=kind != "refresh") + added = await self.api.add_to_playlist(self.url, items=uri_add, skip_dupes=kind != "refresh") if reload: # reload the current playlist object from remote - self.reload(extend_tracks=True) + await self.reload(extend_tracks=True) return SyncResultRemotePlaylist( start=len(uri_remote), diff --git a/musify/libraries/remote/core/processors/check.py b/musify/libraries/remote/core/processors/check.py index 2e41604a..5c98c1c0 100644 --- a/musify/libraries/remote/core/processors/check.py +++ b/musify/libraries/remote/core/processors/check.py @@ -9,7 +9,7 @@ from collections.abc import Sequence, Collection, Iterator from concurrent.futures import ThreadPoolExecutor from dataclasses import dataclass, field -from typing import Any +from typing import Any, Self from musify import PROGRAM_NAME from musify.core.base import MusifyItemSettable @@ -135,41 +135,47 @@ def __init__( #: The final list of items skipped by the checker self._final_skipped: list[MusifyItemSettable] = [] - def _check_api(self): + async def __aenter__(self) -> Self: + await self.api.__aenter__() + return self + + async def __aexit__(self, exc_type, exc_val, exc_tb) -> None: + await self.api.__aexit__(exc_type, exc_val, exc_tb) + + async def _check_api(self) -> None: """Check if the API token has expired and refresh as necessary""" - if not self.api.handler.authoriser.test_token(): # check if token has expired + if not await self.api.handler.authoriser.test_token(): # check if token has expired self.logger.info_extra("\33[93mAPI token has expired, re-authorising... \33[0m") - self.api.authorise() + await self.api.authorise() - def _create_playlist(self, collection: MusifyCollection[MusifyItemSettable]) -> None: + async def _create_playlist(self, collection: MusifyCollection[MusifyItemSettable]) -> None: """Create a temporary playlist, store its URL for later unfollowing, and add all given URIs.""" - self._check_api() + await self._check_api() uris = [item.uri for item in collection if item.has_uri] if not uris: return - url = self.api.create_playlist(collection.name, public=False) + url = await self.api.create_playlist(collection.name, public=False) self._playlist_name_urls[collection.name] = url self._playlist_name_collection[collection.name] = collection - self.api.add_to_playlist(url, items=uris, skip_dupes=False) + await self.api.add_to_playlist(url, items=uris, skip_dupes=False) - def _delete_playlists(self) -> None: + async def _delete_playlists(self) -> None: """Delete all temporary playlists stored and clear stored playlists and collections""" - self._check_api() + await self._check_api() self.logger.info_extra(f"\33[93mDeleting {len(self._playlist_name_urls)} temporary playlists... \33[0m") for url in self._playlist_name_urls.values(): # delete playlists - self.api.delete_playlist(url) + await self.api.delete_playlist(url) self._playlist_name_urls.clear() self._playlist_name_collection.clear() - def __call__(self, *args, **kwargs) -> ItemCheckResult | None: - return self.check(*args, **kwargs) - - def check[T: MusifyItemSettable](self, collections: Collection[MusifyCollection[T]]) -> ItemCheckResult[T] | None: + async def check[T: MusifyItemSettable]( + self, collections: Collection[MusifyCollection[T]] + ) -> ItemCheckResult[T] | None: """ Run the checker for the given ``collections``. @@ -189,36 +195,30 @@ def check[T: MusifyItemSettable](self, collections: Collection[MusifyCollection[ total = len(collections) pages_total = (total // self.interval) + (total % self.interval > 0) - bar = self.logger.get_iterator(total=total, desc="Creating temp playlists", unit="playlists") + bar = self.logger.get_iterator(iter(collections), desc="Creating temp playlists", unit="playlists") self._skip = False self._quit = False - collections_iter = (collection for collection in collections) for page in range(1, pages_total + 1): try: - for count, collection in enumerate(collections_iter, 1): - self._create_playlist(collection=collection) - if tqdm is not None: - bar.update(1) + for count, collection in enumerate(bar, 1): + await self._create_playlist(collection=collection) if count >= self.interval: break - self._pause(page=page, total=pages_total) + await self._pause(page=page, total=pages_total) if not self._quit: # still run if skip is True - self._check_uri() + await self._check_uri() except KeyboardInterrupt: self.logger.error("User triggered exit with KeyboardInterrupt") self._quit = True finally: - self._delete_playlists() + await self._delete_playlists() if self._quit or self._skip: # quit check break - if tqdm is not None: - bar.close() - result = self._finalise() if not self._quit else None self.logger.debug("Checking items: DONE\n") return result @@ -250,7 +250,7 @@ def _finalise(self) -> ItemCheckResult: ########################################################################### ## Pause to check items in current temp playlists ########################################################################### - def _pause(self, page: int, total: int) -> None: + async def _pause(self, page: int, total: int) -> None: """ Initial pause after the ``interval`` limit of playlists have been created. @@ -305,8 +305,8 @@ def _pause(self, page: int, total: int) -> None: print() elif self.api.wrangler.validate_id_type(current_input): # print URL/URI/ID result - self._check_api() - self.api.print_collection(current_input) + await self._check_api() + await self.api.print_collection(current_input) elif current_input != "": self.logger.warning("Input not recognised.") @@ -314,15 +314,15 @@ def _pause(self, page: int, total: int) -> None: ########################################################################### ## Match items user has added or removed ########################################################################### - def _check_uri(self) -> None: + async def _check_uri(self) -> None: """Run operations to check that URIs are assigned to all the items in the current list of collections.""" skip_hold = self._skip self._skip = False for name, collection in self._playlist_name_collection.items(): - self.matcher.log_messages([name, f"{len(collection):>6} total items"], pad='>') + self.matcher.log([name, f"{len(collection):>6} total items"], pad='>') while True: - self._match_to_remote(name=name) + await self._match_to_remote(name=name) self._match_to_input(name=name) if not self._remaining: break @@ -330,9 +330,9 @@ def _check_uri(self) -> None: unavailable = tuple(item for item in collection if item.has_uri is False) skipped = tuple(item for item in collection if item.has_uri is None) - self.matcher.log_messages([name, f"{len(self._switched):>6} items switched"], pad='<') - self.matcher.log_messages([name, f"{len(unavailable):>6} items unavailable"]) - self.matcher.log_messages([name, f"{len(skipped):>6} items skipped"]) + self.matcher.log([name, f"{len(self._switched):>6} items switched"], pad='<') + self.matcher.log([name, f"{len(unavailable):>6} items unavailable"]) + self.matcher.log([name, f"{len(skipped):>6} items skipped"]) self._final_switched += self._switched self._final_unavailable += unavailable @@ -344,12 +344,12 @@ def _check_uri(self) -> None: self._skip = skip_hold - def _match_to_remote(self, name: str) -> None: + async def _match_to_remote(self, name: str) -> None: """ Check the current temporary playlist given by ``name`` and attempt to match the source list of items to any modifications the user has made. """ - self._check_api() + await self._check_api() self.logger.info( "\33[1;95m ->\33[1;97m Checking for changes to items in " @@ -359,7 +359,7 @@ def _match_to_remote(self, name: str) -> None: source = self._playlist_name_collection[name] source_valid = [item for item in source if item.has_uri] - remote_response = self.api.get_items(self._playlist_name_urls[name], extend=True)[0] + remote_response = next(iter(await self.api.get_items(self._playlist_name_urls[name], extend=True))) remote = self.factory.playlist(response=remote_response).items remote_valid = [item for item in remote if item.has_uri] @@ -369,7 +369,7 @@ def _match_to_remote(self, name: str) -> None: if len(added) + len(removed) + len(missing) == 0: if len(source_valid) == len(remote_valid): - self.matcher.log_messages([name, "Playlist unchanged and no missing URIs, skipping match"]) + self.matcher.log([name, "Playlist unchanged and no missing URIs, skipping match"]) return # if item collection originally contained duplicate URIS and one or more of the duplicates were removed, @@ -379,10 +379,10 @@ def _match_to_remote(self, name: str) -> None: if remote_counts.get(uri) != count: missing.extend([item for item in source_valid if item.uri == uri]) - self.matcher.log_messages([name, f"{len(added):>6} items added"]) - self.matcher.log_messages([name, f"{len(removed):>6} items removed"]) - self.matcher.log_messages([name, f"{len(missing):>6} items in source missing URI"]) - self.matcher.log_messages([name, f"{len(source_valid) - len(remote_valid):>6} total difference"]) + self.matcher.log([name, f"{len(added):>6} items added"]) + self.matcher.log([name, f"{len(removed):>6} items removed"]) + self.matcher.log([name, f"{len(missing):>6} items in source missing URI"]) + self.matcher.log([name, f"{len(source_valid) - len(remote_valid):>6} total difference"]) remaining = removed + missing count_start = len(remaining) @@ -406,8 +406,8 @@ def _match_to_remote(self, name: str) -> None: self._remaining = removed + missing count_final = len(self._remaining) - self.matcher.log_messages([name, f"{count_start - count_final:>6} items switched"]) - self.matcher.log_messages([name, f"{count_final:>6} items still not found"]) + self.matcher.log([name, f"{count_start - count_final:>6} items switched"]) + self.matcher.log([name, f"{count_final:>6} items still not found"]) def _match_to_input(self, name: str) -> None: """ @@ -419,6 +419,7 @@ def _match_to_input(self, name: str) -> None: header = [f"\t\33[1;94m{name}:\33[91m The following items were removed and/or matches were not found. \33[0m"] options = { + f"<{self.api.source} ID/URL/URI>": "Assign the given ID/URL/URI to the item", "u": f"Mark item as 'Unavailable on {self.api.source}'", "n": f"Leave item with no URI. ({PROGRAM_NAME} will still attempt to find this item at the next run)", "a": "Add in addition to 'u' or 'n' options to apply this setting to all items in this playlist", @@ -433,13 +434,13 @@ def _match_to_input(self, name: str) -> None: help_text = self._format_help_text(options=options, header=header) help_text += "OR enter a custom URI/URL/ID for this item\n" - self.matcher.log_messages([name, f"Getting user input for {len(self._remaining)} items"]) + self.matcher.log([name, f"Getting user input for {len(self._remaining)} items"]) max_width = get_max_width({item.name for item in self._remaining}) print("\n" + help_text) for item in self._remaining.copy(): while current_input is not None and item in self._remaining: # while item not matched or skipped - self.matcher.log_messages([name, f"{len(self._remaining):>6} remaining items"]) + self.matcher.log([name, f"{len(self._remaining):>6} remaining items"]) if 'a' not in current_input: current_input = self._get_user_input(align_string(item.name, max_width=max_width)) @@ -453,24 +454,24 @@ def _match_to_input(self, name: str) -> None: def _match_item_to_input(self, name: str, item: MusifyItemSettable, current_input: str) -> str | None: if current_input.casefold() == 's' or current_input.casefold() == 'q': # quit/skip - self.matcher.log_messages([name, "Skipping all loops"], pad="<") + self.matcher.log([name, "Skipping all loops"], pad="<") self._quit = current_input.casefold() == 'q' or self._quit self._skip = current_input.casefold() == 's' or self._skip self._remaining.clear() return elif current_input.casefold().replace('a', '') == 'u': # mark item as unavailable - self.matcher.log_messages([name, "Marking as unavailable"], pad="<") + self.matcher.log([name, "Marking as unavailable"], pad="<") item.uri = self.api.wrangler.unavailable_uri_dummy self._remaining.remove(item) elif current_input.casefold().replace('a', '') == 'n': # leave item without URI and unprocessed - self.matcher.log_messages([name, "Skipping"], pad="<") + self.matcher.log([name, "Skipping"], pad="<") item.uri = None self._remaining.remove(item) elif current_input.casefold() == 'r': # return to former 'while' loop - self.matcher.log_messages([name, "Refreshing playlist metadata and restarting loop"]) + self.matcher.log([name, "Refreshing playlist metadata and restarting loop"]) return elif current_input.casefold() == 'p' and hasattr(item, "path"): # print item path @@ -481,7 +482,7 @@ def _match_item_to_input(self, name: str, item: MusifyItemSettable, current_inpu current_input, kind=RemoteObjectType.TRACK, type_out=RemoteIDType.URI ) - self.matcher.log_messages([name, f"Updating URI: {item.uri} -> {uri}"], pad="<") + self.matcher.log([name, f"Updating URI: {item.uri} -> {uri}"], pad="<") item.uri = uri self._switched.append(item) diff --git a/musify/libraries/remote/core/processors/search.py b/musify/libraries/remote/core/processors/search.py index 0896ea80..be2f12df 100644 --- a/musify/libraries/remote/core/processors/search.py +++ b/musify/libraries/remote/core/processors/search.py @@ -6,9 +6,8 @@ """ import logging from collections.abc import Mapping, Sequence, Iterable, Collection -from concurrent.futures import ThreadPoolExecutor from dataclasses import dataclass, field -from typing import Any +from typing import Any, Self from musify.core.base import MusifyObject, MusifyItemSettable from musify.core.enum import TagField, TagFields as Tag @@ -114,28 +113,35 @@ def __init__(self, matcher: ItemMatcher, object_factory: RemoteObjectFactory): #: The :py:class:`RemoteObjectFactory` to use when creating new remote objects. self.factory = object_factory - def _get_results( + async def __aenter__(self) -> Self: + await self.api.__aenter__() + return self + + async def __aexit__(self, exc_type, exc_val, exc_tb) -> None: + await self.api.__aexit__(exc_type, exc_val, exc_tb) + + async def _get_results( self, item: MusifyObject, kind: RemoteObjectType, settings: SearchConfig ) -> list[dict[str, Any]] | None: """Query the API to get results for the current item based on algorithm settings""" self.matcher.clean_tags(item) - def execute_query(keys: Iterable[TagField]) -> tuple[list[dict[str, Any]], str]: + async def execute_query(keys: Iterable[TagField]) -> tuple[list[dict[str, Any]], str]: """Generate and execute the query against the API for the given item's cleaned ``keys``""" attributes = [item.clean_tags.get(key) for key in keys] q = " ".join(str(attr) for attr in attributes if attr) - return self.api.query(q, kind=kind, limit=settings.result_count), q + return await self.api.query(q, kind=kind, limit=settings.result_count), q - results, query = execute_query(settings.search_fields_1) + results, query = await execute_query(settings.search_fields_1) if not results and settings.search_fields_2: - results, query = execute_query(settings.search_fields_2) + results, query = await execute_query(settings.search_fields_2) if not results and settings.search_fields_3: - results, query = execute_query(settings.search_fields_3) + results, query = await execute_query(settings.search_fields_3) if results: - self.matcher.log_messages([item.name, f"Query: {query}", f"{len(results)} results"]) + self.matcher.log([item.name, f"Query: {query}", f"{len(results)} results"]) return results - self.matcher.log_messages([item.name, f"Query: {query}", "Match failed: No results."], pad="<") + self.matcher.log([item.name, f"Query: {query}", "Match failed: No results."], pad="<") def _log_results(self, results: Mapping[str, ItemSearchResult]) -> None: """Logs the final results of the ItemSearcher""" @@ -187,10 +193,7 @@ def _determine_remote_object_type(obj: MusifyObject) -> RemoteObjectType: return obj.kind raise MusifyAttributeError(f"Given object does not specify a RemoteObjectType: {obj.__class__.__name__}") - def __call__(self, *args, **kwargs) -> dict[str, ItemSearchResult]: - return self.search(*args, **kwargs) - - def search[T: MusifyItemSettable]( + async def search[T: MusifyItemSettable]( self, collections: Collection[MusifyCollection[T]] ) -> dict[str, ItemSearchResult[T]]: """ @@ -211,34 +214,33 @@ def search[T: MusifyItemSettable]( ) bar = self.logger.get_iterator(iterable=collections, desc="Searching", unit=f"{kind}s") - with ThreadPoolExecutor(thread_name_prefix="searcher-main") as executor: - search_results = dict(executor.map(lambda coll: (coll.name, self._search_collection(coll)), bar)) + search_results = {coll.name: await self._search_collection(coll) for coll in bar} self.logger.print() self._log_results(search_results) self.logger.debug("Searching: DONE\n") return search_results - def _search_collection[T: MusifyItemSettable](self, collection: MusifyCollection) -> ItemSearchResult[T]: + async def _search_collection[T: MusifyItemSettable](self, collection: MusifyCollection) -> ItemSearchResult[T]: kind = collection.__class__.__name__ skipped = tuple(item for item in collection if item.has_uri is not None) if len(skipped) == len(collection): - self.matcher.log_messages([collection.name, "Skipping search, no items to search"], pad='<') + self.matcher.log([collection.name, "Skipping search, no items to search"], pad='<') if getattr(collection, "compilation", True) is False: - self.matcher.log_messages([collection.name, "Searching for collection as a unit"], pad='>') - self._search_collection_unit(collection=collection) + self.matcher.log([collection.name, "Searching for collection as a unit"], pad='>') + await self._search_collection_unit(collection=collection) missing = [item for item in collection.items if item.has_uri is None] if missing: - self.matcher.log_messages( + self.matcher.log( [collection.name, f"Searching for {len(missing)} unmatched items in this {kind}"] ) - self._search_items(collection=collection) + await self._search_items(collection=collection) else: - self.matcher.log_messages([collection.name, "Searching for distinct items in collection"], pad='>') - self._search_items(collection=collection) + self.matcher.log([collection.name, "Searching for distinct items in collection"], pad='>') + await self._search_items(collection=collection) return ItemSearchResult( matched=tuple(item for item in collection if item.has_uri and item not in skipped), @@ -246,14 +248,14 @@ def _search_collection[T: MusifyItemSettable](self, collection: MusifyCollection skipped=skipped ) - def _get_item_match[T: MusifyItemSettable]( + async def _get_item_match[T: MusifyItemSettable]( self, item: T, match_on: UnitIterable[TagField] | None = None, results: Iterable[T] = None ) -> tuple[T, T | None]: kind = self._determine_remote_object_type(item) search_config = self.search_settings[kind] if results is None: - responses = self._get_results(item, kind=kind, settings=search_config) + responses = await self._get_results(item, kind=kind, settings=search_config) # noinspection PyTypeChecker results: Iterable[T] = map(self.factory[kind], responses or ()) @@ -268,16 +270,17 @@ def _get_item_match[T: MusifyItemSettable]( return item, result - def _search_items[T: MusifyItemSettable](self, collection: Iterable[T]) -> None: + async def _search_items[T: MusifyItemSettable](self, collection: Iterable[T]) -> None: """Search for matches on individual items in an item collection that have ``None`` on ``has_uri`` attribute""" - with ThreadPoolExecutor(thread_name_prefix="searcher-items") as executor: - matches = executor.map(self._get_item_match, filter(lambda i: i.has_uri is None, collection)) + for item in collection: + if item.has_uri is not None: + continue - for item, match in matches: + item, match = await self._get_item_match(item) if match and match.has_uri: item.uri = match.uri - def _search_collection_unit[T: MusifyItemSettable](self, collection: MusifyCollection[T]) -> None: + async def _search_collection_unit[T: MusifyItemSettable](self, collection: MusifyCollection[T]) -> None: """ Search for matches on an entire collection as a whole i.e. search for just the collection and not its distinct items. @@ -288,10 +291,10 @@ def _search_collection_unit[T: MusifyItemSettable](self, collection: MusifyColle kind = self._determine_remote_object_type(collection) search_config = self.search_settings[kind] - responses = self._get_results(collection, kind=kind, settings=search_config) + responses = await self._get_results(collection, kind=kind, settings=search_config) key = self.api.collection_item_map[kind] for response in responses: - self.api.extend_items(response, kind=kind, key=key) + await self.api.extend_items(response, kind=kind, key=key) # noinspection PyProtectedMember,PyTypeChecker # order to prioritise results that are closer to the item count of the input collection @@ -309,13 +312,13 @@ def _search_collection_unit[T: MusifyItemSettable](self, collection: MusifyColle if not result: return - with ThreadPoolExecutor(thread_name_prefix="searcher-collection") as executor: - matches = executor.map( - lambda item: self._get_item_match(item, match_on=[Tag.TITLE], results=result.items), - filter(lambda i: i.has_uri is None, collection) - ) + # check all items in the collection have been matched + # get matches on those that are still missing matches + for item in collection: + if item.has_uri is not None: + continue - for item, match in matches: + item, match = await self._get_item_match(item, match_on=[Tag.TITLE], results=result.items) if match and match.has_uri: item.uri = match.uri diff --git a/musify/libraries/remote/core/types.py b/musify/libraries/remote/core/types.py index 27db988d..be317f4a 100644 --- a/musify/libraries/remote/core/types.py +++ b/musify/libraries/remote/core/types.py @@ -2,13 +2,12 @@ All type hints to use throughout the module. """ from collections.abc import MutableMapping -from typing import Any, TypeVar +from typing import Any from musify.libraries.remote.core import RemoteResponse from musify.types import UnitMutableSequence, UnitSequence -UT = TypeVar('UT') -APIInputValue = ( +type APIInputValue = ( UnitMutableSequence[str] | UnitMutableSequence[MutableMapping[str, Any]] | UnitSequence[RemoteResponse] diff --git a/musify/libraries/remote/spotify/api/api.py b/musify/libraries/remote/spotify/api/api.py index e9b11fdd..a56a0f90 100644 --- a/musify/libraries/remote/spotify/api/api.py +++ b/musify/libraries/remote/spotify/api/api.py @@ -6,11 +6,13 @@ import base64 from collections.abc import Iterable from copy import deepcopy -from urllib.parse import urlparse + +from yarl import URL from musify import PROGRAM_NAME from musify.api.authorise import APIAuthoriser from musify.api.cache.backend.base import ResponseCache, ResponseRepository +from musify.api.cache.session import CachedSession from musify.api.exception import APIError from musify.libraries.remote.spotify.api.cache import SpotifyRequestSettings, SpotifyPaginatedRequestSettings from musify.libraries.remote.spotify.api.item import SpotifyAPIItems @@ -27,8 +29,6 @@ "url": f"{URL_AUTH}/api/token", "data": { "grant_type": "authorization_code", - "code": None, - "redirect_uri": None, }, "headers": { "content-type": "application/x-www-form-urlencoded", @@ -40,7 +40,6 @@ "params": { "client_id": "{client_id}", "response_type": "code", - "redirect_uri": None, "state": PROGRAM_NAME, "scope": "{scopes}", "show_dialog": False, @@ -50,7 +49,6 @@ "url": f"{URL_AUTH}/api/token", "data": { "grant_type": "refresh_token", - "refresh_token": None, }, "headers": { "content-type": "application/x-www-form-urlencoded", @@ -82,20 +80,20 @@ class SpotifyAPI(SpotifyAPIMisc, SpotifyAPIItems, SpotifyAPIPlaylists): def user_id(self) -> str | None: """ID of the currently authenticated user""" if not self.user_data: - try: - self.user_data = self.get_self() - except APIError: - return None + raise APIError( + "User data not set. Either set explicitly or enter the " + f"{self.__class__.__name__} context to set automatically." + ) return self.user_data["id"] @property def user_name(self) -> str | None: """Name of the currently authenticated user""" if not self.user_data: - try: - self.user_data = self.get_self() - except APIError: - return None + raise APIError( + "User data not set. Either set explicitly or enter the " + f"{self.__class__.__name__} context to set automatically." + ) return self.user_data["display_name"] def __init__( @@ -114,7 +112,7 @@ def __init__( "scopes": " ".join(scopes), "url": wrangler.url_api } - auth_kwargs = merge_maps(deepcopy(SPOTIFY_API_AUTH_ARGS), auth_kwargs) + auth_kwargs = merge_maps(deepcopy(SPOTIFY_API_AUTH_ARGS), auth_kwargs, extend=False, overwrite=True) safe_format_map(auth_kwargs, format_map=format_map) auth_kwargs.pop("name", None) @@ -122,10 +120,11 @@ def __init__( super().__init__(authoriser=authoriser, wrangler=wrangler, cache=cache) - def _setup_cache(self, cache: ResponseCache) -> None: - if cache is None: + async def _setup_cache(self) -> None: + if not isinstance(self.handler.session, CachedSession): return + cache = self.handler.session.cache cache.repository_getter = self._get_cache_repository cache.create_repository(SpotifyRequestSettings(name="tracks")) @@ -146,9 +145,11 @@ def _setup_cache(self, cache: ResponseCache) -> None: cache.create_repository(SpotifyRequestSettings(name="chapters")) cache.create_repository(SpotifyPaginatedRequestSettings(name="audiobook_chapters")) + await cache + @staticmethod - def _get_cache_repository(cache: ResponseCache, url: str) -> ResponseRepository | None: - path = urlparse(url).path + def _get_cache_repository(cache: ResponseCache, url: str | URL) -> ResponseRepository | None: + path = URL(url).path path_split = [part.replace("-", "_") for part in path.split("/")[2:]] if len(path_split) < 3: diff --git a/musify/libraries/remote/spotify/api/base.py b/musify/libraries/remote/spotify/api/base.py index 31d79da7..f8301b0f 100644 --- a/musify/libraries/remote/spotify/api/base.py +++ b/musify/libraries/remote/spotify/api/base.py @@ -3,7 +3,8 @@ """ from abc import ABC from typing import Any -from urllib.parse import parse_qs, urlparse, urlencode, quote, urlunparse + +from yarl import URL from musify.libraries.remote.core.api import RemoteAPI from musify.libraries.remote.core.enum import RemoteObjectType @@ -26,13 +27,13 @@ def _get_key(key: str | RemoteObjectType | None) -> str | None: return key.lower().rstrip("s") + "s" @staticmethod - def format_next_url(url: str, offset: int = 0, limit: int = 20) -> str: + def format_next_url(url: str | URL, offset: int = 0, limit: int = 20) -> str: """Format a `next` style URL for looping through API pages""" - url_parsed = urlparse(url) - params: dict[str, Any] = parse_qs(url_parsed.query) + url = URL(url) + + params: dict[str, Any] = dict(url.query) params["offset"] = offset params["limit"] = limit - url_parts = list(url_parsed[:]) - url_parts[4] = urlencode(params, doseq=True, quote_via=quote) - return str(urlunparse(url_parts)) + url = url.with_query(params) + return str(url) diff --git a/musify/libraries/remote/spotify/api/cache.py b/musify/libraries/remote/spotify/api/cache.py index 928486ab..04f932a2 100644 --- a/musify/libraries/remote/spotify/api/cache.py +++ b/musify/libraries/remote/spotify/api/cache.py @@ -1,6 +1,8 @@ -from urllib.parse import urlparse, parse_qs +from typing import Any -from musify.api.cache.backend.base import RequestSettings, PaginatedRequestSettings +from yarl import URL + +from musify.api.cache.backend.base import RequestSettings from musify.libraries.remote.core.enum import RemoteIDType from musify.libraries.remote.core.exception import RemoteObjectTypeError from musify.libraries.remote.spotify.processors import SpotifyDataWrangler @@ -8,29 +10,42 @@ class SpotifyRequestSettings(RequestSettings): - @staticmethod - def get_name(value: dict) -> str | None: - if isinstance(value, dict): - if value.get("type") == "user": - return value["display_name"] - return value.get("name") + @property + def fields(self) -> tuple[str, ...]: + return "id", - @staticmethod - def get_id(url: str) -> str | None: + def get_key(self, url: str | URL, *_, **__) -> tuple[str | None, ...]: try: - return SpotifyDataWrangler.convert(url, type_in=RemoteIDType.URL, type_out=RemoteIDType.ID) + return SpotifyDataWrangler.convert(str(url), type_in=RemoteIDType.URL, type_out=RemoteIDType.ID), except RemoteObjectTypeError: pass + return (None,) + + @staticmethod + def get_name(response: dict[str, Any]) -> str | None: + if response.get("type") == "user": + return response["display_name"] + return response.get("name") + + +class SpotifyPaginatedRequestSettings(SpotifyRequestSettings): + @property + def fields(self) -> tuple[str, ...]: + return *super().fields, "offset", "size" -class SpotifyPaginatedRequestSettings(PaginatedRequestSettings, SpotifyRequestSettings): + def get_key(self, url: str | URL, *_, **__) -> tuple[str | int | None, ...]: + base = super().get_key(url=url) + return *base, self.get_offset(url), self.get_limit(url) @staticmethod - def get_offset(url: str) -> int: - params = parse_qs(urlparse(url).query) - return int(params.get("offset", [0])[0]) + def get_offset(url: str | URL) -> int: + """Extracts the offset for a paginated request from the given ``url``.""" + params = URL(url).query + return int(params.get("offset", 0)) @staticmethod - def get_limit(url: str) -> int: - params = parse_qs(urlparse(url).query) - return int(params.get("limit", [50])[0]) + def get_limit(url: str | URL) -> int: + """Extracts the limit for a paginated request from the given ``url``.""" + params = URL(url).query + return int(params.get("limit", 50)) diff --git a/musify/libraries/remote/spotify/api/item.py b/musify/libraries/remote/spotify/api/item.py index ea293c89..abcf1621 100644 --- a/musify/libraries/remote/spotify/api/item.py +++ b/musify/libraries/remote/spotify/api/item.py @@ -6,8 +6,10 @@ from collections.abc import Collection, Mapping, MutableMapping from itertools import batched from typing import Any -from urllib.parse import parse_qs, urlparse +from yarl import URL + +from musify.api.cache.session import CachedSession from musify.api.exception import APIError, CacheError from musify.libraries.remote.core import RemoteResponse from musify.libraries.remote.core.enum import RemoteIDType, RemoteObjectType @@ -39,9 +41,9 @@ def _get_unit(self, key: str | None = None, kind: str | None = None) -> str: ########################################################################### ## GET helpers: Generic methods for getting items ########################################################################### - def _cache_results(self, method: str, results: list[dict[str, Any]]) -> None: + async def _cache_results(self, method: str, results: list[dict[str, Any]]) -> None: """Persist ``results`` of a given ``method`` to the cache.""" - if self.handler.cache is None: + if not isinstance(self.handler.session, CachedSession): return # take all parts of href path, excluding ID @@ -52,9 +54,9 @@ def _cache_results(self, method: str, results: list[dict[str, Any]]) -> None: ) results_mapped = {(method.upper(), result[self.id_key]): result for result in results} - repository = self.handler.cache.get_repository_from_url(next(iter(possible_urls))) + repository = self.handler.session.cache.get_repository_from_url(next(iter(possible_urls))) if repository is not None: - repository.update(results_mapped) + await repository.save_responses(results_mapped) def _sort_results( self, results: list[dict[str, Any]], results_cache: list[dict[str, Any]], id_list: Collection[str] @@ -67,7 +69,7 @@ def _sort_results( id_list = to_collection(id_list) results.sort(key=lambda result: id_list.index(result[self.id_key])) - def _get_items_from_cache( + async def _get_items_from_cache( self, method: str, url: str, id_list: Collection[str] ) -> tuple[list[dict[str, Any]], Collection[str], Collection[str]]: """ @@ -77,27 +79,27 @@ def _get_items_from_cache( :param id_list: List of IDs to append to the given URL. :return: (Results from the cache, IDs found in the cache, IDs not found in the cache) """ - if self.handler.cache is None: - self.logger.debug(f"{'CACHE':<7}: {url:<43} | No cache configured, skipping...") + if not isinstance(self.handler.session, CachedSession): + self.handler.log("CACHE", url, message="Cache not configured, skipping...") return [], [], id_list - repository = self.handler.cache.get_repository_from_url(url=url) + repository = self.handler.session.cache.get_repository_from_url(url=url) if repository is None: - self.logger.debug(f"{'CACHE':<7}: {url:<43} | No repository for this endpoint, skipping...") + self.handler.log("CACHE", url, message="No repository for this endpoint, skipping...") return [], [], id_list - results = repository.get_responses([(method.upper(), id_,) for id_ in id_list]) + results = await repository.get_responses([(method.upper(), id_,) for id_ in id_list]) ids_found = {result[self.id_key] for result in results} ids_not_found = {id_ for id_ in id_list if id_ not in ids_found} - self.logger.debug( - f"{'CACHE':<7}: {url:<43} | " - f"Retrieved {len(results):>6} cached responses | " - f"{len(ids_not_found):>6} not found in cache" + self.handler.log( + method="CACHE", + url=url, + message=[f"Retrieved {len(results):>6} cached responses", f"{len(ids_not_found):>6} not found in cache"] ) return results, ids_found, ids_not_found - def _get_items_multi( + async def _get_items_multi( self, url: str, id_list: Collection[str], @@ -124,7 +126,9 @@ def _get_items_multi( url = url.rstrip("/") kind = self._get_unit(key=key, kind=kind) - results_cache, ids_cached, ids_not_cached = self._get_items_from_cache(method=method, url=url, id_list=id_list) + results_cache, ids_cached, ids_not_cached = await self._get_items_from_cache( + method=method, url=url, id_list=id_list + ) bar = self.logger.get_iterator( iterable=ids_not_cached, @@ -134,10 +138,10 @@ def _get_items_multi( ) results: list[dict[str, Any]] = [] - log = [f"{kind.title()}:{len(ids_not_cached):>5}"] + log = f"{kind.title()}: {len(ids_not_cached):>5}" for id_ in bar: - response = self.handler.request( - method=method, url=f"{url}/{id_}", params=params, persist=False, log_pad=43, log_extra=log + response = await self.handler.request( + method=method, url=f"{url}/{id_}", params=params, persist=False, log_message=log ) if self.id_key not in response: response[self.id_key] = id_ @@ -146,12 +150,12 @@ def _get_items_multi( results.extend(response[key]) if key else results.append(response) - self._cache_results(method=method, results=results) + await self._cache_results(method=method, results=results) self._sort_results(results=results, results_cache=results_cache, id_list=id_list) return results - def _get_items_batched( + async def _get_items_batched( self, url: str, id_list: Collection[str], @@ -182,7 +186,9 @@ def _get_items_batched( url = url.rstrip("/") kind = self._get_unit(key=key, kind=kind) - results_cache, ids_cached, ids_not_cached = self._get_items_from_cache(method=method, url=url, id_list=id_list) + results_cache, ids_cached, ids_not_cached = await self._get_items_from_cache( + method=method, url=url, id_list=id_list + ) id_chunks = list(batched(ids_not_cached, limit_value(limit, floor=1, ceil=50))) bar = self.logger.get_iterator( @@ -197,17 +203,17 @@ def _get_items_batched( for idx in bar: # get responses in batches id_chunk = id_chunks[idx] params_chunk = params | {"ids": ",".join(id_chunk)} - log = [f"{kind.title() + ':':<11} {len(results) + len(id_chunk):>6}/{len(ids_not_cached):<6}"] + log = f"{kind.title() + ':':<11} {len(results) + len(id_chunk):>6}/{len(ids_not_cached):<6}" - response = self.handler.request( - method=method, url=url, params=params_chunk, persist=False, log_pad=43, log_extra=log + response = await self.handler.request( + method=method, url=url, params=params_chunk, persist=False, log_message=log ) if key and key not in response: raise APIError(f"Given key '{key}' not found in response keys: {list(response.keys())}") results.extend(response[key]) if key else results.append(response) - self._cache_results(method=method, results=results) + await self._cache_results(method=method, results=results) self._sort_results(results=results, results_cache=results_cache, id_list=id_list) return results @@ -215,7 +221,17 @@ def _get_items_batched( ########################################################################### ## GET endpoints ########################################################################### - def extend_items( + @staticmethod + def _reformat_user_items_block(response: MutableMapping[str, Any]) -> None: + """this usually happens on the items block of a current user's playlist""" + if "next" not in response: + response["next"] = response["href"] + if "previous" not in response: + response["previous"] = None + if "limit" not in response: + response["limit"] = int(URL(response["next"]).query.get("limit", 50)) + + async def extend_items( self, response: MutableMapping[str, Any] | RemoteResponse, kind: RemoteObjectType | str | None = None, @@ -258,16 +274,10 @@ def extend_items( if len(response[self.items_key]) == response["total"]: # skip on fully extended response url = response["href"].split("?")[0] - self.logger.debug(f"{'SKIP':<7}: {url:<43} | Response already extended") + self.handler.log("SKIP", url, message="Response already extended") return response[self.items_key] - # this usually happens on the items block of a current user's playlist - if "next" not in response: - response["next"] = response["href"] - if "previous" not in response: - response["previous"] = None - if "limit" not in response: - response["limit"] = int(parse_qs(urlparse(response["next"]).query).get("limit", [50])[0]) + self._reformat_user_items_block(response) kind = self._get_key(kind) or self.items_key pages = (response["total"] - len(response[self.items_key])) / (response["limit"] or 1) @@ -282,9 +292,9 @@ def extend_items( while response.get("next"): # loop through each page log_count = min(len(response[self.items_key]) + response["limit"], response["total"]) - log = [f"{log_count:>6}/{response["total"]:<6} {key or self.items_key}"] + log = f"{log_count:>6}/{response["total"]:<6} {key or self.items_key}" - response_next = self.handler.request(method=method, url=response["next"], log_pad=95, log_extra=log) + response_next = await self.handler.request(method=method, url=response["next"], log_message=log) response_next = response_next.get(key, response_next) response[self.items_key].extend(response_next[self.items_key]) @@ -292,7 +302,7 @@ def extend_items( response["next"] = response_next.get("next") response["previous"] = response_next.get("previous") - if tqdm is not None: + if tqdm is not None: # TODO: drop me bar.update(len(response_next[self.items_key])) # cache child items @@ -301,14 +311,14 @@ def extend_items( result[item_key] if item_key and item_key in result else result for result in response[self.items_key] ] if all("href" in result for result in results_to_cache): - self._cache_results(method=method, results=results_to_cache) + await self._cache_results(method=method, results=results_to_cache) - if tqdm is not None: + if tqdm is not None: # TODO: drop me bar.close() return response[self.items_key] - def get_items( + async def get_items( self, values: APIInputValue, kind: RemoteObjectType | None = None, @@ -349,7 +359,7 @@ def get_items( # input validation if not isinstance(values, RemoteResponse) and not values: # skip on empty url = f"{self.url}/{self._get_key(kind)}" if kind else self.url - self.logger.debug(f"{'SKIP':<7}: {url:<43} | No data given") + self.handler.log("SKIP", url, message="No data given") return [] if kind is None: # determine the item type kind = self.wrangler.get_item_type(values) @@ -361,18 +371,18 @@ def get_items( id_list = self.wrangler.extract_ids(values, kind=kind) if kind in {RemoteObjectType.USER, RemoteObjectType.PLAYLIST} or len(id_list) <= 1: - results = self._get_items_multi(url=url, id_list=id_list, kind=unit) + results = await self._get_items_multi(url=url, id_list=id_list, kind=unit) else: if kind == RemoteObjectType.ALBUM: limit = limit_value(limit, floor=1, ceil=20) - results = self._get_items_batched(url=url, id_list=id_list, key=unit, limit=limit) + results = await self._get_items_batched(url=url, id_list=id_list, key=unit, limit=limit) key = self.collection_item_map.get(kind, kind) key_name = self._get_key(key) if len(results) == 0 or any(key_name not in result for result in results) or not extend: self._merge_results_to_input(original=values, responses=results, ordered=True) self._refresh_responses(responses=values, skip_checks=False) - self.logger.debug(f"{'DONE':<7}: {url:<43} | Retrieved {len(results):>6} {unit}") + self.handler.log("DONE", url, message=f"Retrieved {len(results):>6} {unit}") return results bar = self.logger.get_iterator( @@ -381,19 +391,18 @@ def get_items( for result in bar: if result[key_name].get("next") or ("next" not in result[key_name] and result[key_name].get("href")): - self.extend_items(result[key_name], kind=kind, key=key, leave_bar=False) + self.handler.log("INFO", url, message=f"Extending {key_name} on {unit}") + await self.extend_items(result[key_name], kind=kind, key=key, leave_bar=False) self._merge_results_to_input(original=values, responses=results, ordered=True) self._refresh_responses(responses=values, skip_checks=False) item_count = sum(len(result[key_name][self.items_key]) for result in results) - self.logger.debug( - f"{'DONE':<7}: {url:<71} | Retrieved {item_count:>6} {key_name} across {len(results):>5} {unit}" - ) + self.handler.log("DONE", url, message=f"Retrieved {item_count:>6} {key_name} across {len(results):>5} {unit}") return results - def get_user_items( + async def get_user_items( self, user: str | None = None, kind: RemoteObjectType = RemoteObjectType.PLAYLIST, @@ -435,14 +444,14 @@ def get_user_items( desc_qualifier = "current user's" if kind == RemoteObjectType.PLAYLIST else "current user's saved" desc = f"Getting {desc_qualifier} {kind.name.lower()}s" - initial = self.handler.get(url, params=params, log_pad=71) - results = self.extend_items(initial, kind=desc, key=kind) + initial = await self.handler.get(url, params=params) + results = await self.extend_items(initial, kind=desc, key=kind) - self.logger.debug(f"{'DONE':<7}: {url:<43} | Retrieved {len(results):>6} {kind.name.lower()}s") + self.handler.log("DONE", url, message=f"Retrieved {len(results):>6} {kind.name.lower()}s") return results - def extend_tracks( + async def extend_tracks( self, values: APIInputValue, features: bool = False, @@ -481,7 +490,7 @@ def extend_tracks( if not features and not analysis: # skip on all False return [] if not values: # skip on empty - self.logger.debug(f"{'SKIP':<7}: {self.url:<43} | No data given") + self.handler.log("SKIP", self.url, message="No data given") return [] self.wrangler.validate_item_type(values, kind=RemoteObjectType.TRACK) @@ -501,13 +510,13 @@ def extend_tracks( result = id_map.copy() for (url, key, _) in config.values(): - result[key] = self.handler.get(f"{url}/{id_}", log_pad=43) | id_map.copy() + result[key] = await self.handler.get(f"{url}/{id_}") | id_map.copy() results = [result] else: results = [] for kind, (url, key, batch) in config.items(): method = self._get_items_batched if batch else self._get_items_multi - responses = method(url=url, id_list=id_list, kind=kind, key=key if batch else None, limit=limit) + responses = await method(url=url, id_list=id_list, kind=kind, key=key if batch else None, limit=limit) responses.sort(key=lambda response: id_list.index(response[self.id_key])) responses = [{self.id_key: response[self.id_key], key: response} for response in responses] @@ -523,15 +532,16 @@ def map_key(value: str) -> str: """Map the given ``value`` to logging appropriate string""" return value.replace("_", " ").replace("analysis", "analyses") - log_url = f"{self.url}/{"+".join(c[0].split("/")[-1] for c in config.values())}" - self.logger.debug( - f"{'DONE':<7}: {log_url:<71} | " - f"Retrieved {" and ".join(map_key(key) for _, key, _ in config.values())} for {len(id_list):>5} tracks" + log_types = " and ".join(map_key(key) for _, key, _ in config.values()) + self.handler.log( + method="DONE", + url=f"{self.url}/{"+".join(c[0].split("/")[-1] for c in config.values())}", + message=f"Retrieved {log_types} for {len(id_list):>5} tracks" ) return results - def get_tracks( + async def get_tracks( self, values: APIInputValue, features: bool = False, @@ -568,7 +578,7 @@ def get_tracks( :return: API JSON responses for each item, or the original response if the input ``values`` are API responses. :raise RemoteObjectTypeError: Raised when the item types of the input ``values`` are not all tracks or IDs. """ - tracks = self.get_items(values=values, kind=RemoteObjectType.TRACK, limit=limit) + tracks = await self.get_items(values=values, kind=RemoteObjectType.TRACK, limit=limit) # ensure that response are being assigned back to the original values if API response(s) given if isinstance(values, Mapping | RemoteResponse): @@ -576,10 +586,10 @@ def get_tracks( elif isinstance(values, Collection) and all(isinstance(v, Mapping | RemoteResponse) for v in values): tracks = values - self.extend_tracks(values=tracks, features=features, analysis=analysis, limit=limit) + await self.extend_tracks(values=tracks, features=features, analysis=analysis, limit=limit) return tracks - def get_artist_albums( + async def get_artist_albums( self, values: APIInputValue, types: Collection[str] = (), limit: int = 50, ) -> dict[str, list[dict[str, Any]]]: """ @@ -609,7 +619,7 @@ def get_artist_albums( # input validation if not isinstance(values, RemoteResponse) and not values: # skip on empty - self.logger.debug(f"{'SKIP':<7}: {url:<43} | No data given") + self.handler.log("SKIP", url, message="No data given") return {} if types and not all(t in ARTIST_ALBUM_TYPES for t in types): @@ -630,8 +640,8 @@ def get_artist_albums( key = RemoteObjectType.ALBUM results: dict[str, dict[str, Any]] = {} for id_ in bar: - results[id_] = self.handler.get(url=url.format(id=id_), params=params) - self.extend_items(results[id_], kind="artist albums", key=key, leave_bar=False) + results[id_] = await self.handler.get(url=url.format(id=id_), params=params) + await self.extend_items(results[id_], kind="artist albums", key=key, leave_bar=False) for album in results[id_][self.items_key]: # add skeleton items block to album responses album["tracks"] = { @@ -644,9 +654,10 @@ def get_artist_albums( self._refresh_responses(responses=values, skip_checks=True) item_count = sum(len(result) for result in results.values()) - self.logger.debug( - f"{'DONE':<7}: {url.format(id="..."):<71} | " - f"Retrieved {item_count:>6} albums across {len(results):>5} artists" + self.handler.log( + method="DONE", + url=url.format(id="..."), + message=f"Retrieved {item_count:>6} albums across {len(results):>5} artists", ) return {k: v[self.items_key] for k, v in results.items()} diff --git a/musify/libraries/remote/spotify/api/misc.py b/musify/libraries/remote/spotify/api/misc.py index 6689439c..8da53bfa 100644 --- a/musify/libraries/remote/spotify/api/misc.py +++ b/musify/libraries/remote/spotify/api/misc.py @@ -1,6 +1,7 @@ """ Implements all required non-items and non-playlist endpoints from the Spotify API. """ +import logging from abc import ABC from collections.abc import MutableMapping from typing import Any @@ -15,7 +16,7 @@ class SpotifyAPIMisc(SpotifyAPIBase, ABC): __slots__ = () - def print_collection( + async def print_collection( self, value: str | MutableMapping[str, Any] | None = None, kind: RemoteIDType | None = None, @@ -34,8 +35,8 @@ def print_collection( url = self.wrangler.convert(id_, kind=kind, type_in=RemoteIDType.ID, type_out=RemoteIDType.URL) limit = limit_value(limit, floor=1, ceil=50) - name = self.handler.get(url, params={"limit": limit}, log_pad=43)["name"] - response = self.handler.get(f"{url}/{key}s", params={"limit": limit}, log_pad=43) + name = (await self.handler.get(url, params={"limit": limit}))["name"] + response = await self.handler.get(f"{url}/{key}s", params={"limit": limit}) i = 0 while response.get("next") or i < response["total"]: # loop through each page, printing data in blocks of 20 @@ -55,24 +56,24 @@ def print_collection( self.print_item(i=i, name=track["name"], uri=track["uri"], length=length, total=response["total"]) if response["next"]: - response = self.handler.get(response["next"], params={"limit": limit}) + response = await self.handler.get(response["next"], params={"limit": limit}) print() ########################################################################### ## GET endpoints ########################################################################### - def get_self(self, update_user_data: bool = True) -> dict[str, Any]: + async def get_self(self, update_user_data: bool = True) -> dict[str, Any]: """ ``GET: /me`` - Get API response for information on current user. :param update_user_data: When True, update the ``_user_data`` stored in this API object. """ - r = self.handler.get(url=f"{self.url}/me", log_pad=71) + r = await self.handler.get(url=f"{self.url}/me") if update_user_data: self.user_data = r return r - def query(self, query: str | None, kind: RemoteObjectType, limit: int = 10) -> list[dict[str, Any]]: + async def query(self, query: str | None, kind: RemoteObjectType, limit: int = 10) -> list[dict[str, Any]]: """ ``GET: /search`` - Query for items. Modify result types returned with kind parameter @@ -86,10 +87,10 @@ def query(self, query: str | None, kind: RemoteObjectType, limit: int = 10) -> l url = f"{self.url}/search" params = {'q': query, "type": kind.name.lower(), "limit": limit_value(limit, floor=1, ceil=50)} - response = self.handler.get(url, params=params) + response = await self.handler.get(url, params=params) if "error" in response: - self.logger.error(f"{'ERROR':<7}: {url:<43} | Query: {query} | {response['error']}") + self.handler.log("SKIP", url, message=[f"Query: {query}", response['error']], level=logging.ERROR) return [] results = response[f"{kind.name.lower()}s"][self.items_key] diff --git a/musify/libraries/remote/spotify/api/playlist.py b/musify/libraries/remote/spotify/api/playlist.py index 0fbce66a..8cce3f2f 100644 --- a/musify/libraries/remote/spotify/api/playlist.py +++ b/musify/libraries/remote/spotify/api/playlist.py @@ -19,7 +19,7 @@ class SpotifyAPIPlaylists(SpotifyAPIBase, ABC): __slots__ = () - def get_playlist_url(self, playlist: str | Mapping[str, Any] | RemoteResponse) -> str: + async def get_playlist_url(self, playlist: str | Mapping[str, Any] | RemoteResponse) -> str: """ Determine the type of the given ``playlist`` and return its API URL. If type cannot be determined, attempt to find the playlist in the @@ -58,7 +58,7 @@ def get_playlist_url(self, playlist: str | Mapping[str, Any] | RemoteResponse) - try: return self.wrangler.convert(playlist, kind=RemoteObjectType.PLAYLIST, type_out=RemoteIDType.URL) except RemoteIDTypeError: - playlists = self.get_user_items(kind=RemoteObjectType.PLAYLIST) + playlists = await self.get_user_items(kind=RemoteObjectType.PLAYLIST) playlists = {pl["name"]: pl["href"] for pl in playlists} if playlist not in playlists: raise RemoteIDTypeError( @@ -70,7 +70,7 @@ def get_playlist_url(self, playlist: str | Mapping[str, Any] | RemoteResponse) - ########################################################################### ## POST endpoints ########################################################################### - def create_playlist(self, name: str, public: bool = True, collaborative: bool = False, *_, **__) -> str: + async def create_playlist(self, name: str, public: bool = True, collaborative: bool = False, *_, **__) -> str: """ ``POST: /users/{user_id}/playlists`` - Create an empty playlist for the current user with the given name. @@ -87,12 +87,12 @@ def create_playlist(self, name: str, public: bool = True, collaborative: bool = "public": public, "collaborative": collaborative, } - pl_url = self.handler.post(url, json=body, log_pad=71)["href"] + pl_url = (await self.handler.post(url, json=body))["href"] - self.logger.debug(f"{'DONE':<7}: {url:<71} | Created playlist: '{name}' -> {pl_url}") + self.handler.log("DONE", url, message=f"Created playlist: '{name}' -> {pl_url}") return pl_url - def add_to_playlist( + async def add_to_playlist( self, playlist: str | Mapping[str, Any] | RemoteResponse, items: Collection[str], @@ -116,10 +116,10 @@ def add_to_playlist( :raise RemoteObjectTypeError: Raised when the item types of the input ``items`` are not all tracks or IDs. """ - url = f"{self.get_playlist_url(playlist)}/tracks" + url = f"{await self.get_playlist_url(playlist)}/tracks" if len(items) == 0: - self.logger.debug(f"{'SKIP':<7}: {url:<43} | No data given") + self.handler.log("SKIP", url, message="No data given") return 0 self.wrangler.validate_item_type(items, kind=RemoteObjectType.TRACK) @@ -128,7 +128,7 @@ def add_to_playlist( self.wrangler.convert(item, kind=RemoteObjectType.TRACK, type_out=RemoteIDType.URI) for item in items ] if skip_dupes: # skip tracks currently in playlist - pl_current = self.get_items(url, kind=RemoteObjectType.PLAYLIST)[0] + pl_current = next(iter(await self.get_items(url, kind=RemoteObjectType.PLAYLIST))) tracks_key = self.collection_item_map[RemoteObjectType.PLAYLIST].name.lower() + "s" tracks = pl_current[tracks_key][self.items_key] @@ -137,16 +137,15 @@ def add_to_playlist( limit = limit_value(limit, floor=1, ceil=100) for uris in batched(uri_list, limit): # add tracks in batches - log = [f"Adding {len(uris):>6} items"] - self.handler.post(url, json={"uris": uris}, log_pad=71, log_extra=log) + await self.handler.post(url, json={"uris": uris}, log_message=f"Adding {len(uris):>6} items") - self.logger.debug(f"{'DONE':<7}: {url:<71} | Added {len(uri_list):>6} items to playlist: {url}") + self.handler.log("DONE", url, message=f"Added {len(uri_list):>6} items to playlist: {url}") return len(uri_list) ########################################################################### ## DELETE endpoints ########################################################################### - def delete_playlist(self, playlist: str | Mapping[str, Any] | RemoteResponse) -> str: + async def delete_playlist(self, playlist: str | Mapping[str, Any] | RemoteResponse) -> str: """ ``DELETE: /playlists/{playlist_id}/followers`` - Unfollow a given playlist. WARNING: This function will destructively modify your remote playlists. @@ -158,11 +157,11 @@ def delete_playlist(self, playlist: str | Mapping[str, Any] | RemoteResponse) -> - a RemoteResponse object representing a remote playlist. :return: API URL for playlist. """ - url = f"{self.get_playlist_url(playlist)}/followers" - self.handler.delete(url, log_pad=43) + url = f"{await self.get_playlist_url(playlist)}/followers" + await self.handler.delete(url) return url - def clear_from_playlist( + async def clear_from_playlist( self, playlist: str | Mapping[str, Any] | RemoteResponse, items: Collection[str] | None = None, @@ -186,13 +185,13 @@ def clear_from_playlist( :raise RemoteObjectTypeError: Raised when the item types of the input ``items`` are not all tracks or IDs. """ - url = f"{self.get_playlist_url(playlist)}/tracks" + url = f"{await self.get_playlist_url(playlist)}/tracks" if items is not None and len(items) == 0: - self.logger.debug(f"{'SKIP':<7}: {url:<43} | No data given") + self.handler.log("SKIP", url, message="No data given") return 0 if items is None: # clear everything - pl_current = self.get_items(url, kind=RemoteObjectType.PLAYLIST, extend=True)[0] + pl_current = next(iter(await self.get_items(url, kind=RemoteObjectType.PLAYLIST, extend=True))) tracks_key = self.collection_item_map[RemoteObjectType.PLAYLIST].name.lower() + "s" tracks = pl_current[tracks_key][self.items_key] @@ -204,14 +203,13 @@ def clear_from_playlist( ] if not uri_list: # skip when nothing to clear - self.logger.debug(f"{'SKIP':<7}: {url:<43} | No tracks to clear") + self.handler.log("SKIP", url, message="No tracks to clear") return 0 limit = limit_value(limit, floor=1, ceil=100) for uris in batched(uri_list, limit): # clear in batches body = {"tracks": [{"uri": uri} for uri in uris]} - log = [f"Clearing {len(uri_list):>3} tracks"] - self.handler.delete(url, json=body, log_pad=71, log_extra=log) + await self.handler.delete(url, json=body, log_message=f"Clearing {len(uri_list):>3} tracks") - self.logger.debug(f"{'DONE':<7}: {url:<71} | Cleared {len(uri_list):>3} tracks") + self.handler.log("DONE", url, message=f"Cleared {len(uri_list):>3} tracks") return len(uri_list) diff --git a/musify/libraries/remote/spotify/base.py b/musify/libraries/remote/spotify/base.py index cbd1ca63..30ef7d3f 100644 --- a/musify/libraries/remote/spotify/base.py +++ b/musify/libraries/remote/spotify/base.py @@ -16,8 +16,6 @@ class SpotifyObject(RemoteObject[SpotifyAPI], ABC): __slots__ = () - _url_pad = 71 - @property def id(self): return self.response["id"] diff --git a/musify/libraries/remote/spotify/library.py b/musify/libraries/remote/spotify/library.py index 4312fa73..7074a19c 100644 --- a/musify/libraries/remote/spotify/library.py +++ b/musify/libraries/remote/spotify/library.py @@ -45,7 +45,7 @@ def _filter_playlists(self, responses: list[dict[str, Any]]) -> list[dict[str, A def _get_total_tracks(self, responses: list[dict[str, Any]]) -> int: return sum(pl["tracks"]["total"] for pl in responses) - def enrich_tracks( + async def enrich_tracks( self, features: bool = False, analysis: bool = False, albums: bool = False, artists: bool = False ) -> None: """ @@ -65,12 +65,12 @@ def enrich_tracks( ) tracks = [track for track in self.tracks if track.has_uri] - self.api.extend_tracks(tracks, features=features, analysis=analysis) + await self.api.extend_tracks(tracks, features=features, analysis=analysis) # enrich on list of URIs to avoid duplicate calls for same items if albums: album_uris: set[str] = {track.response["album"]["uri"] for track in self.tracks} - album_responses = self.api.get_items(album_uris, kind=RemoteObjectType.ALBUM, extend=False) + album_responses = await self.api.get_items(album_uris, kind=RemoteObjectType.ALBUM, extend=False) for album in album_responses: album.pop("tracks") @@ -80,7 +80,7 @@ def enrich_tracks( if artists: artist_uris: set[str] = {artist["uri"] for track in self.tracks for artist in track.response["artists"]} - artist_responses = self.api.get_items(artist_uris, kind=RemoteObjectType.ARTIST, extend=False) + artist_responses = await self.api.get_items(artist_uris, kind=RemoteObjectType.ARTIST, extend=False) artists = {response["uri"]: response for response in artist_responses} for track in self.tracks: @@ -91,7 +91,7 @@ def enrich_tracks( self.logger.debug(f"Enrich {self.api.source} tracks: DONE\n") - def enrich_saved_albums(self) -> None: + async def enrich_saved_albums(self) -> None: """Extends the tracks data for currently loaded albums, getting all available tracks data for each album""" if not self.albums or all(len(album) == album.track_total for album in self.albums): return @@ -104,7 +104,7 @@ def enrich_saved_albums(self) -> None: key = self.api.collection_item_map[kind] for album in self.albums: - self.api.extend_items(album, kind=kind, key=key) + await self.api.extend_items(album, kind=kind, key=key) album.refresh(skip_checks=False) for track in album.tracks: # add tracks from this album to the user's saved tracks @@ -113,7 +113,7 @@ def enrich_saved_albums(self) -> None: self.logger.debug(f"Enrich {self.api.source} artists: DONE\n") - def enrich_saved_artists(self, tracks: bool = False, types: Collection[str] = ()) -> None: + async def enrich_saved_artists(self, tracks: bool = False, types: Collection[str] = ()) -> None: """ Gets all albums for current loaded following artists. @@ -128,7 +128,7 @@ def enrich_saved_artists(self, tracks: bool = False, types: Collection[str] = () f"\33[1;95m >\33[1;97m Enriching {len(self.artists)} {self.api.source} artists \33[0m" ) - self.api.get_artist_albums(self.artists, types=types) + await self.api.get_artist_albums(self.artists, types=types) if tracks: kind = RemoteObjectType.ALBUM @@ -137,7 +137,7 @@ def enrich_saved_artists(self, tracks: bool = False, types: Collection[str] = () responses_albums = [album for artist in self.artists for album in artist.albums] bar = self.logger.get_iterator(iterable=responses_albums, desc="Getting album tracks", unit="albums") for album in bar: - self.api.extend_items(album, kind=kind, key=key) + await self.api.extend_items(album, kind=kind, key=key) album.refresh(skip_checks=False) self.logger.debug(f"Enrich {self.api.source} artists: DONE\n") diff --git a/musify/libraries/remote/spotify/object.py b/musify/libraries/remote/spotify/object.py index 01b4be6b..7a2f731d 100644 --- a/musify/libraries/remote/spotify/object.py +++ b/musify/libraries/remote/spotify/object.py @@ -166,14 +166,11 @@ def comments(self, value: UnitCollection[str] | None): @property def image_links(self): album = self.response.get("album", {}) - images = {image["height"]: image["url"] for image in album.get("images", [])} - if not images: + if not (images := album.get("images", [])): return {} - return {"cover_front": next(url for height, url in images.items() if height == max(images))} - @property - def has_image(self): - return len(self.response.get("album", {}).get("images", [])) > 0 + images = {image["height"]: image["url"] for image in images} + return {"cover_front": next(url for height, url in images.items() if height == max(images))} @property def length(self): @@ -204,7 +201,7 @@ def refresh(self, skip_checks: bool = False) -> None: ] @classmethod - def load( + async def load( cls, value: str | Mapping[str, Any] | RemoteResponse, api: SpotifyAPI, @@ -225,10 +222,13 @@ def load( id_, kind=RemoteObjectType.TRACK, type_in=RemoteIDType.ID, type_out=RemoteIDType.URL ) } - self.reload(features=features, analysis=analysis, extend_album=extend_album, extend_artists=extend_artists) + await self.reload( + features=features, analysis=analysis, extend_album=extend_album, extend_artists=extend_artists + ) + return self - def reload( + async def reload( self, features: bool = False, analysis: bool = False, @@ -240,13 +240,13 @@ def reload( self._check_for_api() # reload with enriched data - response = self.api.handler.get(self.url, log_pad=self._url_pad) + response = await self.api.handler.get(self.url) if extend_album: - self.api.get_items(response["album"], kind=RemoteObjectType.ALBUM, extend=False) + await self.api.get_items(response["album"], kind=RemoteObjectType.ALBUM, extend=False) if extend_artists: - self.api.get_items(response["artists"], kind=RemoteObjectType.ARTIST) + await self.api.get_items(response["artists"], kind=RemoteObjectType.ARTIST) if features or analysis: - self.api.extend_tracks(response, features=features, analysis=analysis) + await self.api.extend_tracks(response, features=features, analysis=analysis) self.__init__(response=response, api=self.api) @@ -263,7 +263,9 @@ def _get_item_kind(cls, api: SpotifyAPI) -> RemoteObjectType: @classmethod @abstractmethod - def _get_items(cls, items: Collection[str] | MutableMapping[str, Any], api: SpotifyAPI) -> list[dict[str, Any]]: + async def _get_items( + cls, items: Collection[str] | MutableMapping[str, Any], api: SpotifyAPI + ) -> list[dict[str, Any]]: """Call the ``api`` to get values for the given ``items`` URIs""" raise NotImplementedError @@ -277,7 +279,7 @@ def _filter_items(cls, items: Iterable[T], response: Mapping[str, Any]) -> Itera return items @classmethod - def _extend_response(cls, response: MutableMapping[str, Any], api: SpotifyAPI, *_, **__) -> bool: + async def _extend_response(cls, response: MutableMapping[str, Any], api: SpotifyAPI, *_, **__) -> bool: """ Apply extensions to specific aspects of the given ``response``. Does nothing by default. Override to implement object-specific extensions. @@ -327,7 +329,7 @@ def _merge_items_to_response( return uri_matched, uri_missing @classmethod - def _load_new(cls, value: str | Mapping[str, Any] | RemoteResponse, api: SpotifyAPI, *args, **kwargs) -> Self: + async def _load_new(cls, value: str | Mapping[str, Any] | RemoteResponse, api: SpotifyAPI, *args, **kwargs) -> Self: """ Sets up a new object of the current class for the given ``value`` by calling ``__new__`` and adding just enough attributes to the object to get :py:meth:`reload` to run. @@ -341,11 +343,11 @@ def _load_new(cls, value: str | Mapping[str, Any] | RemoteResponse, api: Spotify self.api = api self._response = {"href": url} - self.reload(*args, **kwargs) + await self.reload(*args, **kwargs) return self @classmethod - def load( + async def load( cls, value: str | Mapping[str, Any] | RemoteResponse, api: SpotifyAPI, @@ -362,22 +364,22 @@ def load( # no items given, regenerate API response from the URL if any({not items, isinstance(value, Mapping) and api.items_key not in value.get(item_key, [])}): - return cls._load_new(value=value, api=api, *args, **kwargs) + return await cls._load_new(value=value, api=api, *args, **kwargs) if isinstance(value, MutableMapping) and api.wrangler.get_item_type(value) == cls.kind: # input is response response = deepcopy(value) else: # load fresh response from the API - response = cls.api.get_items(value, kind=cls.kind)[0] + response = await cls.api.get_items(value, kind=cls.kind)[0] # filter down input items to those that match the response items = cls._filter_items(items=items, response=response) matched, missing = cls._merge_items_to_response(items=items, response=response[item_key][api.items_key]) if missing: - items_missing = cls._get_items(items=missing, api=api) + items_missing = await cls._get_items(items=missing, api=api) cls._merge_items_to_response(items=items_missing, response=response[item_key][api.items_key], skip=matched) - skip_checks = cls._extend_response(response=response, api=api, *args, **kwargs) + skip_checks = await cls._extend_response(response=response, api=api, *args, **kwargs) return cls(response=response, api=api, skip_checks=skip_checks) @@ -457,12 +459,11 @@ def track_total(self): @property def image_links(self): - images = {image["height"]: image["url"] for image in self.response["images"]} - return {"cover_front": url for height, url in images.items() if height == max(images)} + if not (images := self.response.get("images")): + return {} - @property - def has_image(self): - return len(self.response["images"]) > 0 + images = {image["height"]: image["url"] for image in images} + return {"cover_front": url for height, url in images.items() if height == max(images)} @property def date_created(self): @@ -494,11 +495,13 @@ def refresh(self, skip_checks: bool = False) -> None: self._check_total() @classmethod - def _get_items(cls, items: Collection[str] | MutableMapping[str, Any], api: SpotifyAPI) -> list[dict[str, Any]]: - return api.get_tracks(items) + async def _get_items( + cls, items: Collection[str] | MutableMapping[str, Any], api: SpotifyAPI + ) -> list[dict[str, Any]]: + return await api.get_tracks(items) @classmethod - def _extend_response( + async def _extend_response( cls, response: MutableMapping[str, Any], api: SpotifyAPI, @@ -512,20 +515,20 @@ def _extend_response( if extend_tracks: # noinspection PyTypeChecker - api.extend_items(response, kind=cls.kind, key=item_kind, leave_bar=leave_bar) + await api.extend_items(response, kind=cls.kind, key=item_kind, leave_bar=leave_bar) item_key = item_kind.name.lower() + "s" tracks = [item["track"] for item in response.get(item_key, {}).get(api.items_key, [])] if tracks and extend_features: - api.extend_tracks(tracks, limit=response[item_key]["limit"], features=True) + await api.extend_tracks(tracks, limit=response[item_key]["limit"], features=True) return not extend_tracks - def reload(self, extend_tracks: bool = False, extend_features: bool = False, *_, **__) -> None: + async def reload(self, extend_tracks: bool = False, extend_features: bool = False, *_, **__) -> None: self._check_for_api() - response = self.api.get_items(self.url, kind=RemoteObjectType.PLAYLIST, extend=False)[0] + response = next(iter(await self.api.get_items(self.url, kind=RemoteObjectType.PLAYLIST, extend=False))) - skip_checks = self._extend_response( + skip_checks = await self._extend_response( response=response, api=self.api, extend_tracks=extend_tracks, extend_features=extend_features ) @@ -608,12 +611,11 @@ def compilation(self): @property def image_links(self): - images = {image["height"]: image["url"] for image in self.response["images"]} - return {"cover_front": url for height, url in images.items() if height == max(images)} + if not (images := self.response.get("images")): + return {} - @property - def has_image(self): - return len(self.response["images"]) > 0 + images = {image["height"]: image["url"] for image in images} + return {"cover_front": url for height, url in images.items() if height == max(images)} @property def rating(self): @@ -646,8 +648,10 @@ def refresh(self, skip_checks: bool = False) -> None: track.disc_total = self.disc_total @classmethod - def _get_items(cls, items: Collection[str] | MutableMapping[str, Any], api: SpotifyAPI) -> list[dict[str, Any]]: - return api.get_tracks(items) + async def _get_items( + cls, items: Collection[str] | MutableMapping[str, Any], api: SpotifyAPI + ) -> list[dict[str, Any]]: + return await api.get_tracks(items) @classmethod def _filter_items(cls, items: Iterable[SpotifyTrack], response: Mapping[str, Any]) -> Iterable[SpotifyTrack]: @@ -655,7 +659,7 @@ def _filter_items(cls, items: Iterable[SpotifyTrack], response: Mapping[str, Any return [item for item in items if item.response.get("album", {}).get("id") == response["id"]] @classmethod - def _extend_response( + async def _extend_response( cls, response: MutableMapping[str, Any], api: SpotifyAPI, @@ -669,26 +673,26 @@ def _extend_response( item_kind = api.collection_item_map[cls.kind] if extend_artists: - api.get_items(response["artists"], kind=RemoteObjectType.ARTIST) + await api.get_items(response["artists"], kind=RemoteObjectType.ARTIST) if extend_tracks: # noinspection PyTypeChecker - api.extend_items(response, kind=cls.kind, key=item_kind, leave_bar=leave_bar) + await api.extend_items(response, kind=cls.kind, key=item_kind, leave_bar=leave_bar) item_key = item_kind.name.lower() + "s" tracks = response.get(item_key, {}).get(api.items_key) if tracks and extend_features: - api.extend_tracks(tracks, limit=response[item_key]["limit"], features=True) + await api.extend_tracks(tracks, limit=response[item_key]["limit"], features=True) return not extend_tracks - def reload( + async def reload( self, extend_artists: bool = False, extend_tracks: bool = False, extend_features: bool = False, *_, **__ ) -> None: self._check_for_api() - response = self.api.get_items(self.url, kind=RemoteObjectType.ALBUM, extend=False)[0] + response = next(iter(await self.api.get_items(self.url, kind=RemoteObjectType.ALBUM, extend=False))) - skip_checks = self._extend_response( + skip_checks = await self._extend_response( response=response, api=self.api, extend_tracks=extend_tracks, @@ -737,7 +741,10 @@ def genres(self): @property def image_links(self): - images = {image["height"]: image["url"] for image in self.response.get("images", [])} + if not (images := self.response.get("images")): + return {} + + images = {image["height"]: image["url"] for image in images} return {"cover_front": url for height, url in images.items() if height == max(images)} @property @@ -766,8 +773,10 @@ def _get_item_kind(cls, api: SpotifyAPI) -> RemoteObjectType: return RemoteObjectType.ALBUM @classmethod - def _get_items(cls, items: Collection[str] | MutableMapping[str, Any], api: SpotifyAPI) -> list[dict[str, Any]]: - return api.get_items(items, extend=False) + async def _get_items( + cls, items: Collection[str] | MutableMapping[str, Any], api: SpotifyAPI + ) -> list[dict[str, Any]]: + return await api.get_items(items, extend=False) @classmethod def _filter_items(cls, items: Iterable[SpotifyAlbum], response: dict[str, Any]) -> Iterable[SpotifyAlbum]: @@ -778,7 +787,7 @@ def _filter_items(cls, items: Iterable[SpotifyAlbum], response: dict[str, Any]) ] @classmethod - def _extend_response( + async def _extend_response( cls, response: MutableMapping[str, Any], api: SpotifyAPI, @@ -795,7 +804,7 @@ def _extend_response( response_items = response.get(item_key, {}) has_all_albums = item_key in response and len(response_items[api.items_key]) == response_items["total"] if extend_albums and not has_all_albums: - api.get_artist_albums(response, limit=response.get(item_key, {}).get("limit", 50)) + await api.get_artist_albums(response, limit=response.get(item_key, {}).get("limit", 50)) album_item_kind = api.collection_item_map[item_kind] album_item_key = album_item_kind.name.lower() + "s" @@ -803,21 +812,21 @@ def _extend_response( if albums and extend_tracks: for album in albums: - api.extend_items(album[album_item_key], kind=item_kind, key=album_item_kind) + await api.extend_items(album[album_item_key], kind=item_kind, key=album_item_kind) if albums and extend_features: tracks = [track for album in albums for track in album[album_item_key]["items"]] - api.extend_tracks(tracks, limit=response[item_key]["limit"], features=True) + await api.extend_tracks(tracks, limit=response[item_key]["limit"], features=True) return not extend_albums or not extend_tracks - def reload( + async def reload( self, extend_albums: bool = False, extend_tracks: bool = False, extend_features: bool = False, *_, **__ ) -> None: self._check_for_api() - response = self.api.handler.get(url=self.url, log_pad=self._url_pad) + response = await self.api.handler.get(url=self.url) - skip_checks = self._extend_response( + skip_checks = await self._extend_response( response=response, api=self.api, extend_albums=extend_albums, diff --git a/musify/libraries/remote/spotify/processors.py b/musify/libraries/remote/spotify/processors.py index be2fa3ed..96b3ff12 100644 --- a/musify/libraries/remote/spotify/processors.py +++ b/musify/libraries/remote/spotify/processors.py @@ -3,7 +3,8 @@ """ from collections.abc import Mapping from typing import Any -from urllib.parse import urlparse + +from yarl import URL from musify.exception import MusifyEnumError from musify.libraries.core.collection import MusifyCollection @@ -80,7 +81,7 @@ def _get_item_type( if value.startswith(cls.url_api) or value.startswith(cls.url_ext): # open/API URL value = value.removeprefix(cls.url_api if value.startswith(cls.url_api) else cls.url_ext) - url_path = urlparse(value).path.split("/") + url_path = URL(value).path.split("/") for chunk in url_path: try: return RemoteObjectType.from_name(chunk.casefold().rstrip('s'))[0] @@ -161,7 +162,7 @@ def _get_id( @classmethod def _get_id_from_url(cls, value: str, kind: RemoteObjectType | None = None) -> tuple[RemoteObjectType, str]: - url_path = urlparse(value).path.split("/") + url_path = URL(value).path.split("/") for chunk in url_path: try: kind = RemoteObjectType.from_name(chunk.rstrip('s'))[0] diff --git a/musify/log/logger.py b/musify/log/logger.py index e64be08d..7c51b153 100644 --- a/musify/log/logger.py +++ b/musify/log/logger.py @@ -12,7 +12,6 @@ from musify.log import INFO_EXTRA, REPORT, STAT T = TypeVar("T") - try: from tqdm.auto import tqdm ProgressBarType = tqdm | Iterable[T] @@ -76,7 +75,9 @@ def stat(self, msg, *args, **kwargs) -> None: def print(self, level: int = logging.CRITICAL + 1) -> None: """Print a new line only when DEBUG < ``logger level`` <= ``level`` for all console handlers""" - if not self.compact and all(logging.DEBUG < h.level <= level for h in self.stdout_handlers): + if not self.compact and self.stdout_handlers and all( + logging.DEBUG < h.level <= level for h in self.stdout_handlers + ): print() def get_iterator[T: Any]( @@ -91,7 +92,7 @@ def get_iterator[T: Any]( For tqdm kwargs, see :py:class:`tqdm_std` """ if tqdm is None: - return iterable if iterable is not None else range(total) + return iter(iterable) if iterable is not None else range(total) # noinspection SpellCheckingInspection preset_keys = ("leave", "disable", "file", "ncols", "colour", "smoothing", "position") diff --git a/musify/processors/match.py b/musify/processors/match.py index a94580a7..d30f4fb9 100644 --- a/musify/processors/match.py +++ b/musify/processors/match.py @@ -82,7 +82,7 @@ def __init__(self): #: The :py:class:`MusifyLogger` for this object self.logger: MusifyLogger = logging.getLogger(__name__) - def log_messages(self, messages: MutableSequence[str], pad: str = ' ') -> None: + def log(self, messages: MutableSequence[str], pad: str = ' ') -> None: """ Log lists of ``messages`` in a uniform aligned format with a given ``pad`` character. @@ -98,7 +98,7 @@ def _log_algorithm(self, source: MusifyObject, extra: Iterable[str] = ()) -> Non log = [source.name, algorithm] if extra: log.extend(extra) - self.log_messages(log, pad='>') + self.log(log, pad='>') def _log_test[T: MusifyObject](self, source: T, result: T | None, test: Any, extra: Iterable[str] = ()) -> None: """Wrapper for initially logging a test result in a uniform aligned format""" @@ -112,14 +112,14 @@ def _log_test[T: MusifyObject](self, source: T, result: T | None, test: Any, ext log = [source.name, log_result, f"{algorithm:<10}={test:<5}"] if extra: log.extend(extra) - self.log_messages(log) + self.log(log) def _log_match[T: MusifyObject](self, source: T, result: T, extra: Iterable[str] = ()) -> None: """Wrapper for initially logging a match in a correctly aligned format""" log = [source.name, f"< Matched URI: {result.uri}"] if extra: log.extend(extra) - self.log_messages(log, pad='<') + self.log(log, pad='<') def clean_tags(self, source: MusifyObject) -> None: """ diff --git a/musify/types.py b/musify/types.py index dccf91ed..8107eb35 100644 --- a/musify/types.py +++ b/musify/types.py @@ -2,14 +2,12 @@ All core type hints to use throughout the entire package. """ from collections.abc import Iterable, Sequence, MutableSequence, Collection, Mapping, MutableMapping -from typing import TypeVar -UT = TypeVar('UT') -UnitIterable = UT | Iterable[UT] -UnitCollection = UT | Collection[UT] -UnitSequence = UT | Sequence[UT] -UnitMutableSequence = UT | MutableSequence[UT] -UnitList = UT | list[UT] +type UnitIterable[T] = T | Iterable[T] +type UnitCollection[T] = T | Collection[T] +type UnitSequence[T] = T | Sequence[T] +type UnitMutableSequence[T] = T | MutableSequence[T] +type UnitList[T] = T | list[T] JSON_VALUE = str | int | float | list | dict | bool | None JSON = Mapping[str, JSON_VALUE] diff --git a/pyproject.toml b/pyproject.toml index 49155072..8f2b7368 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -25,14 +25,14 @@ classifiers = [ ] dependencies = [ "mutagen~=1.47", - "requests~=2.31", + "aiohttp~=3.9", "python-dateutil~=2.9", ] [project.optional-dependencies] # optional functionality all = [ - "musify[bars,images,musicbee]", + "musify[bars,images,musicbee,sqlite]", ] bars = [ "tqdm~=4.66", @@ -44,6 +44,9 @@ musicbee = [ "xmltodict~=0.13", "lxml~=5.2", ] +sqlite = [ + "aiosqlite~=0.20", +] # dev dependencies build = [ @@ -53,9 +56,10 @@ build = [ test = [ "musify[all]", "pytest~=8.2", + "pytest-asyncio~=0.23", "pytest-xdist~=3.6", "pytest-mock~=3.14", - "requests-mock~=1.12", + "aioresponses~=0.7", "pyyaml~=6.0", "pycountry~=23.12", ] @@ -104,3 +108,4 @@ markers = [ "slow: marks test as slow (deselect with '-m \"not slow\"')", "manual: marks tests to be run only when manually directed to by the developer", ] +asyncio_mode = "auto" diff --git a/tests/__resources/library/musicbee_library.xml b/tests/__resources/library/musicbee_library.xml index 9ccae650..bce5bf37 100644 --- a/tests/__resources/library/musicbee_library.xml +++ b/tests/__resources/library/musicbee_library.xml @@ -38,7 +38,7 @@ Play Date UTC2023-07-20T06:12:26Z Play Count5 Track TypeFile - Locationfile://localhost/{path_resources}/track/noise_mp3.mp3 + Locationfile://localhost/{path_resources}/track/noiSE_mP3.mp3 2 @@ -70,7 +70,7 @@ Play Date UTC2023-09-02T08:21:22Z Play Count10 Track TypeFile - Locationfile://localhost/{path_resources}/track/noise_flac.flac + Locationfile://localhost/{path_resources}/track/NOISE_FLaC.flac 3 diff --git a/tests/__resources/playlist/Simple Playlist.m3u b/tests/__resources/playlist/Simple Playlist.m3u index a1bdcec8..54bd6997 100644 --- a/tests/__resources/playlist/Simple Playlist.m3u +++ b/tests/__resources/playlist/Simple Playlist.m3u @@ -1,3 +1,3 @@ -../track/noise_flac.flac -../track/noise_mp3.mp3 +../track/NOISE_FLaC.flac +../track/noiSE_mP3.mp3 ../track/noise_wma.wma diff --git a/tests/__resources/playlist/The Best Playlist Ever.xautopf b/tests/__resources/playlist/The Best Playlist Ever.xautopf index 1b01e14b..e0277ca0 100644 --- a/tests/__resources/playlist/The Best Playlist Ever.xautopf +++ b/tests/__resources/playlist/The Best Playlist Ever.xautopf @@ -21,7 +21,7 @@ - ../track/noise_flac.flac|../track/noise_mp3.mp3|../track/noise_wma.wma - ../playlist/exclude_me.flac|../playlist/exclude_me_2.mp3|../track/noise_mp3.mp3 + ../track/NOISE_FLaC.flac|../track/noiSE_mP3.mp3|../track/noise_wma.wma + ../playlist/exclude_me.flac|../playlist/exclude_me_2.mp3|../track/noiSE_mP3.mp3 diff --git a/tests/__resources/track/noise_flac.flac b/tests/__resources/track/NOISE_FLaC.flac similarity index 100% rename from tests/__resources/track/noise_flac.flac rename to tests/__resources/track/NOISE_FLaC.flac diff --git a/tests/__resources/track/noise_mp3.mp3 b/tests/__resources/track/noiSE_mP3.mp3 similarity index 100% rename from tests/__resources/track/noise_mp3.mp3 rename to tests/__resources/track/noiSE_mP3.mp3 diff --git a/tests/api/cache/backend/test_sqlite.py b/tests/api/cache/backend/test_sqlite.py index 2cc8b4d4..0898386a 100644 --- a/tests/api/cache/backend/test_sqlite.py +++ b/tests/api/cache/backend/test_sqlite.py @@ -1,3 +1,4 @@ +import contextlib import json import sqlite3 from datetime import datetime, timedelta @@ -6,14 +7,17 @@ from random import randrange from tempfile import gettempdir from typing import Any -from urllib.parse import urlparse +import aiosqlite import pytest -from requests import Response, Request +from aiohttp import ClientRequest, ClientResponse, ClientSession +from yarl import URL -from musify.api.cache.backend.base import RequestSettings, PaginatedRequestSettings +from musify.api.cache.backend.base import RequestSettings from musify.api.cache.backend.sqlite import SQLiteTable, SQLiteCache +from musify.api.cache.response import CachedResponse from tests.api.cache.backend.testers import ResponseRepositoryTester, ResponseCacheTester, BaseResponseTester +from tests.api.cache.backend.utils import MockPaginatedRequestSettings from tests.utils import random_str @@ -22,7 +26,7 @@ class SQLiteTester(BaseResponseTester): @staticmethod def generate_connection() -> sqlite3.Connection: - return sqlite3.Connection(database=":memory:") + return aiosqlite.connect(":memory:") @staticmethod def generate_item(settings: RequestSettings) -> tuple[tuple, dict[str, Any]]: @@ -34,82 +38,123 @@ def generate_item(settings: RequestSettings) -> tuple[tuple, dict[str, Any]]: str(randrange(0, 100)): randrange(0, 100), } - if isinstance(settings, PaginatedRequestSettings): + if isinstance(settings, MockPaginatedRequestSettings): key = (*key, randrange(0, 100), randrange(1, 50)) return key, value - @staticmethod - def generate_response_from_item(settings: RequestSettings, key: tuple, value: dict[str, Any]) -> Response: + # noinspection PyProtectedMember + @classmethod + def generate_response_from_item( + cls, settings: RequestSettings, key: Any, value: Any, session: ClientSession = None + ) -> ClientResponse: url = f"http://test.com/{settings.name}/{key[1]}" + return cls._generate_response_from_item(url=url, key=key, value=value, session=session) + + # noinspection PyProtectedMember + @classmethod + def generate_bad_response_from_item( + cls, settings: RequestSettings, key: Any, value: Any, session: ClientSession = None + ) -> ClientResponse: + url = "http://test.com" + return cls._generate_response_from_item(url=url, key=key, value=value, session=session) + + @staticmethod + def _generate_response_from_item(url: str, key: Any, value: Any, session: ClientSession = None) -> ClientResponse: params = {} if len(key) == 4: params["offset"] = key[2] - params["size"] = key[3] - - request = Request(method=key[0], url=url, params=params).prepare() - - response = Response() - response.encoding = "utf-8" - response._content = json.dumps(value).encode(response.encoding) - response.status_code = 200 - response.url = request.url - response.request = request - - return response + params["limit"] = key[3] + + if session is not None: + # noinspection PyProtectedMember + request = ClientRequest( + method=key[0], + url=URL(url), + params=params, + headers={"Content-Type": "application/json"}, + loop=session._loop, + session=session + ) + else: + request = ClientRequest( + method=key[0], + url=URL(url), + params=params, + headers={"Content-Type": "application/json"}, + ) + return CachedResponse(request=request, data=json.dumps(value)) class TestSQLiteTable(SQLiteTester, ResponseRepositoryTester): @pytest.fixture - def repository( - self, settings: RequestSettings, connection: sqlite3.Connection, valid_items: dict, invalid_items: dict + async def repository( + self, connection: aiosqlite.Connection, settings: RequestSettings, valid_items: dict, invalid_items: dict ) -> SQLiteTable: expire = timedelta(days=2) - repository = SQLiteTable(connection=connection, settings=settings, expire=expire) - columns = ( - *repository._primary_key_columns, - repository.cached_column, - repository.expiry_column, - repository.data_column - ) - query = "\n".join(( - f"INSERT OR REPLACE INTO {settings.name} (", - f"\t{", ".join(columns)}", - ") ", - f"VALUES ({",".join("?" * len(columns))});", - )) - parameters = [ - (*key, datetime.now().isoformat(), repository.expire.isoformat(), repository.serialize(value)) - for key, value in valid_items.items() - ] - invalid_expire_dt = datetime.now() - expire # expiry time in the past, response cache has expired - parameters.extend( - (*key, datetime.now().isoformat(), invalid_expire_dt.isoformat(), repository.serialize(value)) - for key, value in invalid_items.items() - ) - connection.executemany(query, parameters) - - return repository - - @property - def connection_closed_exception(self) -> type[Exception]: - return sqlite3.DatabaseError - - def test_init(self, connection: sqlite3.Connection, settings: RequestSettings): - repository = SQLiteTable(connection=connection, settings=settings) - - cur = connection.execute( - f"SELECT name FROM sqlite_master WHERE type='table' AND name='{settings.name}'" - ) - rows = cur.fetchall() - assert len(rows) == 1 - assert rows[0][0] == settings.name - cur = connection.execute(f"SELECT name FROM pragma_table_info('{settings.name}');") - columns = {row[0] for row in cur} - assert {repository.name_column, repository.data_column, repository.expiry_column}.issubset(columns) - assert set(repository._primary_key_columns).issubset(columns) + async with SQLiteTable(connection, settings=settings, expire=expire) as repository: + columns = ( + *repository._primary_key_columns, + repository.cached_column, + repository.expiry_column, + repository.data_column + ) + query = "\n".join(( + f"INSERT OR REPLACE INTO {settings.name} (", + f"\t{", ".join(columns)}", + ") ", + f"VALUES ({",".join("?" * len(columns))});", + )) + parameters = [ + (*key, datetime.now().isoformat(), repository.expire.isoformat(), repository.serialize(value)) + for key, value in valid_items.items() + ] + invalid_expire_dt = datetime.now() - expire # expiry time in the past, response cache has expired + parameters.extend( + (*key, datetime.now().isoformat(), invalid_expire_dt.isoformat(), repository.serialize(value)) + for key, value in invalid_items.items() + ) + + await connection.executemany(query, parameters) + await connection.commit() + + yield repository + + async def test_init_fails(self, connection: aiosqlite.Connection, settings: RequestSettings): + repository = SQLiteTable(connection, settings=settings) + with pytest.raises(ValueError): + assert await repository.count() + + with pytest.raises(ValueError): + await repository + + async with connection: + async with connection.execute( + f"SELECT name FROM sqlite_master WHERE type='table' AND name='{settings.name}'" + ) as cur: + rows = await cur.fetchall() + assert len(rows) == 0 + + async def test_init(self, connection: aiosqlite.Connection, settings: RequestSettings): + async with SQLiteTable(connection, settings=settings) as repository: + async with connection.execute( + f"SELECT name FROM sqlite_master WHERE type='table' AND name='{settings.name}'" + ) as cur: + rows = await cur.fetchall() + assert len(rows) == 1 + assert rows[0][0] == settings.name + + async with connection.execute(f"SELECT name FROM pragma_table_info('{settings.name}');") as cur: + columns = {row[0] async for row in cur} + assert {repository.name_column, repository.data_column, repository.expiry_column}.issubset(columns) + assert set(repository._primary_key_columns).issubset(columns) + + assert await repository.count() == 0 + + with pytest.raises(ValueError): + assert await repository.count() def test_serialize(self, repository: SQLiteTable): _, value = self.generate_item(repository.settings) @@ -136,68 +181,79 @@ def test_deserialize(self, repository: SQLiteTable): class TestSQLiteCache(SQLiteTester, ResponseCacheTester): @staticmethod - def generate_response(settings: RequestSettings) -> Response: + def generate_response(settings: RequestSettings, session: ClientSession = None) -> ClientResponse: key, value = TestSQLiteTable.generate_item(settings) - return TestSQLiteTable.generate_response_from_item(settings, key, value) + return TestSQLiteTable.generate_response_from_item(settings, key, value, session=session) - @classmethod - def generate_cache(cls, connection: sqlite3.Connection) -> SQLiteCache: - cache = SQLiteCache(cache_name="test", connection=connection) - cache.repository_getter = cls.get_repository_from_url - - for _ in range(randrange(5, 10)): - settings = cls.generate_settings() - items = dict(TestSQLiteTable.generate_item(settings) for _ in range(randrange(3, 6))) - - repository = SQLiteTable(settings=settings, connection=connection) - repository.update(items) - cache[settings.name] = repository + @staticmethod + def generate_repository(settings: RequestSettings, session: ClientSession = None) -> ClientResponse: + key, value = TestSQLiteTable.generate_item(settings) + return TestSQLiteTable.generate_response_from_item(settings, key, value, session=session) - return cache + @classmethod + @contextlib.asynccontextmanager + async def generate_cache(cls) -> SQLiteCache: + async with SQLiteCache( + cache_name="test", + connector=cls.generate_connection, + repository_getter=cls.get_repository_from_url, + ) as cache: + for _ in range(randrange(5, 10)): + settings = cls.generate_settings() + items = dict(TestSQLiteTable.generate_item(settings) for _ in range(randrange(3, 6))) + + repository = await SQLiteTable(settings=settings, connection=cache.connection) + for k, v in items.items(): + await repository._set_item_from_key_value_pair(k, v) + cache[settings.name] = repository + + await cache.commit() + assert await repository.count() == len(items) + + yield cache @staticmethod - def get_repository_from_url(cache: SQLiteCache, url: str) -> SQLiteCache | None: + def get_repository_from_url(cache: SQLiteCache, url: str | URL) -> SQLiteTable | None: + url = URL(url) for name, repository in cache.items(): - if name == urlparse(url).path.split("/")[-2]: + if name == url.path.split("/")[-2]: return repository @staticmethod - def get_db_path(cache: SQLiteCache) -> str: + async def get_db_path(cache: SQLiteCache) -> str: """Get the DB path from the connection associated with the given ``cache``.""" - cur = cache.connection.execute("PRAGMA database_list") - rows = cur.fetchall() + async with cache.connection.execute("PRAGMA database_list") as cur: + rows = await cur.fetchall() + assert len(rows) == 1 db_seq, db_name, db_path = rows[0] return db_path - def test_connect_with_path(self, tmp_path: Path): + async def test_connect_with_path(self, tmp_path: Path): fake_name = "not my real name" path = join(tmp_path, "test") expire = timedelta(weeks=42) - cache = SQLiteCache.connect_with_path(path, cache_name=fake_name, expire=expire) - - assert self.get_db_path(cache) == path + ".sqlite" - assert cache.cache_name != fake_name - assert cache.expire == expire + async with SQLiteCache.connect_with_path(path, cache_name=fake_name, expire=expire) as cache: + assert await self.get_db_path(cache) == path + ".sqlite" + assert cache.cache_name != fake_name + assert cache.expire == expire - def test_connect_with_in_memory_db(self): + async def test_connect_with_in_memory_db(self): fake_name = "not my real name" expire = timedelta(weeks=42) - cache = SQLiteCache.connect_with_in_memory_db(cache_name=fake_name, expire=expire) + async with SQLiteCache.connect_with_in_memory_db(cache_name=fake_name, expire=expire) as cache: + assert await self.get_db_path(cache) == "" + assert cache.cache_name != fake_name + assert cache.expire == expire - assert self.get_db_path(cache) == "" - assert cache.cache_name != fake_name - assert cache.expire == expire - - def test_connect_with_temp_db(self): + async def test_connect_with_temp_db(self): name = "this is my real name" path = join(gettempdir(), name) expire = timedelta(weeks=42) - cache = SQLiteCache.connect_with_temp_db(name, expire=expire) - - assert self.get_db_path(cache).endswith(path + ".sqlite") - assert cache.cache_name == name - assert cache.expire == expire + async with SQLiteCache.connect_with_temp_db(name, expire=expire) as cache: + assert (await self.get_db_path(cache)).endswith(path + ".sqlite") + assert cache.cache_name == name + assert cache.expire == expire diff --git a/tests/api/cache/backend/testers.py b/tests/api/cache/backend/testers.py index 3df2f75b..7f5bb645 100644 --- a/tests/api/cache/backend/testers.py +++ b/tests/api/cache/backend/testers.py @@ -1,13 +1,13 @@ +import sqlite3 from abc import ABC, abstractmethod from random import choice, randrange from typing import Any import pytest -from requests import Response +from aiohttp import ClientResponse, ClientSession -from musify.api.cache.backend.base import ResponseRepository, Connection, ResponseCache, RequestSettings +from musify.api.cache.backend.base import ResponseRepository, ResponseCache, RequestSettings from musify.api.exception import CacheError -from musify.exception import MusifyKeyError from tests.api.cache.backend.utils import MockRequestSettings, MockPaginatedRequestSettings from tests.utils import random_str @@ -22,13 +22,13 @@ class BaseResponseTester(ABC): @staticmethod @abstractmethod - def generate_connection() -> Connection: - """Generates a :py:class:`Connection` for this backend type.""" + def generate_connection() -> Any: + """Generates and yields a :py:class:`Connection` for this backend type.""" raise NotImplementedError @pytest.fixture - def connection(self) -> Connection: - """Yields a valid :py:class:`Connection` to use throughout tests in this suite as a pytest.fixture.""" + def connection(self) -> Any: + """Yields a valid :py:class:`Connection` to use throughout tests in this suite as a pytest_asyncio.fixture.""" return self.generate_connection() @staticmethod @@ -40,12 +40,25 @@ def generate_item(settings: RequestSettings) -> tuple[Any, Any]: """ raise NotImplementedError - @staticmethod + @classmethod @abstractmethod - def generate_response_from_item(settings: RequestSettings, key: Any, value: Any) -> Response: + def generate_response_from_item( + cls, settings: RequestSettings, key: Any, value: Any, session: ClientSession = None + ) -> ClientResponse: """ - Generates a :py:class:`Response` appropriate for the given ``settings`` from the given ``key`` and ``value`` - that can be persisted to the repository. + Generates a :py:class:`ClientResponse` appropriate for the given ``settings`` + from the given ``key`` and ``value`` that can be persisted to the repository. + """ + raise NotImplementedError + + @classmethod + @abstractmethod + def generate_bad_response_from_item( + cls, settings: RequestSettings, key: Any, value: Any, session: ClientSession = None + ) -> ClientResponse: + """ + Generates a bad :py:class:`ClientResponse` appropriate for the given ``settings`` + from the given ``key`` and ``value`` that can be persisted to the repository. """ raise NotImplementedError @@ -53,8 +66,6 @@ def generate_response_from_item(settings: RequestSettings, key: Any, value: Any) class ResponseRepositoryTester(BaseResponseTester, ABC): """Run generic tests for :py:class:`ResponseRepository` implementations.""" - bad_url = "test.com" - # noinspection PyArgumentList @pytest.fixture(scope="class", params=REQUEST_SETTINGS) def settings(self, request) -> RequestSettings: @@ -81,32 +92,37 @@ def items(self, valid_items: dict, invalid_items: dict) -> dict: return valid_items | invalid_items @abstractmethod - def repository(self, request, connection: Connection, valid_items: dict, invalid_items: dict) -> ResponseRepository: + async def repository( + self, connection: Any, settings: RequestSettings, valid_items: dict, invalid_items: dict + ) -> ResponseRepository: """ - Yields a valid :py:class:`ResponseRepository` to use throughout tests in this suite as a pytest.fixture. - Should produce a repository for each type of :py:class:`RequestSettings` type - as given by the request fixture. + Yields a valid :py:class:`ResponseRepository` to use throughout tests in this suite as a pytest_asyncio.fixture. + Populates this repository with ``valid_items`` and ``invalid_items``. """ raise NotImplementedError - @property - @abstractmethod - def connection_closed_exception(self) -> type[Exception]: - """Returns the exception class to expect when executing against a closed connection.""" - raise NotImplementedError + @staticmethod + async def test_close(repository: ResponseRepository): + key, _ = await anext(aiter(repository)) + await repository.close() - def test_close(self, repository: ResponseRepository): - key = next(iter(repository)) - repository.close() + with pytest.raises(ValueError): + await repository.get_response(key) - with pytest.raises(self.connection_closed_exception): - repository.get_response(key) + @staticmethod + async def test_count(repository: ResponseRepository, items: dict, valid_items: dict): + assert await repository.count() == len(items) + assert await repository.count(False) == len(valid_items) @staticmethod - def test_count(repository: ResponseRepository, items: dict, valid_items: dict): - assert len(repository) == len(valid_items) - assert repository.count() == len(items) - assert repository.count(False) == len(valid_items) + async def test_contains_and_clear(repository: ResponseRepository): + key, _ = await anext(aiter(repository)) + assert await repository.count() > 0 + assert await repository.contains(key) + + await repository.clear() + assert await repository.count() == 0 + assert not await repository.contains(key) @abstractmethod def test_serialize(self, repository: ResponseRepository): @@ -118,210 +134,175 @@ def test_deserialize(self, repository: ResponseRepository): def test_get_key_from_request(self, repository: ResponseRepository): key, value = self.generate_item(repository.settings) - request = self.generate_response_from_item(repository.settings, key, value).request + request = self.generate_response_from_item(repository.settings, key, value).request_info assert repository.get_key_from_request(request) == key def test_get_key_from_invalid_request(self, repository: ResponseRepository): key, value = self.generate_item(repository.settings) - request = self.generate_response_from_item(repository.settings, key, value).request - request.url = self.bad_url # should not return an ID for this URL format + request = self.generate_bad_response_from_item(repository.settings, key, value).request_info assert repository.get_key_from_request(request) is None - def test_get_response_on_missing(self, repository: ResponseRepository, valid_items: dict): - key, value = self.generate_item(repository.settings) - assert key not in repository - - with pytest.raises(MusifyKeyError): - assert repository[key] - - assert repository.get(key) is None - assert repository.get_response(key) is None - assert repository.get_responses(list(valid_items) + [key]) == list(valid_items.values()) - @staticmethod - def test_get_response_from_key(repository: ResponseRepository, valid_items: dict): + async def test_get_responses_from_keys(repository: ResponseRepository, valid_items: dict): key, value = choice(list(valid_items.items())) + assert await repository.get_response(key) == value + assert await repository.get_responses(valid_items.keys()) == list(valid_items.values()) - assert repository[key] == value - assert repository.get(key) == value - assert repository.get_response(key) == value + async def test_get_response_on_missing(self, repository: ResponseRepository, valid_items: dict): + key, value = self.generate_item(repository.settings) + assert not await repository.contains(key) - @staticmethod - def test_get_responses_from_keys(repository: ResponseRepository, valid_items: dict): - assert repository.get_responses(valid_items.keys()) == list(valid_items.values()) + assert await repository.get_response(key) is None + assert await repository.get_responses(list(valid_items) + [key]) == list(valid_items.values()) - def test_get_response_from_request(self, repository: ResponseRepository, valid_items: dict): + async def test_get_responses_from_requests(self, repository: ResponseRepository, valid_items: dict): key, value = choice(list(valid_items.items())) - request = self.generate_response_from_item(repository.settings, key, value).request - assert repository.get_response(request) == value + request = self.generate_response_from_item(repository.settings, key, value).request_info + assert await repository.get_response(request) == value - def test_get_responses_from_requests(self, repository: ResponseRepository, valid_items: dict): requests = [ - self.generate_response_from_item(repository.settings, key, value).request + self.generate_response_from_item(repository.settings, key, value).request_info for key, value in valid_items.items() ] - assert repository.get_responses(requests) == list(valid_items.values()) + assert await repository.get_responses(requests) == list(valid_items.values()) - def test_get_response_from_response(self, repository: ResponseRepository, valid_items: dict): + async def test_get_response_from_responses(self, repository: ResponseRepository, valid_items: dict): key, value = choice(list(valid_items.items())) response = self.generate_response_from_item(repository.settings, key, value) - assert repository.get_response(response) == value + assert await repository.get_response(response) == value - def test_get_response_from_response_on_missing(self, repository: ResponseRepository, valid_items: dict): - key, value = choice(list(valid_items.items())) - response = self.generate_response_from_item(repository.settings, key, value) - response.url = self.bad_url - response.request.url = self.bad_url - assert repository.get_response(response) is None - - def test_get_responses_from_responses(self, repository: ResponseRepository, valid_items: dict): responses = [ self.generate_response_from_item(repository.settings, key, value) for key, value in valid_items.items() ] - assert repository.get_responses(responses) == list(valid_items.values()) + assert await repository.get_responses(responses) == list(valid_items.values()) + + async def test_get_response_from_responses_on_missing(self, repository: ResponseRepository, valid_items: dict): + key, value = choice(list(valid_items.items())) + response = self.generate_bad_response_from_item(repository.settings, key, value) + assert await repository.get_response(response) is None - def test_get_responses_from_responses_on_missing(self, repository: ResponseRepository, valid_items: dict): responses = [ - self.generate_response_from_item(repository.settings, key, value) for key, value in valid_items.items() + self.generate_bad_response_from_item(repository.settings, key, value) for key, value in valid_items.items() ] - for response in responses: - response.url = self.bad_url - response.request.url = self.bad_url - assert repository.get_responses(responses) == [] + assert await repository.get_responses(responses) == [] - def test_save_response_from_key(self, repository: ResponseRepository): - key, value = self.generate_item(repository.settings) - assert key not in repository + async def test_set_item_from_key_value_pair(self, repository: ResponseRepository): + items = [self.generate_item(repository.settings) for _ in range(randrange(3, 6))] + assert all([not await repository.contains(key) for key, _ in items]) - repository[key] = value - assert key in repository - assert repository[key] == value + for key, value in items: + await repository._set_item_from_key_value_pair(key, value) - def test_save_responses_from_dict(self, repository: ResponseRepository): - items = dict(self.generate_item(repository.settings) for _ in range(randrange(3, 6))) - assert all(key not in repository for key in items) - - repository.update(items) - assert all(key in repository for key in items) - for key, value in items.items(): - assert repository[key] == value + assert all([await repository.contains(key) for key, _ in items]) + for key, value in items: + assert await repository.get_response(key) == value - def test_save_response(self, repository: ResponseRepository): + async def test_save_response_from_collection(self, repository: ResponseRepository): key, value = self.generate_item(repository.settings) - response = self.generate_response_from_item(repository.settings, key, value) - assert key not in repository + assert not await repository.contains(key) - repository.save_response(response) - assert key in repository - assert repository[key] == value + await repository.save_response((key, value)) + assert repository.contains(key) + assert await repository.get_response(key) == value - def test_save_response_fails_silently(self, repository: ResponseRepository): + async def test_save_response_from_response(self, repository: ResponseRepository): key, value = self.generate_item(repository.settings) response = self.generate_response_from_item(repository.settings, key, value) - assert key not in repository + assert not await repository.contains(key) - response.url = self.bad_url # should not return an ID for this URL format - response.request.url = response.url + await repository.save_response(response) + assert repository.contains(key) + assert await repository.get_response(key) == value - repository.save_response(response) - assert key not in repository + async def test_save_response_fails_silently(self, repository: ResponseRepository): + key, value = self.generate_item(repository.settings) + assert not await repository.contains(key) + + response = self.generate_bad_response_from_item(repository.settings, key, value) + await repository.save_response(response) + assert not await repository.contains(key) - def test_save_responses(self, repository: ResponseRepository): + async def test_save_responses_from_mapping(self, repository: ResponseRepository): items = dict(self.generate_item(repository.settings) for _ in range(randrange(3, 6))) - responses = [self.generate_response_from_item(repository.settings, key, value) for key, value in items.items()] - assert all(key not in repository for key in items) + assert all([not await repository.contains(key) for key in items]) - repository.save_responses(responses) - assert all(key in repository for key in items) + await repository.save_responses(items) + assert all([await repository.contains(key) for key in items]) for key, value in items.items(): - assert repository[key] == value + assert await repository.get_response(key) == value - def test_save_responses_fails_silently(self, repository: ResponseRepository): + async def test_save_responses_from_responses(self, repository: ResponseRepository): items = dict(self.generate_item(repository.settings) for _ in range(randrange(3, 6))) responses = [self.generate_response_from_item(repository.settings, key, value) for key, value in items.items()] - assert all(key not in repository for key in items) + assert all([not await repository.contains(key) for key in items]) - for response in responses: - response.url = self.bad_url # should not return an ID for this URL format - response.request.url = response.url - - repository.save_responses(responses) - assert all(key not in repository for key in items) + await repository.save_responses(responses) + assert all([await repository.contains(key) for key in items]) + for key, value in items.items(): + assert await repository.get_response(key) == value - def test_delete_response_on_missing(self, repository: ResponseRepository): - key, value = self.generate_item(repository.settings) - assert key not in repository + async def test_save_responses_fails_silently(self, repository: ResponseRepository): + items = dict(self.generate_item(repository.settings) for _ in range(randrange(3, 6))) + assert all([not await repository.contains(key) for key in items]) - with pytest.raises(MusifyKeyError): - del repository[key] + responses = [ + self.generate_bad_response_from_item(repository.settings, key, value) for key, value in items.items() + ] + await repository.save_responses(responses) + assert all([not await repository.contains(key) for key in items]) - repository.delete_response(key) + async def test_delete_response_on_missing(self, repository: ResponseRepository): + key, value = self.generate_item(repository.settings) + assert not await repository.contains(key) + assert not await repository.delete_response(key) @staticmethod - def test_delete_response_from_key(repository: ResponseRepository, valid_items: dict): + async def test_delete_response_from_key(repository: ResponseRepository, valid_items: dict): key, value = choice(list(valid_items.items())) - assert key in repository + assert await repository.contains(key) - del repository[key] - assert key not in repository + assert await repository.delete_response(key) + assert not await repository.contains(key) - def test_delete_response_from_request(self, repository: ResponseRepository, valid_items: dict): + async def test_delete_response_from_request(self, repository: ResponseRepository, valid_items: dict): key, value = choice(list(valid_items.items())) - request = self.generate_response_from_item(repository.settings, key, value).request - assert key in repository + request = self.generate_response_from_item(repository.settings, key, value).request_info + assert await repository.contains(key) - repository.delete_response(request) - assert key not in repository + assert await repository.delete_response(request) + assert not await repository.contains(key) - def test_delete_responses_from_requests(self, repository: ResponseRepository, valid_items: dict): + async def test_delete_responses_from_requests(self, repository: ResponseRepository, valid_items: dict): requests = [ - self.generate_response_from_item(repository.settings, key, value).request + self.generate_response_from_item(repository.settings, key, value).request_info for key, value in valid_items.items() ] for key in valid_items: - assert key in repository + assert await repository.contains(key) - repository.delete_responses(requests) + assert await repository.delete_responses(requests) == len(requests) for key in valid_items: - assert key not in repository + assert not await repository.contains(key) - def test_delete_response_from_response(self, repository: ResponseRepository, valid_items: dict): + async def test_delete_response_from_response(self, repository: ResponseRepository, valid_items: dict): key, value = choice(list(valid_items.items())) response = self.generate_response_from_item(repository.settings, key, value) - assert key in repository + assert await repository.contains(key) - repository.delete_response(response) - assert key not in repository + assert await repository.delete_response(response) + assert not await repository.contains(key) - def test_delete_responses_from_responses(self, repository: ResponseRepository, valid_items: dict): + async def test_delete_responses_from_responses(self, repository: ResponseRepository, valid_items: dict): responses = [ self.generate_response_from_item(repository.settings, key, value) for key, value in valid_items.items() ] for key in valid_items: - assert key in repository + assert await repository.contains(key) - repository.delete_responses(responses) + assert await repository.delete_responses(responses) == len(responses) for key in valid_items: - assert key not in repository - - def test_mapping_functionality(self, repository: ResponseRepository, valid_items: dict): - key, value = choice(list(repository.items())) - assert key in repository - assert all(key in valid_items for key in repository) - - repository.pop(key) - assert key not in repository - - items = dict(self.generate_item(repository.settings) for _ in range(randrange(3, 6))) - assert all(key not in repository for key in items) - repository.update(items) - assert all(key in repository for key in items) - - repository.clear() - assert key not in repository - with pytest.raises(MusifyKeyError): - assert repository[key] + assert not await repository.contains(key) class ResponseCacheTester(BaseResponseTester, ABC): @@ -336,16 +317,16 @@ def generate_settings() -> RequestSettings: @staticmethod @abstractmethod - def generate_response(settings: RequestSettings) -> Response: + def generate_response(settings: RequestSettings, session: ClientSession = None) -> ClientResponse: """ - Randomly generates a :py:class:`Response` appropriate for the given ``settings`` + Randomly generates a :py:class:`ClientResponse` appropriate for the given ``settings`` that can be persisted to the repository. """ raise NotImplementedError @classmethod @abstractmethod - def generate_cache(cls, connection: Connection) -> ResponseCache: + async def generate_cache(cls) -> ResponseCache: """ Generates a :py:class:`ResponseCache` for this backend type with many randomly generated :py:class:`ResponseRepository` objects assigned @@ -353,10 +334,12 @@ def generate_cache(cls, connection: Connection) -> ResponseCache: """ raise NotImplementedError + # noinspection PyTestUnpassedFixture @pytest.fixture - def cache(self, connection: Connection) -> ResponseCache: + async def cache(self) -> ResponseCache: """Yields a valid :py:class:`ResponseCache` to use throughout tests in this suite as a pytest.fixture.""" - return self.generate_cache(connection) + async with self.generate_cache() as cache: + yield cache @staticmethod @abstractmethod @@ -364,18 +347,36 @@ def get_repository_from_url(cache: ResponseCache, url: str) -> ResponseCache: """Returns a repository for the given ``url`` from the given ``cache``.""" raise NotImplementedError - @staticmethod - def test_close(cache: ResponseCache): - key = choice(list(cache.values())) - cache.close() + async def test_init(self, cache: ResponseCache): + assert cache.values() + for repository in cache.values(): + assert await repository.count() + + async def test_context_management(self, cache: ResponseCache): + # does not create repository backend resource until awaited or entered + settings = self.generate_settings() + assert settings.name not in cache + repository = cache.create_repository(settings) + + with pytest.raises(sqlite3.OperationalError): + await repository.count() + await cache + await repository.count() + + settings = self.generate_settings() + assert settings.name not in cache + repository = cache.create_repository(settings) - with pytest.raises(Exception): - cache.get_response(key) + with pytest.raises(sqlite3.OperationalError): + await repository.count() + async with cache: + await repository.count() - def test_create_repository(self, cache: ResponseCache): + async def test_create_repository(self, cache: ResponseCache): settings = self.generate_settings() assert settings.name not in cache + # noinspection PyAsyncCall cache.create_repository(settings) assert settings.name in cache assert cache[settings.name].settings == settings @@ -383,11 +384,12 @@ def test_create_repository(self, cache: ResponseCache): # does not create a repository that already exists repository = choice(list(cache.values())) with pytest.raises(CacheError): + # noinspection PyAsyncCall cache.create_repository(repository.settings) def test_get_repository_for_url(self, cache: ResponseCache): repository = choice(list(cache.values())) - url = self.generate_response(repository.settings).request.url + url = self.generate_response(repository.settings).request_info.url assert cache.get_repository_from_url(url).settings.name == repository.settings.name assert cache.get_repository_from_url(f"http://www.does-not-exist.com/{random_str()}/{random_str()}") is None @@ -396,7 +398,7 @@ def test_get_repository_for_url(self, cache: ResponseCache): def test_get_repository_for_requests(self, cache: ResponseCache): repository = choice(list(cache.values())) - requests = [self.generate_response(repository.settings).request for _ in range(3, 6)] + requests = [self.generate_response(repository.settings).request_info for _ in range(3, 6)] cache.get_repository_from_requests(requests) def test_get_repository_for_responses(self, cache: ResponseCache): @@ -414,15 +416,15 @@ def test_get_repository_for_responses(self, cache: ResponseCache): cache.repository_getter = None assert cache.get_repository_from_requests(responses) is None - def test_repository_operations(self, cache: ResponseCache): + async def test_repository_operations(self, cache: ResponseCache): repository = choice(list(cache.values())) response = self.generate_response(repository.settings) - key = repository.get_key_from_request(response.request) - cache.save_response(response) - assert key in repository + key = repository.get_key_from_request(response.request_info) + await cache.save_response(response) + assert await repository.contains(key) - assert cache.get_response(response) == repository.deserialize(response.text) + assert await cache.get_response(response) == repository.deserialize(await response.text()) - cache.delete_response(response) - assert key not in repository + assert await cache.delete_response(response) + assert not await repository.contains(key) diff --git a/tests/api/cache/backend/utils.py b/tests/api/cache/backend/utils.py index a2120ac1..ec0a8ff7 100644 --- a/tests/api/cache/backend/utils.py +++ b/tests/api/cache/backend/utils.py @@ -1,38 +1,42 @@ -import json from typing import Any -from urllib.parse import parse_qs, urlparse -from requests import Response +from yarl import URL -from musify.api.cache.backend.base import RequestSettings, PaginatedRequestSettings +from musify.api.cache.backend.base import RequestSettings class MockRequestSettings(RequestSettings): - @staticmethod - def get_name(value: Any) -> str | None: - if isinstance(value, dict): - return value.get("name") - elif isinstance(value, Response): - try: - return value.json().get("name") - except json.decoder.JSONDecodeError: - pass + @property + def fields(self) -> tuple[str, ...]: + return "id", + + def get_key(self, url: str | URL, *_, **__) -> tuple[str | None, ...]: + if str(url).endswith(".com"): + return (None,) + return URL(url).path.split("/")[-1] or None, @staticmethod - def get_id(url: str) -> str | None: - if "/" not in url: - return - return urlparse(url).path.split("/")[-1] + def get_name(response: dict[str, Any]) -> str | None: + return response.get("name") + + +class MockPaginatedRequestSettings(MockRequestSettings): + + @property + def fields(self) -> tuple[str, ...]: + return *super().fields, "offset", "size" + def get_key(self, url: str | URL, *_, **__) -> tuple[str | int | None, ...]: + base = super().get_key(url=url) + return *base, self.get_offset(url), self.get_limit(url) -class MockPaginatedRequestSettings(MockRequestSettings, PaginatedRequestSettings): @staticmethod - def get_offset(url: str) -> int: - params = parse_qs(urlparse(url).query) - return int(params.get("offset", [0])[0]) + def get_offset(url: str | URL) -> int: + params = URL(url).query + return int(params.get("offset", 0)) @staticmethod - def get_limit(url: str) -> int: - params = parse_qs(urlparse(url).query) - return int(params.get("size", [0])[0]) + def get_limit(url: str | URL) -> int: + params = URL(url).query + return int(params.get("limit", 0)) diff --git a/tests/api/cache/test_response.py b/tests/api/cache/test_response.py new file mode 100644 index 00000000..d3c237b9 --- /dev/null +++ b/tests/api/cache/test_response.py @@ -0,0 +1,37 @@ +import json +from typing import Any + +import pytest +from aiohttp import ClientRequest +from yarl import URL + +from musify.api.cache.response import CachedResponse + + +class TestCachedResponse: + + @pytest.fixture(scope="class") + def http_request(self) -> ClientRequest: + """Yields a basic :py:class:`ClientRequest` as a pytest.fixture.""" + return ClientRequest( + method="GET", url=URL("https://www.test.com"), headers={"Content-Type": "application/json"} + ) + + @pytest.fixture(scope="class") + def data(self) -> dict[str, Any]: + """Yields the expected payload dict response for a given request as a pytest.fixture.""" + return { + "1": "val1", + "2": "val2", + "3": "val3", + } + + @pytest.fixture + def http_response(self, http_request: ClientRequest, data: dict[str, Any]) -> CachedResponse: + """Yields the expected response for a given request as a pytest.fixture.""" + return CachedResponse(request=http_request, data=json.dumps(data)) + + async def test_read(self, http_response: CachedResponse, data: dict[str, Any]): + assert await http_response.read() == json.dumps(data).encode() + assert await http_response.text() == json.dumps(data) + assert await http_response.json() == data diff --git a/tests/api/cache/test_session.py b/tests/api/cache/test_session.py index 41f9a933..e558afd2 100644 --- a/tests/api/cache/test_session.py +++ b/tests/api/cache/test_session.py @@ -1,12 +1,15 @@ from random import choice +from typing import Any import pytest -from requests_mock import Mocker +from aioresponses import aioresponses -from musify.api.cache.backend.base import ResponseCache, Connection +from musify.api.cache.backend.base import ResponseCache from musify.api.cache.session import CachedSession from tests.api.cache.backend.test_sqlite import TestSQLiteCache as SQLiteCacheTester from tests.api.cache.backend.testers import ResponseCacheTester +from tests.api.cache.backend.utils import MockRequestSettings +from tests.utils import random_str class TestCachedSession: @@ -15,66 +18,89 @@ class TestCachedSession: def tester(self, request) -> ResponseCacheTester: return request.param - @pytest.fixture(scope="class") - def connection(self, tester: ResponseCacheTester) -> Connection: + @pytest.fixture + def connection(self, tester: ResponseCacheTester) -> Any: """Yields a valid :py:class:`Connection` to use throughout tests in this suite as a pytest.fixture.""" return tester.generate_connection() - @pytest.fixture(scope="class") - def cache(self, tester: ResponseCacheTester, connection: Connection) -> ResponseCache: + # noinspection PyTestUnpassedFixture + @pytest.fixture + async def cache(self, tester: ResponseCacheTester) -> ResponseCache: """Yields a valid :py:class:`ResponseCache` to use throughout tests in this suite as a pytest.fixture.""" - return tester.generate_cache(connection=connection) + async with tester.generate_cache() as cache: + yield cache - @pytest.fixture(scope="class") - def session(self, cache: ResponseCache) -> CachedSession: + @pytest.fixture + async def session(self, cache: ResponseCache) -> CachedSession: """ Yields a valid :py:class:`CachedSession` with the given ``cache`` to use throughout tests in this suite as a pytest.fixture. """ - return CachedSession(cache=cache) + async with CachedSession(cache=cache) as session: + yield session - def test_request_cached( + async def test_context_management(self, cache: ResponseCache): + # does not create repository backend resource until entered + settings = MockRequestSettings(name=random_str(20, 30)) + session = CachedSession(cache=cache) + repository = cache.create_repository(settings) + + with pytest.raises(Exception): + await repository.count() + async with session: + await repository.count() + + async def test_request_cached( self, session: CachedSession, cache: ResponseCache, tester: ResponseCacheTester, - requests_mock: Mocker + requests_mock: aioresponses ): repository = choice(list(cache.values())) - key, value = choice(list(repository.items())) - expected = tester.generate_response_from_item(repository.settings, key, value) - request = expected.request - assert key in repository + key, value = choice([(k, v) async for k, v in repository]) + assert repository.contains(key) + + expected = tester.generate_response_from_item(repository.settings, key, value, session=session) + request = expected.request_info + headers = {"Content-Type": "application/json"} - response = session.request(method=request.method, url=request.url) - assert response.json() == expected.json() - assert len(requests_mock.request_history) == 0 + async with session.request(method=request.method, url=request.url, headers=headers) as response: + assert await response.json() == await expected.json() + requests_mock.assert_not_called() - def test_request_not_cached( + async def test_request_not_cached( self, session: CachedSession, cache: ResponseCache, tester: ResponseCacheTester, - requests_mock: Mocker + requests_mock: aioresponses, ): repository = choice(list(cache.values())) - expected = tester.generate_response(repository.settings) - request = expected.request + + expected = tester.generate_response(repository.settings, session=session) + request = expected.request_info + headers = {"Content-Type": "application/json"} + key = repository.get_key_from_request(request) - assert key not in repository - requests_mock.get(request.url, json=expected.json()) - - response = session.request(method=request.method, url=request.url, persist=False) - assert response.json() == expected.json() - assert len(requests_mock.request_history) == 1 - assert key not in repository - - response = session.request(method=request.method, url=request.url, persist=True) - assert response.text == expected.text - assert len(requests_mock.request_history) == 2 - assert key in repository - - response = session.request(method=request.method, url=request.url) - assert response.json() == expected.json() - assert len(requests_mock.request_history) == 2 + assert repository.contains(key) + + requests_mock.get(request.url, body=await expected.text(), repeat=True) + + async with session.request(method=request.method, url=request.url, headers=headers, persist=False) as response: + assert await response.json() == await expected.json() + assert len(requests_mock.requests) == 1 + assert sum(len(reqs) for reqs in requests_mock.requests.values()) == 1 + assert not await repository.contains(key) + + async with session.request(method=request.method, url=request.url, headers=headers, persist=True) as response: + assert await response.text() == await expected.text() + assert len(requests_mock.requests) == 1 + assert sum(len(reqs) for reqs in requests_mock.requests.values()) == 2 + assert await repository.contains(key) + + async with session.request(method=request.method, url=request.url, headers=headers) as response: + assert await response.json() == await expected.json() + assert len(requests_mock.requests) == 1 + assert sum(len(reqs) for reqs in requests_mock.requests.values()) == 2 diff --git a/tests/api/test_authorise.py b/tests/api/test_authorise.py index 08d46d29..0d583d6a 100644 --- a/tests/api/test_authorise.py +++ b/tests/api/test_authorise.py @@ -1,15 +1,17 @@ import json import os +import re import socket from datetime import datetime, timedelta from os.path import join from pathlib import Path from typing import Any -from urllib.parse import urlparse, parse_qs +from urllib.parse import unquote import pytest +from aioresponses import aioresponses from pytest_mock import MockerFixture -from requests_mock import Mocker +from yarl import URL from musify import MODULE_ROOT from musify.api.authorise import APIAuthoriser @@ -94,7 +96,7 @@ def test_save_token(self, authoriser: APIAuthoriser, token: dict[str, Any], tmp_ assert token == token_saved - def test_user_auth(self, authoriser: APIAuthoriser, mocker: MockerFixture, requests_mock: Mocker): + async def test_user_auth(self, authoriser: APIAuthoriser, mocker: MockerFixture, requests_mock: aioresponses): user_url = f"http://{APIAuthoriser._user_auth_socket_address}:{APIAuthoriser._user_auth_socket_port + 1}" authoriser.auth_args = {"url": ""} authoriser.user_args = {"url": user_url} @@ -106,66 +108,66 @@ def test_user_auth(self, authoriser: APIAuthoriser, mocker: MockerFixture, reque def check_url(url: str): """Check the URL given to the webopen call""" assert url.startswith(user_url) - assert parse_qs(urlparse(url).query)["redirect_uri"][0] == redirect_uri + assert unquote(URL(url).query["redirect_uri"]) == redirect_uri socket_listener = socket.socket(socket.AF_INET, socket.SOCK_STREAM) - requests_mock.post(user_url) + requests_mock.post(re.compile(user_url)) mocker.patch(f"{MODULE_ROOT}.api.authorise.webopen", new=check_url) mocker.patch.object(socket.socket, attribute="accept", return_value=(socket_listener, None)) mocker.patch.object(socket.socket, attribute="send") mocker.patch.object(socket.socket, attribute="recv", return_value=response.encode("utf-8")) - authoriser._authorise_user() + await authoriser._authorise_user() assert authoriser.auth_args["data"]["code"] == code assert authoriser.auth_args["data"]["redirect_uri"] == redirect_uri - def test_request_token_1(self, authoriser: APIAuthoriser, token: dict[str, Any], requests_mock: Mocker): + async def test_request_token_1(self, authoriser: APIAuthoriser, token: dict[str, Any], requests_mock: aioresponses): authoriser.auth_args = { "url": f"http://{APIAuthoriser._user_auth_socket_address}:{APIAuthoriser._user_auth_socket_port + 1}", "data": {"grant_type": "authorization_code", "code": None}, } - requests_mock.post(authoriser.auth_args["url"], json=token) + requests_mock.post(authoriser.auth_args["url"], payload=token) - result = authoriser._request_token(**authoriser.auth_args) + result = await authoriser._request_token(**authoriser.auth_args) assert {k: v for k, v in result.items() if k not in self.refresh_test_keys} == token assert "granted_at" in result assert "expires_at" not in result assert "refresh_token" not in result - def test_request_token_2(self, authoriser: APIAuthoriser, token: dict[str, Any], requests_mock: Mocker): + async def test_request_token_2(self, authoriser: APIAuthoriser, token: dict[str, Any], requests_mock: aioresponses): authoriser.auth_args = { "url": f"http://{APIAuthoriser._user_auth_socket_address}:{APIAuthoriser._user_auth_socket_port + 1}", "data": {"grant_type": "authorization_code", "code": None}, } authoriser.token = {"refresh_token": "new token"} expires_in_token = {"expires_in": 3600} - requests_mock.post(authoriser.auth_args["url"], json=token | expires_in_token) + requests_mock.post(authoriser.auth_args["url"], payload=token | expires_in_token) - result = authoriser._request_token(**authoriser.auth_args) + result = await authoriser._request_token(**authoriser.auth_args) assert {k: v for k, v in result.items() if k not in self.refresh_test_keys} == token | expires_in_token assert "granted_at" in result assert "expires_at" in result assert result["refresh_token"] == authoriser.token["refresh_token"] - def test_request_token_3(self, authoriser: APIAuthoriser, token: dict[str, Any], requests_mock: Mocker): + async def test_request_token_3(self, authoriser: APIAuthoriser, token: dict[str, Any], requests_mock: aioresponses): authoriser.auth_args = { "url": f"http://{APIAuthoriser._user_auth_socket_address}:{APIAuthoriser._user_auth_socket_port + 1}", "data": {"grant_type": "authorization_code", "code": None}, } response = token | {"refresh_token": "received token"} - requests_mock.post(authoriser.auth_args["url"], json=response) + requests_mock.post(authoriser.auth_args["url"], payload=response) - result = authoriser._request_token(**authoriser.auth_args) + result = await authoriser._request_token(**authoriser.auth_args) assert {k: v for k, v in result.items() if k not in self.refresh_test_keys} == token assert "granted_at" in result assert "expires_at" not in result assert result["refresh_token"] == response["refresh_token"] - def test_token_test(self, authoriser: APIAuthoriser, token: dict[str, Any]): + async def test_token_test(self, authoriser: APIAuthoriser, token: dict[str, Any]): authoriser.token = {"expires_at": (datetime.now() + timedelta(seconds=3000)).timestamp()} authoriser.test_expiry = 1500 - assert authoriser.test_token() + assert await authoriser.test_token() def test_error_test(self, authoriser: APIAuthoriser, token: dict[str, Any]): authoriser.token = {"error": "error message"} @@ -174,8 +176,8 @@ def test_error_test(self, authoriser: APIAuthoriser, token: dict[str, Any]): authoriser.token = token assert authoriser._test_no_error() - def test_valid_response(self, authoriser: APIAuthoriser, token: dict[str, Any], requests_mock: Mocker): - assert authoriser._test_valid_response() + async def test_valid_response(self, authoriser: APIAuthoriser, token: dict[str, Any], requests_mock: aioresponses): + assert await authoriser._test_valid_response() authoriser.header_key = "Authorization" authoriser.header_prefix = "Bearer " @@ -184,16 +186,16 @@ def test_valid_response(self, authoriser: APIAuthoriser, token: dict[str, Any], authoriser.test_args = {"url": "http://locahost/test"} authoriser.test_condition = lambda x: x["test_result"] == "valid" - requests_mock.get(authoriser.test_args["url"], json={"test_result": "valid"}) - assert authoriser._test_valid_response() + requests_mock.get(authoriser.test_args["url"], payload={"test_result": "valid"}) + assert await authoriser._test_valid_response() authoriser.test_condition = lambda x: x == "valid" - requests_mock.get(authoriser.test_args["url"], text="valid") - assert authoriser._test_valid_response() + requests_mock.get(authoriser.test_args["url"], body="valid") + assert await authoriser._test_valid_response() - requests_mock.get(authoriser.test_args["url"], text="invalid") - assert not authoriser._test_valid_response() + requests_mock.get(authoriser.test_args["url"], body="invalid") + assert not await authoriser._test_valid_response() def test_expiry(self, authoriser: APIAuthoriser): authoriser.token = {"expires_at": (datetime.now() + timedelta(seconds=3000)).timestamp()} @@ -211,18 +213,18 @@ def test_expiry(self, authoriser: APIAuthoriser): authoriser.test_expiry = 2000 assert not authoriser._test_expiry() - def test_auth_new_token(self, token: dict[str, Any], token_file_path: str, requests_mock: Mocker): + async def test_auth_new_token(self, token: dict[str, Any], token_file_path: str, requests_mock: aioresponses): authoriser = APIAuthoriser(name="test", auth_args={"url": "http://localhost/auth"}, test_expiry=1000) response = {"access_token": "valid token", "expires_in": 3000, "refresh_token": "new_refresh"} - requests_mock.post(authoriser.auth_args["url"], json=response) + requests_mock.post(authoriser.auth_args["url"], payload=response) - authoriser() + await authoriser.authorise() expected_header = {"Authorization": "Bearer valid token"} assert authoriser.headers == expected_header assert authoriser.token["refresh_token"] == "new_refresh" - def test_auth_load_and_token_valid(self, token_file_path: str, requests_mock: Mocker): + async def test_auth_load_and_token_valid(self, token_file_path: str, requests_mock: aioresponses): authoriser = APIAuthoriser( name="test", test_args={"url": "http://localhost/test"}, @@ -230,14 +232,14 @@ def test_auth_load_and_token_valid(self, token_file_path: str, requests_mock: Mo token_file_path=token_file_path, ) - requests_mock.get(authoriser.test_args["url"], json={"test": "valid"}) + requests_mock.get(authoriser.test_args["url"], payload={"test": "valid"}) # loads token, token is valid, no refresh needed - authoriser() + await authoriser.authorise() expected_header = {"Authorization": f"Bearer {authoriser.token["access_token"]}"} assert authoriser.headers == expected_header - def test_auth_force_load_and_token_valid(self, token_file_path: str): + async def test_auth_force_load_and_token_valid(self, token_file_path: str): authoriser = APIAuthoriser( name="test", token={"this token": "is not valid"}, @@ -248,32 +250,36 @@ def test_auth_force_load_and_token_valid(self, token_file_path: str): ) # force load from json despite being given token - authoriser(force_load=True) + await authoriser.authorise(force_load=True) expected_header = {"new_key": f"prefix - {authoriser.token["access_token"]}"} assert authoriser.headers == expected_header | authoriser.header_extra - def test_auth_force_new_and_no_args(self, token: dict[str, Any], token_file_path: str): + async def test_auth_force_new_and_no_args(self, token: dict[str, Any], token_file_path: str): authoriser = APIAuthoriser(name="test", token=token, token_file_path=token_file_path) # force new despite being given token and token file path with pytest.raises(APIError): - authoriser(force_new=True) + await authoriser.authorise(force_new=True) - def test_auth_new_token_and_no_refresh(self, token: dict[str, Any], token_file_path: str, requests_mock: Mocker): + async def test_auth_new_token_and_no_refresh( + self, token: dict[str, Any], token_file_path: str, requests_mock: aioresponses + ): authoriser = APIAuthoriser( name="test", auth_args={"url": "http://localhost/auth"}, token_key_path=["1", "2", "code"] ) - requests_mock.post(authoriser.auth_args["url"], json={"1": {"2": {"code": "token"}}}) + requests_mock.post(authoriser.auth_args["url"], payload={"1": {"2": {"code": "token"}}}) - authoriser() + await authoriser.authorise() expected_header = {"Authorization": "Bearer token"} assert authoriser.headers == expected_header - def test_auth_new_token_and_refresh_valid(self, token: dict[str, Any], token_file_path: str, requests_mock: Mocker): + async def test_auth_new_token_and_refresh_valid( + self, token: dict[str, Any], token_file_path: str, requests_mock: aioresponses + ): authoriser = APIAuthoriser( name="test", refresh_args={"url": "http://localhost/refresh"}, @@ -283,15 +289,15 @@ def test_auth_new_token_and_refresh_valid(self, token: dict[str, Any], token_fil ) response = {"get_token": "valid token", "expires_in": 3000, "refresh_token": "new_refresh"} - requests_mock.post(authoriser.refresh_args["url"], json=response) + requests_mock.post(authoriser.refresh_args["url"], payload=response, repeat=True) - authoriser() + await authoriser.authorise() expected_header = {"Authorization": "Bearer valid token"} assert authoriser.headers == expected_header assert authoriser.token["refresh_token"] == "new_refresh" - def test_auth_new_token_and_refresh_invalid( - self, token: dict[str, Any], token_file_path: str, requests_mock: Mocker + async def test_auth_new_token_and_refresh_invalid( + self, token: dict[str, Any], token_file_path: str, requests_mock: aioresponses ): authoriser = APIAuthoriser( name="test", @@ -302,22 +308,22 @@ def test_auth_new_token_and_refresh_invalid( ) response = {"expires_in": 10} - requests_mock.post(authoriser.refresh_args["url"], json=response) + requests_mock.post(authoriser.refresh_args["url"], payload=response, repeat=True) with pytest.raises(APIError): - authoriser() + await authoriser.authorise() authoriser.auth_args = {"url": "http://localhost/auth"} response = {"get_token": "valid token", "expires_in": 20, "refresh_token": "new_refresh"} - requests_mock.post(authoriser.auth_args["url"], json=response) + requests_mock.post(authoriser.auth_args["url"], payload=response) with pytest.raises(APIError): - authoriser() + await authoriser.authorise() response = {"get_token": "valid token", "expires_in": 3000, "refresh_token": "new_refresh"} - requests_mock.post(authoriser.auth_args["url"], json=response) + requests_mock.post(authoriser.auth_args["url"], payload=response) - authoriser() + await authoriser.authorise() expected_header = {"Authorization": "Bearer valid token"} assert authoriser.headers == expected_header assert authoriser.token["refresh_token"] == "new_refresh" diff --git a/tests/api/test_request.py b/tests/api/test_request.py index 33ddc88d..c321ff1e 100644 --- a/tests/api/test_request.py +++ b/tests/api/test_request.py @@ -1,20 +1,18 @@ import json from typing import Any +import aiohttp import pytest -import requests -from requests import Response -from requests_mock import Mocker -# noinspection PyProtectedMember,PyUnresolvedReferences -from requests_mock.request import _RequestObjectProxy as Request -# noinspection PyProtectedMember,PyUnresolvedReferences -from requests_mock.response import _Context as Context +from aiohttp import ClientRequest +from aioresponses import aioresponses, CallbackResult +from yarl import URL from musify.api.authorise import APIAuthoriser from musify.api.cache.backend.base import ResponseCache from musify.api.cache.backend.sqlite import SQLiteCache +from musify.api.cache.response import CachedResponse from musify.api.cache.session import CachedSession -from musify.api.exception import APIError +from musify.api.exception import APIError, RequestError from musify.api.request import RequestHandler from tests.api.cache.backend.utils import MockRequestSettings @@ -22,19 +20,26 @@ class TestRequestHandler: @pytest.fixture - def authoriser(self, token: dict[str, Any]) -> APIAuthoriser: - """Yield a simple :py:class:`APIAuthoriser` object""" - return APIAuthoriser(name="test", token=token) + def url(self) -> URL: + """Yield a simple :py:class:`URL` object""" + return URL("http://test.com") @pytest.fixture def cache(self) -> ResponseCache: """Yield a simple :py:class:`ResponseCache` object""" return SQLiteCache.connect_with_in_memory_db() + @pytest.fixture + def authoriser(self, token: dict[str, Any]) -> APIAuthoriser: + """Yield a simple :py:class:`APIAuthoriser` object""" + return APIAuthoriser(name="test", token=token) + @pytest.fixture def request_handler(self, authoriser: APIAuthoriser, cache: ResponseCache) -> RequestHandler: """Yield a simple :py:class:`RequestHandler` object""" - return RequestHandler(authoriser=authoriser, cache=cache) + return RequestHandler.create( + authoriser=authoriser, cache=cache, headers={"Content-Type": "application/json"} + ) @pytest.fixture def token(self) -> dict[str, Any]: @@ -45,133 +50,145 @@ def token(self) -> dict[str, Any]: "scope": "test-read" } - # noinspection PyTestUnpassedFixture - def test_init(self, token: dict[str, Any], authoriser: APIAuthoriser, cache: ResponseCache): - request_handler = RequestHandler(authoriser=authoriser) - assert request_handler.authoriser.token == token - assert not isinstance(request_handler.session, CachedSession) + async def test_init(self, token: dict[str, Any], authoriser: APIAuthoriser, cache: ResponseCache): + handler = RequestHandler.create(authoriser=authoriser, cache=cache) + assert handler.authoriser.token == token + assert not isinstance(handler.session, CachedSession) + + handler = RequestHandler.create(authoriser=authoriser, cache=cache) + assert handler.closed - request_handler = RequestHandler(authoriser=authoriser, cache=cache) - assert isinstance(request_handler.session, CachedSession) + async def test_context_management(self, request_handler: RequestHandler): + with pytest.raises(RequestError): + await request_handler.authorise() - request_handler.authorise() - for k, v in request_handler.authoriser.headers.items(): - assert request_handler.session.headers.get(k) == v + async with request_handler as handler: + assert isinstance(handler.session, CachedSession) - def test_context_management(self, authoriser: APIAuthoriser): - with RequestHandler(authoriser=authoriser) as handler: for k, v in handler.authoriser.headers.items(): assert handler.session.headers.get(k) == v - def test_check_response_codes(self, request_handler: RequestHandler): - response = Response() + async def test_check_response_codes(self, request_handler: RequestHandler, url: URL): + headers = {"Content-Type": "application/json"} + request = ClientRequest(method="GET", url=url, headers=headers) # error message not found, no fail - response.status_code = 201 - assert not request_handler._handle_unexpected_response(response=response) + response = CachedResponse(request=request, data="") + response.status = 201 + assert not await request_handler._handle_unexpected_response(response=response) # error message found, no fail expected = {"error": {"message": "request failed"}} - response._content = json.dumps(expected).encode() - assert request_handler._handle_unexpected_response(response=response) + response = CachedResponse(request=request, data=json.dumps(expected)) + assert await request_handler._handle_unexpected_response(response=response) # error message not found, raises exception - response.status_code = 400 + response.status = 400 with pytest.raises(APIError): - request_handler._handle_unexpected_response(response=response) - - def test_check_for_wait_time(self, request_handler: RequestHandler): - response = Response() + await request_handler._handle_unexpected_response(response=response) + async def test_check_for_wait_time(self, request_handler: RequestHandler, url: URL): # no header - assert not request_handler._handle_wait_time(response=response) + request = ClientRequest(method="GET", url=URL("http://test.com")) + response = CachedResponse(request, data="") + assert not await request_handler._handle_wait_time(response=response) # expected key not in headers - response.headers = {"header key": "header value"} - assert not request_handler._handle_wait_time(response=response) + headers = {"header key": "header value"} + request = ClientRequest(method="GET", url=url, headers=headers) + response = CachedResponse(request, data="") + assert not await request_handler._handle_wait_time(response=response) # expected key in headers and time is short - response.headers = {"retry-after": "1"} + headers = {"retry-after": "1"} + request = ClientRequest(method="GET", url=url, headers=headers) + response = CachedResponse(request, data="") assert request_handler.timeout >= 1 - assert request_handler._handle_wait_time(response=response) + assert await request_handler._handle_wait_time(response=response) # expected key in headers and time too long - response.headers = {"retry-after": "2000"} + headers = {"retry-after": "2000"} + request = ClientRequest(method="GET", url=url, headers=headers) + response = CachedResponse(request, data="") assert request_handler.timeout < 2000 with pytest.raises(APIError): - request_handler._handle_wait_time(response=response) + await request_handler._handle_wait_time(response=response) - def test_response_as_json(self, request_handler: RequestHandler): - response = Response() - response._content = "simple text should not be returned".encode() - assert request_handler._response_as_json(response) == {} + async def test_response_as_json(self, request_handler: RequestHandler, url: URL): + request = ClientRequest(method="GET", url=url, headers={"Content-Type": "application/json"}) + response = CachedResponse(request=request, data="simple text should not be returned") + assert await request_handler._response_as_json(response) == {} expected = {"key": "valid json"} - response._content = json.dumps(expected).encode() - assert request_handler._response_as_json(response) == expected + response = CachedResponse(request=request, data=json.dumps(expected)) + assert await request_handler._response_as_json(response) == expected - def test_cache_usage(self, request_handler: RequestHandler, requests_mock: Mocker): - test_url = "http://localhost/test" + # noinspection PyTestUnpassedFixture + async def test_cache_usage(self, request_handler: RequestHandler, requests_mock: aioresponses): + url = "http://localhost/test" expected_json = {"key": "value"} - requests_mock.get(test_url, json=expected_json) + requests_mock.get(url, payload=expected_json, repeat=True) - repository = request_handler.cache.create_repository(MockRequestSettings(name="test")) - request_handler.cache.repository_getter = lambda _, __: repository + async with request_handler as handler: + repository = await handler.session.cache.create_repository(MockRequestSettings(name="test")) + handler.session.cache.repository_getter = lambda _, __: repository - response = request_handler._request(method="GET", url=test_url, persist=False) - assert response.json() == expected_json - assert requests_mock.call_count == 1 + async with handler._request(method="GET", url=url, persist=False) as response: + assert await response.json() == expected_json + requests_mock.assert_called_once() - key = repository.get_key_from_request(response.request) - assert repository.get_response(key) is None + key = repository.get_key_from_request(response.request_info) + assert await repository.get_response(key) is None - response = request_handler._request(method="GET", url=test_url, persist=True) - assert response.json() == expected_json - assert requests_mock.call_count == 2 - assert repository.get_response(key) + async with handler._request(method="GET", url=url, persist=True) as response: + assert await response.json() == expected_json + assert sum(len(reqs) for reqs in requests_mock.requests.values()) == 2 + assert await repository.get_response(key) - response = request_handler._request(method="GET", url=test_url) - assert response.json() == expected_json - assert requests_mock.call_count == 2 + async with handler._request(method="GET", url=url) as response: + assert await response.json() == expected_json + assert sum(len(reqs) for reqs in requests_mock.requests.values()) == 2 - repository.clear() - response = request_handler._request(method="GET", url=test_url) - assert response.json() == expected_json - assert requests_mock.call_count == 3 + await repository.clear() + async with handler._request(method="GET", url=url) as response: + assert await response.json() == expected_json + assert sum(len(reqs) for reqs in requests_mock.requests.values()) == 3 - def test_request(self, request_handler: RequestHandler, requests_mock: Mocker): + async def test_request(self, request_handler: RequestHandler, requests_mock: aioresponses): def raise_error(*_, **__): """Just raise a ConnectionError""" - raise requests.exceptions.ConnectionError() + raise aiohttp.ClientConnectionError() # handles connection errors safely url = "http://localhost/text_response" - requests_mock.get(url, text=raise_error) - assert request_handler._request(method="GET", url=url) is None - - url = "http://localhost/test" - expected_json = {"key": "value"} - - requests_mock.get(url, json=expected_json) - assert request_handler.request(method="GET", url=url) == expected_json - - # ignore headers on good status code - requests_mock.post(url, status_code=200, headers={"retry-after": "2000"}, json=expected_json) - assert request_handler.request(method="POST", url=url) == expected_json - - # fail on long wait time - requests_mock.put(url, status_code=429, headers={"retry-after": "2000"}) - assert request_handler.timeout < 2000 - with pytest.raises(APIError): - request_handler.put(url=url) - - # fail on breaking status code - requests_mock.delete(url, status_code=400) - assert request_handler.timeout < 2000 - with pytest.raises(APIError): - request_handler.delete(method="GET", url=url) - - def test_backoff(self, request_handler: RequestHandler, requests_mock: Mocker): + requests_mock.get(url, callback=raise_error, repeat=True) + async with request_handler as handler: + async with handler._request(method="GET", url=url) as response: + assert response is None + + url = "http://localhost/test" + expected_json = {"key": "value"} + + requests_mock.get(url, payload=expected_json) + assert await handler.request(method="GET", url=url) == expected_json + + # ignore headers on good status code + requests_mock.post(url, status=200, headers={"retry-after": "2000"}, payload=expected_json) + assert await handler.request(method="POST", url=url) == expected_json + + # fail on long wait time + requests_mock.put(url, status=429, headers={"retry-after": "2000"}) + assert handler.timeout < 2000 + with pytest.raises(APIError): + await handler.put(url=url) + + # fail on breaking status code + requests_mock.delete(url, status=400) + assert handler.timeout < 2000 + with pytest.raises(APIError): + await handler.delete(method="GET", url=url) + + async def test_backoff(self, request_handler: RequestHandler, requests_mock: aioresponses): url = "http://localhost/test" expected_json = {"key": "value"} backoff_limit = 3 @@ -181,15 +198,15 @@ def test_backoff(self, request_handler: RequestHandler, requests_mock: Mocker): request_handler.backoff_factor = 2 request_handler.backoff_count = backoff_limit + 2 - def backoff(_: Request, context: Context) -> dict[str, Any]: - """Return response based on how many times backoff process has happened""" - if requests_mock.call_count < backoff_limit: - context.status_code = 408 - return {"error": {"message": "fail"}} + def callback(method: str, *_, **__) -> CallbackResult: + """Modify mock response based on how many times backoff process has happened""" + if sum(len(reqs) for reqs in requests_mock.requests.values()) < backoff_limit: + payload = {"error": {"message": "fail"}} + return CallbackResult(method=method, status=408, payload=payload) - context.status_code = 200 - return expected_json + return CallbackResult(method=method, status=200, payload=expected_json) - requests_mock.patch(url, json=backoff) - assert request_handler.patch(url=url) == expected_json - assert requests_mock.call_count == backoff_limit + requests_mock.patch(url, callback=callback, repeat=True) + async with request_handler as handler: + assert await handler.patch(url=url) == expected_json + assert sum(len(reqs) for reqs in requests_mock.requests.values()) == backoff_limit diff --git a/tests/conftest.py b/tests/conftest.py index 2170a2a6..f479557b 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -11,6 +11,7 @@ import pytest import yaml from _pytest.fixtures import SubRequest +from aioresponses import aioresponses from musify import MODULE_ROOT from musify.libraries.remote.core.enum import RemoteObjectType @@ -53,6 +54,20 @@ def remove_file_handler(c: dict[str, Any]) -> None: logging.config.dictConfig(log_config) +def pytest_collection_modifyitems(items: list[pytest.Function]): + """Modifies test items in-place, ordering them based on assigned marks.""" + marker_name_order = [] # currently not implemented + + def _get_item_order_index(item: pytest.Function) -> int: + try: + name = next(marker.name for marker in item.own_markers if marker.name.casefold() in marker_name_order) + return marker_name_order.index(name.casefold()) + except (StopIteration, ValueError): + return len(marker_name_order) + + items.sort(key=_get_item_order_index) + + # This is a fork of the pytest-lazy-fixture package # Fixes applied for issues with pytest >8.0: https://github.com/TvoroG/pytest-lazy-fixture/issues/65 # noinspection PyProtectedMember @@ -258,6 +273,13 @@ def __eq__(self, other): return self.name == other.name +@pytest.fixture +def requests_mock(): + """Yields an initialised :py:class:`aioresponses` object for mocking aiohttp requests as a pytest.fixture.""" + with aioresponses() as m: + yield m + + @pytest.fixture def path(request: pytest.FixtureRequest | SubRequest, tmp_path: Path) -> str: """ @@ -292,21 +314,20 @@ def spotify_wrangler(): return SpotifyDataWrangler() -@pytest.fixture(scope="session") -def spotify_api() -> SpotifyAPI: - """Yield an authorised :py:class:`SpotifyAPI` object""" - token = {"access_token": "fake access token", "token_type": "Bearer", "scope": "test-read"} - api = SpotifyAPI(token=token) - # blocks any token tests - api.handler.authoriser.test_args = None - api.handler.authoriser.test_expiry = 0 - api.handler.authoriser.test_condition = None - with api as a: - yield a - - @pytest.fixture(scope="session") def spotify_mock() -> SpotifyMock: """Yield an authorised and configured :py:class:`SpotifyMock` object""" with SpotifyMock() as m: yield m + + +@pytest.fixture(scope="session") +async def spotify_api(spotify_mock: SpotifyMock) -> SpotifyAPI: + """Yield an authorised :py:class:`SpotifyAPI` object""" + token = {"access_token": "fake access token", "token_type": "Bearer", "scope": "test-read"} + # disable any token tests by settings test_* kwargs as appropriate + api = SpotifyAPI(token=token, test_args=None, test_expiry=0, test_condition=None) + api.handler.backoff_count = 1 # TODO: remove me + + async with api as a: + yield a diff --git a/tests/libraries/local/conftest.py b/tests/libraries/local/conftest.py index 7b435b39..98b9ddf9 100644 --- a/tests/libraries/local/conftest.py +++ b/tests/libraries/local/conftest.py @@ -44,6 +44,7 @@ def track_wma(path: str, remote_wrangler: RemoteDataWrangler) -> WMA: return WMA(file=path, remote_wrangler=remote_wrangler) +# noinspection PyUnresolvedReferences @pytest.fixture(params=[ pytest.lazy_fixture("track_flac"), pytest.lazy_fixture("track_mp3"), diff --git a/tests/libraries/local/playlist/test_m3u.py b/tests/libraries/local/playlist/test_m3u.py index fa6ed49e..3ef27d78 100644 --- a/tests/libraries/local/playlist/test_m3u.py +++ b/tests/libraries/local/playlist/test_m3u.py @@ -12,7 +12,7 @@ from tests.libraries.local.playlist.testers import LocalPlaylistTester from tests.libraries.local.track.utils import random_track, random_tracks from tests.libraries.local.utils import path_playlist_m3u -from tests.utils import path_txt +from tests.utils import path_txt, path_resources class TestM3U(LocalPlaylistTester): @@ -64,16 +64,17 @@ def test_load_fake_file_with_fake_tracks(self, tracks: list[LocalTrack], tmp_pat pl.load(tracks + tracks_random[:4]) assert pl.tracks == tracks_random[:4] - def test_load_file_with_no_tracks( - self, tracks_actual: list[LocalTrack], tracks_limited: list[LocalTrack], path_mapper: PathMapper - ): - pl = M3U(path=path_playlist_m3u, path_mapper=path_mapper) + # noinspection PyTestUnpassedFixture + def test_load_file_with_no_tracks(self, tracks_actual: list[LocalTrack], tracks_limited: list[LocalTrack]): + pl = M3U(path=path_playlist_m3u, path_mapper=PathStemMapper(stem_map={"../": path_resources})) assert pl.path == path_playlist_m3u assert pl.tracks == tracks_actual + assert [track.path for track in pl] == [track.path for track in tracks_actual] # reloads only with given tracks that match conditions i.e. paths to include pl.load(tracks_limited) + assert [track.path for track in pl] == [track.path for track in tracks_limited if track in tracks_actual] assert pl.tracks == [track for track in tracks_limited if track in tracks_actual] # ...and then reloads all tracks from disk that match conditions when no tracks are given diff --git a/tests/libraries/local/utils.py b/tests/libraries/local/utils.py index 31faa870..0d027066 100644 --- a/tests/libraries/local/utils.py +++ b/tests/libraries/local/utils.py @@ -10,8 +10,8 @@ path_track_resources = join(path_resources, "track") path_track_all: set[str] = {path for c in TRACK_CLASSES for path in c.get_filepaths(path_track_resources)} -path_track_flac = join(path_track_resources, "noise_flac.flac") -path_track_mp3 = join(path_track_resources, "noise_mp3.mp3") +path_track_flac = join(path_track_resources, "NOISE_FLaC.flac") +path_track_mp3 = join(path_track_resources, "noiSE_mP3.mp3") path_track_m4a = join(path_track_resources, "noise_m4a.m4a") path_track_wma = join(path_track_resources, "noise_wma.wma") path_track_img = join(path_track_resources, "track_image.jpg") diff --git a/tests/libraries/remote/core/api.py b/tests/libraries/remote/core/api.py index 4259e085..0e2e430d 100644 --- a/tests/libraries/remote/core/api.py +++ b/tests/libraries/remote/core/api.py @@ -1,12 +1,12 @@ from abc import ABC, abstractmethod +from collections.abc import Iterable from copy import deepcopy from random import sample, choice from typing import Any -from urllib.parse import parse_qs +from urllib.parse import unquote import pytest -# noinspection PyProtectedMember,PyUnresolvedReferences -from requests_mock.request import _RequestObjectProxy as Request +from yarl import URL from musify.libraries.remote.core.api import RemoteAPI from musify.libraries.remote.core.enum import RemoteObjectType @@ -77,14 +77,13 @@ def assert_different(source: dict[str, Any], test: dict[str, Any], *omit: str): assert {k: v for k, v in test.items() if k not in omit} != expected @staticmethod - def assert_params(requests: list[Request], params: dict[str, Any] | list[dict[str, Any]]): + def assert_params(requests: Iterable[URL], params: dict[str, Any] | list[dict[str, Any]]): """Check for expected ``params`` in the given ``requests``""" - for request in requests: - request_params = parse_qs(request.query) + for url in requests: if isinstance(params, list): - assert any(request_params[k][0] == param[k] for param in params for k in param) + assert any(unquote(url.query[k]) == param[k] for param in params for k in param) continue for k, v in params.items(): - assert k in request_params - assert request_params[k][0] == params[k] + assert k in url.query + assert unquote(url.query[k]) == params[k] diff --git a/tests/libraries/remote/core/library.py b/tests/libraries/remote/core/library.py index b6162cbc..df71a27b 100644 --- a/tests/libraries/remote/core/library.py +++ b/tests/libraries/remote/core/library.py @@ -1,3 +1,4 @@ +import re from abc import ABCMeta, abstractmethod from collections.abc import Collection, Mapping from copy import copy, deepcopy @@ -28,8 +29,8 @@ def collection_merge_items(self, *args, **kwargs) -> list[RemoteTrack]: @staticmethod @pytest.mark.slow - def test_load_playlists(library_unloaded: RemoteLibrary): - library_unloaded.load_playlists() + async def test_load_playlists(library_unloaded: RemoteLibrary): + await library_unloaded.load_playlists() # only loaded playlists matching the filter assert len(library_unloaded.playlists) == 10 @@ -39,75 +40,75 @@ def test_load_playlists(library_unloaded: RemoteLibrary): assert len(library_unloaded.tracks_in_playlists) == unique_tracks_count # does not add duplicates to the loaded lists - library_unloaded.load_playlists() + await library_unloaded.load_playlists() assert len(library_unloaded.playlists) == 10 assert len(library_unloaded.tracks_in_playlists) == unique_tracks_count @abstractmethod - def test_load_tracks(self, *_, **__): + async def test_load_tracks(self, *_, **__): raise NotImplementedError @abstractmethod - def test_load_saved_albums(self, *_, **__): + async def test_load_saved_albums(self, *_, **__): raise NotImplementedError @abstractmethod - def test_load_saved_artists(self, *_, **__): + async def test_load_saved_artists(self, *_, **__): raise NotImplementedError @staticmethod @pytest.mark.slow - def test_load(library_unloaded: RemoteLibrary): + async def test_load(library_unloaded: RemoteLibrary): assert not library_unloaded.playlists assert not library_unloaded.tracks assert not library_unloaded.albums assert not library_unloaded.artists - library_unloaded.load() + await library_unloaded.load() assert library_unloaded.playlists assert library_unloaded.tracks assert library_unloaded.albums assert library_unloaded.artists - def test_enrich_tracks(self, library: RemoteLibrary, *_, **__): + async def test_enrich_tracks(self, library: RemoteLibrary, *_, **__): """ This function and related test can be implemented by child classes. Just check it doesn't fail for the base test class. """ - assert library.enrich_tracks() is None + assert await library.enrich_tracks() is None - def test_enrich_saved_artists(self, library: RemoteLibrary, *_, **__): + async def test_enrich_saved_artists(self, library: RemoteLibrary, *_, **__): """ This function and related test can be implemented by child classes. Just check it doesn't fail for the base test class. """ - assert library.enrich_saved_artists() is None + assert await library.enrich_saved_artists() is None - def test_enrich_saved_albums(self, library: RemoteLibrary, *_, **__): + async def test_enrich_saved_albums(self, library: RemoteLibrary, *_, **__): """ This function and related test can be implemented by child classes. Just check it doesn't fail for the base test class. """ - assert library.enrich_saved_albums() is None + assert await library.enrich_saved_albums() is None @staticmethod - def test_extend(library: RemoteLibrary, collection_merge_items: list[RemoteTrack], api_mock: RemoteMock): + async def test_extend(library: RemoteLibrary, collection_merge_items: list[RemoteTrack], api_mock: RemoteMock): # extend on already existing tracks with duplicates not allowed library_tracks_start = copy(library.tracks) tracks_existing = library.tracks[:10] assert len(tracks_existing) > 0 - library.extend(tracks_existing, allow_duplicates=False) + await library.extend(tracks_existing, allow_duplicates=False) assert len(library) == len(library_tracks_start) # no change - assert len(api_mock.request_history) == 0 # no requests made + api_mock.assert_not_called() # no requests made - library.extend(collection_merge_items + tracks_existing, allow_duplicates=False) + await library.extend(collection_merge_items + tracks_existing, allow_duplicates=False) assert len(library) == len(library_tracks_start) + len(collection_merge_items) - assert len(api_mock.request_history) == 0 # no requests made + api_mock.assert_not_called() # no requests made - library.extend(tracks_existing, allow_duplicates=True) + await library.extend(tracks_existing, allow_duplicates=True) assert len(library) == len(library_tracks_start) + len(collection_merge_items) + len(tracks_existing) - assert len(api_mock.request_history) == 0 # no requests made + api_mock.assert_not_called() # no requests made library._tracks = copy(library_tracks_start) local_tracks = random_tracks(len(collection_merge_items)) @@ -116,9 +117,9 @@ def test_extend(library: RemoteLibrary, collection_merge_items: list[RemoteTrack assert remote not in library assert local not in library - library.extend(local_tracks) + await library.extend(local_tracks) assert len(library) == len(library_tracks_start) + len(local_tracks) - assert len(api_mock.request_history) > 0 # new requests were made + api_mock.assert_called() # new requests were made @staticmethod def test_backup(library: RemoteLibrary): @@ -126,7 +127,7 @@ def test_backup(library: RemoteLibrary): assert library.backup_playlists() == expected @staticmethod - def assert_restore(library: RemoteLibrary, backup: Any): + async def assert_restore(library: RemoteLibrary, backup: Any): """Run test and assertions on restore_playlists functionality for given input backup data type""" backup_check: Mapping[str, list[str]] if isinstance(backup, RemoteLibrary): # get URIs from playlists in library @@ -134,7 +135,15 @@ def assert_restore(library: RemoteLibrary, backup: Any): elif isinstance(backup, Mapping) and all(isinstance(v, MusifyItem) for vals in backup.values() for v in vals): # get URIs from playlists in map values backup_check = {name: [track.uri for track in pl] for name, pl in backup.items()} - elif not isinstance(backup, Mapping) and isinstance(backup, Collection): + elif isinstance(backup, Mapping): + # get URIs from playlists in collection + backup_check = { + name: + [t["uri"] if isinstance(t, Mapping) else t for t in tracks["tracks"]] + if isinstance(tracks, Mapping) else tracks + for name, tracks in backup.items() + } + elif isinstance(backup, Collection): # get URIs from playlists in collection backup_check = {pl.name: [track.uri for track in pl] for pl in backup} else: @@ -144,7 +153,7 @@ def assert_restore(library: RemoteLibrary, backup: Any): name_new = next(name for name in backup_check if name not in library.playlists) library_test = deepcopy(library) - library_test.restore_playlists(playlists=backup, dry_run=False) + await library_test.restore_playlists(playlists=backup, dry_run=False) assert len(library_test.playlists[name_actual]) == len(backup_check[name_actual]) assert len(library_test.playlists[name_actual]) != len(library.playlists[name_actual]) @@ -154,7 +163,7 @@ def assert_restore(library: RemoteLibrary, backup: Any): assert library.api.handler.get(pl_new.url) # new playlist was created and is callable @pytest.mark.slow - def test_restore(self, library: RemoteLibrary, collection_merge_items: list[RemoteTrack]): + async def test_restore(self, library: RemoteLibrary, collection_merge_items: list[RemoteTrack]): name_actual, pl_actual = choice([(name, pl) for name, pl in library.playlists.items() if len(pl) > 10]) name_new = "new playlist" @@ -167,7 +176,7 @@ def test_restore(self, library: RemoteLibrary, collection_merge_items: list[Remo new_uri_list = [track.uri for track in collection_merge_items] backup_uri = {name_new: new_uri_list, "random new name": new_uri_list} library_test = deepcopy(library) - library_test.restore_playlists(playlists=backup_uri, dry_run=True) + await library_test.restore_playlists(playlists=backup_uri, dry_run=True) assert len(library_test.playlists) == len(library.playlists) # no new playlists created/added # Mapping[str, Iterable[str]] @@ -175,27 +184,36 @@ def test_restore(self, library: RemoteLibrary, collection_merge_items: list[Remo name_actual: [track.uri for track in pl_actual[:5]] + new_uri_list, name_new: new_uri_list, } - self.assert_restore(library=library, backup=backup_uri) + await self.assert_restore(library=library, backup=backup_uri) # Mapping[str, Iterable[Track]] backup_tracks = { name_actual: pl_actual[:5] + collection_merge_items, name_new: collection_merge_items, } - self.assert_restore(library=library, backup=backup_tracks) + await self.assert_restore(library=library, backup=backup_tracks) + + # Mapping[str, Mapping[str, Iterable[Mapping[str, Any]]]] + backup_nested = { + name_actual: { + "tracks": [{"uri": track.uri} for track in pl_actual[:5]] + [{"uri": uri} for uri in new_uri_list] + }, + name_new: {"tracks": [{"uri": uri} for uri in new_uri_list]}, + } + await self.assert_restore(library=library, backup=backup_nested) # Library backup_library = deepcopy(library) backup_library._playlists = {name_actual: backup_library.playlists[name_actual]} - backup_library.restore_playlists(playlists=backup_uri, dry_run=False) - self.assert_restore(library=library, backup=backup_library) + await backup_library.restore_playlists(playlists=backup_uri, dry_run=False) + await self.assert_restore(library=library, backup=backup_library) # Collection[Playlist] backup_pl = [backup_library.playlists[name_actual], backup_library.playlists[name_new]] - self.assert_restore(library=library, backup=backup_pl) + await self.assert_restore(library=library, backup=backup_pl) @staticmethod - def assert_sync(library: RemoteLibrary, playlists: Any, api_mock: RemoteMock): + async def assert_sync(library: RemoteLibrary, playlists: Any, api_mock: RemoteMock): """Run test and assertions on library sync functionality for given input playlists data type""" playlists_check: Mapping[str, Collection[MusifyItem]] if isinstance(playlists, RemoteLibrary): # get map of playlists from the given library @@ -210,7 +228,7 @@ def assert_sync(library: RemoteLibrary, playlists: Any, api_mock: RemoteMock): name_new = next(name for name in playlists_check if name not in library.playlists) library_test = deepcopy(library) - results = library_test.sync(playlists=playlists, dry_run=False) + results = await library_test.sync(playlists=playlists, dry_run=False) # existing playlist assertions assert results[name_actual].added == len(playlists_check[name_new]) @@ -218,7 +236,7 @@ def assert_sync(library: RemoteLibrary, playlists: Any, api_mock: RemoteMock): assert results[name_actual].unchanged == len(library.playlists[name_actual]) url = library_test.playlists[name_actual].url - requests = [req for req in api_mock.get_requests(method="POST") if req.url.startswith(url)] + requests = await api_mock.get_requests(method="POST", url=re.compile(url)) assert len(requests) > 0 # new playlist assertions @@ -229,20 +247,20 @@ def assert_sync(library: RemoteLibrary, playlists: Any, api_mock: RemoteMock): assert library.api.handler.get(library_test.playlists[name_new].url) # new playlist was created and is callable url = library_test.playlists[name_new].url - requests = [req for req in api_mock.get_requests(method="POST") if req.url.startswith(url)] + requests = await api_mock.get_requests(method="POST", url=re.compile(url)) assert len(requests) > 0 @staticmethod - def test_sync_dry_run(library: RemoteLibrary, api_mock: RemoteMock): + async def test_sync_dry_run(library: RemoteLibrary, api_mock: RemoteMock): new_playlists = copy(list(library.playlists.values())) for i, pl in enumerate(new_playlists, 1): pl.name = f"this is a new playlist name {i}" - library.sync(list(library.playlists.values()) + new_playlists, reload=True) - assert not api_mock.get_requests(method="POST") + await library.sync(list(library.playlists.values()) + new_playlists, reload=True) + assert not await api_mock.get_requests(method="POST") @pytest.mark.slow - def test_sync(self, library: RemoteLibrary, collection_merge_items: list[RemoteTrack], api_mock: RemoteMock): + async def test_sync(self, library: RemoteLibrary, collection_merge_items: list[RemoteTrack], api_mock: RemoteMock): name_actual, pl_actual = choice([(name, pl) for name, pl in library.playlists.items() if len(pl) > 10]) name_new = "new playlist" @@ -256,16 +274,16 @@ def test_sync(self, library: RemoteLibrary, collection_merge_items: list[RemoteT name_actual: pl_actual[:5] + collection_merge_items, name_new: collection_merge_items, } - self.assert_sync(library=library, playlists=playlists_tracks, api_mock=api_mock) + await self.assert_sync(library=library, playlists=playlists_tracks, api_mock=api_mock) # Library playlists_library = deepcopy(library) - playlists_library.restore_playlists(playlists=playlists_tracks, dry_run=False) + await playlists_library.restore_playlists(playlists=playlists_tracks, dry_run=False) for name in list(playlists_library.playlists.keys()): if name not in playlists_tracks: playlists_library.playlists.pop(name) - self.assert_sync(library=library, playlists=playlists_library, api_mock=api_mock) + await self.assert_sync(library=library, playlists=playlists_library, api_mock=api_mock) # Collection[Playlist] playlists_coll = [playlists_library.playlists[name_actual], playlists_library.playlists[name_new]] - self.assert_sync(library=library, playlists=playlists_coll, api_mock=api_mock) + await self.assert_sync(library=library, playlists=playlists_coll, api_mock=api_mock) diff --git a/tests/libraries/remote/core/object.py b/tests/libraries/remote/core/object.py index 4db449f3..f52561f9 100644 --- a/tests/libraries/remote/core/object.py +++ b/tests/libraries/remote/core/object.py @@ -3,6 +3,7 @@ from typing import Any import pytest +from aioresponses.core import RequestCall from musify.exception import MusifyKeyError from musify.libraries.local.track import LocalTrack @@ -61,10 +62,13 @@ def test_collection_getitem_dunder_method( class RemotePlaylistTester(RemoteCollectionTester, PlaylistTester, metaclass=ABCMeta): + @staticmethod + def _get_payload_from_request(request: RequestCall) -> dict[str, Any] | None: + return request.kwargs.get("body", request.kwargs.get("json")) + ########################################################################### ## Sync tests ########################################################################### - @abstractmethod def sync_playlist(self, response_valid: dict[str, Any], api: RemoteAPI) -> RemotePlaylist: """ @@ -83,75 +87,77 @@ def sync_items(self, *args, **kwargs) -> list[RemoteTrack]: @staticmethod @abstractmethod - def get_sync_uris(url: str, api_mock: RemoteMock) -> tuple[list[str], list[str]]: + async def get_sync_uris(url: str, api_mock: RemoteMock) -> tuple[list[str], list[str]]: """Return tuple of lists of URIs added and URIs cleared when applying sync operations""" raise NotImplementedError @staticmethod - def assert_playlist_loaded(sync_playlist: RemotePlaylist, api_mock: RemoteMock, count: int = 1) -> None: + async def assert_playlist_loaded(sync_playlist: RemotePlaylist, api_mock: RemoteMock, count: int = 1) -> None: """Assert the given playlist was fully reloaded through GET requests ``count`` number of times""" pages = api_mock.calculate_pages_from_response(sync_playlist.response) - requests = api_mock.get_requests(url=sync_playlist.url, method="GET") - requests += api_mock.get_requests(url=sync_playlist.url + "/tracks", method="GET") + requests = await api_mock.get_requests(method="GET", url=sync_playlist.url) + requests += await api_mock.get_requests(method="GET", url=sync_playlist.url + "/tracks") assert len(requests) == pages * count @staticmethod - def test_sync_dry_run(sync_playlist: RemotePlaylist, sync_items: list[RemoteTrack], api_mock: RemoteMock): - result_refresh_no_items = sync_playlist.sync(kind="refresh", reload=False) + async def test_sync_dry_run(sync_playlist: RemotePlaylist, sync_items: list[RemoteTrack], api_mock: RemoteMock): + result_refresh_no_items = await sync_playlist.sync(kind="refresh", reload=False) assert result_refresh_no_items.start == len(sync_playlist) assert result_refresh_no_items.added == result_refresh_no_items.start assert result_refresh_no_items.removed == result_refresh_no_items.start assert result_refresh_no_items.unchanged == 0 assert result_refresh_no_items.difference == 0 assert result_refresh_no_items.final == result_refresh_no_items.start - assert len(api_mock.request_history) == 0 + api_mock.assert_not_called() sync_items_extended = sync_items + sync_playlist[:10] - result_refresh_with_items = sync_playlist.sync(items=sync_items_extended, kind="refresh", reload=True) + result_refresh_with_items = await sync_playlist.sync(items=sync_items_extended, kind="refresh", reload=True) assert result_refresh_with_items.start == len(sync_playlist) assert result_refresh_with_items.added == len(sync_items_extended) assert result_refresh_with_items.removed == result_refresh_with_items.start assert result_refresh_with_items.unchanged == 0 assert result_refresh_with_items.difference == result_refresh_with_items.added - result_refresh_with_items.start assert result_refresh_with_items.final == result_refresh_with_items.added - assert len(api_mock.request_history) == 0 # reload does not happen on dry_run + api_mock.assert_not_called() # reload does not happen on dry_run - result_new = sync_playlist.sync(items=sync_items_extended, kind="new", reload=False) + result_new = await sync_playlist.sync(items=sync_items_extended, kind="new", reload=False) assert result_new.start == len(sync_playlist) assert result_new.added == len(sync_items) assert result_new.removed == 0 assert result_new.unchanged == result_new.start assert result_new.difference == result_new.added assert result_new.final == result_new.start + result_new.difference - assert len(api_mock.request_history) == 0 + api_mock.assert_not_called() sync_uri = {track.uri for track in sync_items_extended} - result_sync = sync_playlist.sync(items=sync_items_extended, kind="sync", reload=False) + result_sync = await sync_playlist.sync(items=sync_items_extended, kind="sync", reload=False) assert result_sync.start == len(sync_playlist) assert result_sync.added == len(sync_items) assert result_sync.removed == len([track.uri for track in sync_playlist if track.uri not in sync_uri]) assert result_sync.unchanged == len([track.uri for track in sync_playlist if track.uri in sync_uri]) assert result_sync.difference == len(sync_items) - result_sync.removed assert result_sync.final == result_sync.start + result_sync.difference - assert len(api_mock.request_history) == 0 + api_mock.assert_not_called() - def test_sync_reload(self, sync_playlist: RemotePlaylist, sync_items: list[RemoteTrack], api_mock: RemoteMock): + async def test_sync_reload( + self, sync_playlist: RemotePlaylist, sync_items: list[RemoteTrack], api_mock: RemoteMock + ): start = len(sync_playlist) sync_playlist.tracks.clear() assert len(sync_playlist) == 0 - sync_playlist.sync(kind="sync", items=sync_items, reload=True, dry_run=False) + await sync_playlist.sync(kind="sync", items=sync_items, reload=True, dry_run=False) # playlist will reload from mock so, for this test, it will just get back its original items assert len(sync_playlist) == start # 1 for skip dupes on add to playlist, 1 for reload - self.assert_playlist_loaded(sync_playlist=sync_playlist, api_mock=api_mock, count=2) + await self.assert_playlist_loaded(sync_playlist=sync_playlist, api_mock=api_mock, count=2) - def test_sync_new(self, sync_playlist: RemotePlaylist, sync_items: list[RemoteTrack], api_mock: RemoteMock): + async def test_sync_new(self, sync_playlist: RemotePlaylist, sync_items: list[RemoteTrack], api_mock: RemoteMock): sync_items_extended = sync_items + sync_playlist.tracks[:5] - result = sync_playlist.sync(kind="new", items=sync_items_extended, reload=False, dry_run=False) + result = await sync_playlist.sync(kind="new", items=sync_items_extended, reload=False, dry_run=False) assert result.start == len(sync_playlist) assert result.added == len(sync_items) @@ -160,16 +166,18 @@ def test_sync_new(self, sync_playlist: RemotePlaylist, sync_items: list[RemoteTr assert result.difference == result.added assert result.final == result.start + result.difference - uri_add, uri_clear = self.get_sync_uris(url=sync_playlist.url, api_mock=api_mock) + uri_add, uri_clear = await self.get_sync_uris(url=sync_playlist.url, api_mock=api_mock) assert uri_add == [track.uri for track in sync_items] assert uri_clear == [] # 1 for skip dupes check on add to playlist - self.assert_playlist_loaded(sync_playlist=sync_playlist, api_mock=api_mock, count=1) + await self.assert_playlist_loaded(sync_playlist=sync_playlist, api_mock=api_mock, count=1) - def test_sync_refresh(self, sync_playlist: RemotePlaylist, sync_items: list[RemoteTrack], api_mock: RemoteMock): + async def test_sync_refresh( + self, sync_playlist: RemotePlaylist, sync_items: list[RemoteTrack], api_mock: RemoteMock + ): start = len(sync_playlist) - result = sync_playlist.sync(items=sync_items, kind="refresh", reload=True, dry_run=False) + result = await sync_playlist.sync(items=sync_items, kind="refresh", reload=True, dry_run=False) assert result.start == start assert result.added == len(sync_items) @@ -178,16 +186,16 @@ def test_sync_refresh(self, sync_playlist: RemotePlaylist, sync_items: list[Remo # assert result.difference == 0 # useless when mocking + reload # assert result.final == start # useless when mocking + reload - uri_add, uri_clear = self.get_sync_uris(url=sync_playlist.url, api_mock=api_mock) + uri_add, uri_clear = await self.get_sync_uris(url=sync_playlist.url, api_mock=api_mock) assert uri_add == [track.uri for track in sync_items] assert uri_clear == [track.uri for track in sync_playlist] # 1 load current tracks on remote when clearing, 1 for reload - self.assert_playlist_loaded(sync_playlist=sync_playlist, api_mock=api_mock, count=2) + await self.assert_playlist_loaded(sync_playlist=sync_playlist, api_mock=api_mock, count=2) - def test_sync(self, sync_playlist: RemotePlaylist, sync_items: list[RemoteTrack], api_mock: RemoteMock): + async def test_sync(self, sync_playlist: RemotePlaylist, sync_items: list[RemoteTrack], api_mock: RemoteMock): sync_items_extended = sync_items + sync_playlist[:10] - result = sync_playlist.sync(kind="sync", items=sync_items_extended, reload=False, dry_run=False) + result = await sync_playlist.sync(kind="sync", items=sync_items_extended, reload=False, dry_run=False) sync_uri = {track.uri for track in sync_items_extended} assert result.start == len(sync_playlist) @@ -197,9 +205,9 @@ def test_sync(self, sync_playlist: RemotePlaylist, sync_items: list[RemoteTrack] assert result.difference == len(sync_items) - result.removed assert result.final == result.start + result.difference - uri_add, uri_clear = self.get_sync_uris(url=sync_playlist.url, api_mock=api_mock) + uri_add, uri_clear = await self.get_sync_uris(url=sync_playlist.url, api_mock=api_mock) assert uri_add == [track.uri for track in sync_items] assert uri_clear == [track.uri for track in sync_playlist if track.uri not in sync_uri] # 1 load when clearing - self.assert_playlist_loaded(sync_playlist=sync_playlist, api_mock=api_mock, count=1) + await self.assert_playlist_loaded(sync_playlist=sync_playlist, api_mock=api_mock, count=1) diff --git a/tests/libraries/remote/core/processors/check.py b/tests/libraries/remote/core/processors/check.py index 07cf7e2b..12df4058 100644 --- a/tests/libraries/remote/core/processors/check.py +++ b/tests/libraries/remote/core/processors/check.py @@ -45,12 +45,12 @@ def collections(self, playlist_urls: list[str]) -> list[BasicCollection]: @staticmethod @pytest.fixture - def setup_playlist_collection( + async def setup_playlist_collection( checker: RemoteItemChecker, playlist_urls: list[str] ) -> tuple[RemotePlaylist, BasicCollection]: """Setups up checker, playlist, and collection for testing match_to_remote functionality""" url = choice(playlist_urls) - pl = checker.factory.playlist(checker.api.get_items(url, extend=True)[0]) + pl = checker.factory.playlist(next(iter(await checker.api.get_items(url, extend=True)))) assert len(pl) > 10 assert len({item.uri for item in pl}) == len(pl) # all unique tracks @@ -69,7 +69,7 @@ def token_file_path(self, path: str) -> str: return path @staticmethod - def test_make_temp_playlist(checker: RemoteItemChecker, api_mock: RemoteMock, token_file_path: str): + async def test_make_temp_playlist(checker: RemoteItemChecker, api_mock: RemoteMock, token_file_path: str): # force auth test to fail and reload from token checker.api.handler.authoriser.token = None checker.api.handler.authoriser.token_file_path = token_file_path @@ -79,22 +79,22 @@ def test_make_temp_playlist(checker: RemoteItemChecker, api_mock: RemoteMock, to item.uri = None # does nothing when no URIs to add - checker._create_playlist(collection=collection) + await checker._create_playlist(collection=collection) assert not checker._playlist_name_urls assert not checker._playlist_name_collection - assert not api_mock.request_history + api_mock.assert_not_called() for item in collection: item.uri = random_uri() - checker._create_playlist(collection=collection) + await checker._create_playlist(collection=collection) assert checker.api.handler.authoriser.token is not None assert collection.name in checker._playlist_name_urls assert checker._playlist_name_collection[collection.name] == collection - assert len(api_mock.request_history) >= 2 + assert api_mock.total_requests >= 2 @staticmethod - def test_delete_temp_playlists( + async def test_delete_temp_playlists( checker: RemoteItemChecker, collections: list[BasicCollection], playlist_urls: list[str], @@ -108,11 +108,11 @@ def test_delete_temp_playlists( checker._playlist_name_urls = {collection.name: url for collection, url in zip(collections, playlist_urls)} checker._playlist_name_collection = {collection.name: collection for collection in collections} - checker._delete_playlists() + await checker._delete_playlists() assert checker.api.handler.authoriser.token is not None assert not checker._playlist_name_urls assert not checker._playlist_name_collection - assert len(api_mock.get_requests(method="DELETE")) == min(len(playlist_urls), len(collections)) + assert len(await api_mock.get_requests(method="DELETE")) == min(len(playlist_urls), len(collections)) @staticmethod def test_finalise(checker: RemoteItemChecker): @@ -141,7 +141,7 @@ def test_finalise(checker: RemoteItemChecker): ## ``pause`` step ########################################################################### @staticmethod - def test_pause_1( + async def test_pause_1( checker: RemoteItemChecker, setup_playlist_collection: tuple[RemotePlaylist, BasicCollection], mocker: MockerFixture, @@ -151,7 +151,7 @@ def test_pause_1( pl, collection = setup_playlist_collection patch_input(["h", collection.name, pl.uri], mocker=mocker) - checker._pause(page=1, total=1) + await checker._pause(page=1, total=1) mocker.stopall() capfd.close() @@ -162,13 +162,13 @@ def test_pause_1( assert f"Showing tracks for playlist: {pl.name}" in stdout # entered pl_pages = api_mock.calculate_pages(limit=20, total=len(pl)) - assert len(api_mock.get_requests(url=re.compile(pl.url + ".*"), method="GET")) == pl_pages + 1 + assert len(await api_mock.get_requests(method="GET", url=re.compile(pl.url))) == pl_pages + 1 assert not checker._skip assert not checker._quit @staticmethod - def test_pause_2( + async def test_pause_2( checker: RemoteItemChecker, mocker: MockerFixture, api_mock: RemoteMock, @@ -176,7 +176,7 @@ def test_pause_2( ): patch_input([random_str(10, 20), "u", "s"], mocker=mocker) - checker._pause(page=1, total=1) + await checker._pause(page=1, total=1) mocker.stopall() capfd.close() @@ -186,13 +186,13 @@ def test_pause_2( assert "Showing items originally added to" not in stdout assert "Showing tracks for playlist" not in stdout - assert not api_mock.request_history + api_mock.assert_not_called() assert checker._skip assert not checker._quit @staticmethod - def test_pause_3( + async def test_pause_3( checker: RemoteItemChecker, mocker: MockerFixture, api_mock: RemoteMock, @@ -200,7 +200,7 @@ def test_pause_3( ): patch_input(["q"], mocker=mocker) - checker._pause(page=1, total=1) + await checker._pause(page=1, total=1) mocker.stopall() capfd.close() @@ -209,7 +209,7 @@ def test_pause_3( assert "Input not recognised" not in stdout assert "Showing items originally added to" not in stdout assert "Showing tracks for playlist" not in stdout - assert not api_mock.request_history + api_mock.assert_not_called() assert not checker._skip assert checker._quit @@ -398,7 +398,7 @@ def test_match_to_input_quit( ## ``match_to_remote`` step ########################################################################### @staticmethod - def test_match_to_remote_no_changes( + async def test_match_to_remote_no_changes( checker: RemoteItemChecker, setup_playlist_collection: tuple[RemotePlaylist, BasicCollection], api_mock: RemoteMock @@ -406,15 +406,15 @@ def test_match_to_remote_no_changes( pl, collection = setup_playlist_collection # test collection == remote playlist; nothing happens - checker._match_to_remote(collection.name) + await checker._match_to_remote(collection.name) assert not checker._switched assert not checker._remaining pl_pages = api_mock.calculate_pages_from_response(pl.response) - assert len(api_mock.get_requests(url=re.compile(pl.url + ".*"), method="GET")) == pl_pages + assert len(await api_mock.get_requests(method="GET", url=re.compile(pl.url))) == pl_pages @staticmethod - def test_match_to_remote_removed( + async def test_match_to_remote_removed( checker: RemoteItemChecker, setup_playlist_collection: tuple[RemotePlaylist, BasicCollection], ): pl, collection = setup_playlist_collection @@ -423,12 +423,12 @@ def test_match_to_remote_removed( extra_tracks = [track for track in random_tracks(10) if track.has_uri] collection.extend(extra_tracks) - checker._match_to_remote(collection.name) + await checker._match_to_remote(collection.name) assert not checker._switched assert checker._remaining == extra_tracks @staticmethod - def test_match_to_remote_added( + async def test_match_to_remote_added( checker: RemoteItemChecker, setup_playlist_collection: tuple[RemotePlaylist, BasicCollection], ): pl, collection = setup_playlist_collection @@ -438,12 +438,12 @@ def test_match_to_remote_added( for item in collection[:5]: collection.remove(item) - checker._match_to_remote(collection.name) + await checker._match_to_remote(collection.name) assert not checker._switched assert not checker._remaining @staticmethod - def test_match_to_remote_switched( + async def test_match_to_remote_switched( checker: RemoteItemChecker, setup_playlist_collection: tuple[RemotePlaylist, BasicCollection], ): pl, collection = setup_playlist_collection @@ -454,13 +454,13 @@ def test_match_to_remote_switched( collection[i] |= item collection[i].uri = random_uri(kind=RemoteObjectType.TRACK) - checker._match_to_remote(collection.name) + await checker._match_to_remote(collection.name) assert checker._switched == collection[:5] assert not checker._remaining @staticmethod @pytest.mark.slow - def test_match_to_remote_complex( + async def test_match_to_remote_complex( checker: RemoteItemChecker, setup_playlist_collection: tuple[RemotePlaylist, BasicCollection], ): pl, collection = setup_playlist_collection @@ -480,7 +480,7 @@ def test_match_to_remote_complex( for item in collection[5:8]: collection.remove(item) - checker._match_to_remote(collection.name) + await checker._match_to_remote(collection.name) assert checker._switched == collection[:5] assert checker._remaining == 2 * extra_tracks @@ -489,7 +489,7 @@ def test_match_to_remote_complex( ########################################################################### @staticmethod @pytest.mark.slow - def test_check_uri( + async def test_check_uri( checker: RemoteItemChecker, setup_playlist_collection: tuple[RemotePlaylist, BasicCollection], remaining: list[LocalTrack], @@ -514,7 +514,7 @@ def test_check_uri( checker._skip = False checker._playlist_name_collection["do not run"] = collection - checker._check_uri() + await checker._check_uri() mocker.stopall() capfd.close() @@ -535,7 +535,7 @@ def test_check_uri( # called 2x: 1 initial, 1 after user inputs 'r' pl_pages = api_mock.calculate_pages_from_response(pl.response) - assert len(api_mock.get_requests(url=re.compile(pl.url + ".*"), method="GET")) == 2 * pl_pages + assert len(await api_mock.get_requests(method="GET", url=re.compile(pl.url))) == 2 * pl_pages assert checker._final_switched == collection[:5] + remaining[:len(uri_list)] assert checker._final_unavailable == remaining[len(uri_list):len(uri_list) + 3] @@ -546,15 +546,20 @@ def test_check_uri( ########################################################################### @staticmethod @pytest.mark.slow - def test_check( + async def test_check( checker: RemoteItemChecker, collections: list[BasicCollection], playlist_urls: list[str], mocker: MockerFixture, api_mock: RemoteMock, ): - def add_collection(self, collection: BasicCollection): + count = 0 + + async def add_collection(self, collection: BasicCollection): """Just simply add the collection and associated URL to the ItemChecker without calling API""" + nonlocal count + count += 1 + self._playlist_name_urls[collection.name] = playlist_name_urls[collection.name] self._playlist_name_collection[collection.name] = collection @@ -563,14 +568,19 @@ def add_collection(self, collection: BasicCollection): interval = len(collections) // 3 checker.interval = interval - batch = next(batched(collections, interval)) + batched_collections = batched(collections, interval) - # initially skip at pause, then mark all items in all processed collections in the first batch as unavailable - patch_input(["s", *["ua" for _ in batch]], mocker=mocker) + # mark all items in 1st and 2nd batch as unavailable, skip after the 2nd batch and quit + batch_1 = next(batched_collections) + batch_2 = next(batched_collections) + values = ["", *["ua" for _ in batch_1], "s", *["ua" for _ in batch_2]] + patch_input(values, mocker=mocker) - result = checker.check(collections) + result = await checker.check(collections) mocker.stopall() + assert count == len(batch_1) + len(batch_2) # only 2 batches executed + # resets after each run assert not checker._remaining assert not checker._switched @@ -579,17 +589,18 @@ def add_collection(self, collection: BasicCollection): assert not checker._final_skipped assert not result.switched - assert len(result.unavailable) == sum(len(collection) for collection in batch) + assert len(result.unavailable) == sum(len(collection) for collection in batch_1 + batch_2) assert not result.skipped - # all items in all collections in the first batch were marked as unavailable - for collection in batch: - for item in collection: + # all items in all collections in the 1st and 2nd batches were marked as unavailable + for coll in batch_1 + batch_2: + for item in coll: assert item.uri is None assert not item.has_uri - # deleted only the playlists in the first batch + # deleted only the playlists in the first 2 batches requests = [] - for url in (playlist_name_urls[collection.name] for collection in batch): - requests.append(api_mock.get_requests(url=url, method="DELETE")) - assert len(requests) == len(batch) + for url in (playlist_name_urls[collection.name] for collection in batch_1 + batch_2): + requests += await api_mock.get_requests(method="DELETE", url=re.compile(url)) + + assert len(requests) == len(batch_1) + len(batch_2) diff --git a/tests/libraries/remote/core/processors/search.py b/tests/libraries/remote/core/processors/search.py index 935c1b6d..08212557 100644 --- a/tests/libraries/remote/core/processors/search.py +++ b/tests/libraries/remote/core/processors/search.py @@ -1,7 +1,7 @@ from abc import ABCMeta, abstractmethod -from collections.abc import Callable, Iterable +from collections.abc import Iterable, Callable, Awaitable from copy import copy -from urllib.parse import parse_qs +from urllib.parse import unquote import pytest @@ -68,7 +68,7 @@ def unmatchable_items(self) -> list[LocalTrack]: return unmatchable_items @staticmethod - def test_get_results(searcher: RemoteItemSearcher, api_mock: RemoteMock): + async def test_get_results(searcher: RemoteItemSearcher, api_mock: RemoteMock): settings = SearchConfig( search_fields_1=[Tag.NAME, Tag.ARTIST], # query mock always returns match on name search_fields_2=[Tag.NAME, Tag.ALBUM], @@ -77,15 +77,16 @@ def test_get_results(searcher: RemoteItemSearcher, api_mock: RemoteMock): result_count=7 ) item = random_track() - results = searcher._get_results(item=item, kind=RemoteObjectType.TRACK, settings=settings) - requests = api_mock.get_requests(method="GET") + results = await searcher._get_results(item=item, kind=RemoteObjectType.TRACK, settings=settings) + requests = await api_mock.get_requests(method="GET") assert len(results) == settings.result_count assert len(requests) == 1 expected = [str(item.clean_tags.get(key)) for key in settings.search_fields_1] found = False - for k, v in parse_qs(requests[0].query).items(): - if expected == v[0].split(): + url, _, _ = next(iter(requests)) + for k, v in url.query.items(): + if expected == unquote(v).split(): found = True break @@ -95,17 +96,18 @@ def test_get_results(searcher: RemoteItemSearcher, api_mock: RemoteMock): # make these tags too long to query forcing them to return on results item.artist = 'b' * 200 item.album = 'c' * 200 - api_mock.reset_mock() # reset for new requests checks to work correctly + api_mock.reset() # reset for new requests checks to work correctly - results = searcher._get_results(item=item, kind=RemoteObjectType.TRACK, settings=settings) - requests = api_mock.get_requests(method="GET") + results = await searcher._get_results(item=item, kind=RemoteObjectType.TRACK, settings=settings) + requests = await api_mock.get_requests(method="GET") assert len(results) == settings.result_count assert len(requests) == 1 expected = [str(item.clean_tags.get(key)) for key in settings.search_fields_3] found = False - for k, v in parse_qs(requests[0].query).items(): - if expected == v[0].split(): + url, _, _ = next(iter(requests)) + for k, v in url.query.items(): + if expected == unquote(v).split(): found = True break @@ -116,8 +118,11 @@ def test_get_results(searcher: RemoteItemSearcher, api_mock: RemoteMock): ## _search_ tests ########################################################################### @staticmethod - def assert_search( - search_function: Callable[[Iterable[MusifyItemSettable] | MusifyCollection[MusifyItemSettable]], None], + async def assert_search( + search_function: Callable[ + [Iterable[MusifyItemSettable] | MusifyCollection[MusifyItemSettable]], + Awaitable[None] + ], collection: Iterable[MusifyItemSettable], search_items: Iterable[MusifyItemSettable], unmatchable_items: Iterable[MusifyItemSettable], @@ -127,7 +132,7 @@ def assert_search( assert item.has_uri is None assert item.uri is None - search_function(collection) + await search_function(collection) for item in search_items: assert item.has_uri assert item.uri is not None @@ -141,7 +146,7 @@ def assert_search( for item in search_items: item.uri = uri - search_function(collection) + await search_function(collection) for item in search_items: assert item.has_uri assert item.uri == uri @@ -149,27 +154,27 @@ def assert_search( assert item.has_uri is None assert item.uri is None - def test_search_items( + async def test_search_items( self, searcher: RemoteItemSearcher, search_items: list[MusifyItemSettable], unmatchable_items: list[LocalTrack] ): - self.assert_search( + await self.assert_search( searcher._search_items, collection=BasicCollection(name="test", items=search_items + unmatchable_items), search_items=search_items, unmatchable_items=unmatchable_items ) - def test_search_album( + async def test_search_album( self, searcher: RemoteItemSearcher, search_albums: list[Album], unmatchable_items: list[LocalTrack] ): collection = search_albums[0] search_items = copy(collection.tracks) collection.tracks.extend(unmatchable_items) - self.assert_search( + await self.assert_search( searcher._search_collection_unit, collection=collection, search_items=search_items, @@ -196,36 +201,36 @@ def search_album(search_albums: list[LocalAlbum]): return collection @staticmethod - def test_search_result_items( + async def test_search_result_items( searcher: RemoteItemSearcher, search_items: list[LocalTrack], unmatchable_items: list[LocalTrack] ): collection = BasicCollection(name="test", items=search_items + unmatchable_items) - result = searcher._search_collection(collection) + result = await searcher._search_collection(collection) assert len(result.matched) + len(result.unmatched) + len(result.skipped) == len(collection) assert len(result.matched) == len(search_items) assert len(result.unmatched) == len(unmatchable_items) assert len(result.skipped) == 0 # skips all matched on 2nd run - result = searcher._search_collection(collection) + result = await searcher._search_collection(collection) assert len(result.matched) + len(result.unmatched) + len(result.skipped) == len(collection) assert len(result.matched) == 0 assert len(result.unmatched) == len(unmatchable_items) assert len(result.skipped) == len(search_items) @staticmethod - def test_search_result_album(searcher: RemoteItemSearcher, search_album: LocalAlbum): + async def test_search_result_album(searcher: RemoteItemSearcher, search_album: LocalAlbum): skip = len([item for item in search_album if item.has_uri is not None]) - result = searcher._search_collection(search_album) + result = await searcher._search_collection(search_album) assert len(result.matched) + len(result.unmatched) + len(result.skipped) == len(search_album) assert len(result.matched) == len(search_album) - skip assert len(result.unmatched) == 0 assert len(result.skipped) == skip # skips all matched on 2nd run - result = searcher._search_collection(search_album) + result = await searcher._search_collection(search_album) assert len(result.matched) + len(result.unmatched) + len(result.skipped) == len(search_album) assert len(result.matched) == 0 assert len(result.unmatched) == 0 @@ -233,7 +238,7 @@ def test_search_result_album(searcher: RemoteItemSearcher, search_album: LocalAl @staticmethod @pytest.mark.slow - def test_search_result_combined( + async def test_search_result_combined( searcher: RemoteItemSearcher, search_items: list[LocalTrack], search_album: LocalAlbum, @@ -244,7 +249,7 @@ def test_search_result_combined( search_album.items.extend(unmatchable_items) skip = len([item for item in search_album if item.has_uri is not None]) - result = searcher._search_collection(search_album) + result = await searcher._search_collection(search_album) assert len(result.matched) + len(result.unmatched) + len(result.skipped) == len(search_album) assert len(result.matched) == matchable - skip @@ -252,7 +257,7 @@ def test_search_result_combined( assert len(result.skipped) == skip # skips all matched on 2nd run - result = searcher._search_collection(search_album) + result = await searcher._search_collection(search_album) assert len(result.matched) + len(result.unmatched) + len(result.skipped) == len(search_album) assert len(result.matched) == 0 assert len(result.unmatched) == len(unmatchable_items) @@ -262,7 +267,7 @@ def test_search_result_combined( ## main search tests ########################################################################### @staticmethod - def test_search( + async def test_search( searcher: RemoteItemSearcher, search_items: list[LocalTrack], search_album: LocalAlbum, @@ -272,7 +277,7 @@ def test_search( search_collection = BasicCollection(name="test", items=search_items + unmatchable_items) skip_album = len([item for item in search_album if item.has_uri is not None]) - results = searcher([search_collection, search_album]) + results = await searcher.search([search_collection, search_album]) result = results[search_collection.name] assert len(result.matched) + len(result.unmatched) + len(result.skipped) == len(search_collection) @@ -287,7 +292,7 @@ def test_search( assert len(result.skipped) == skip_album # check nothing happens on matched collections - api_mock.reset_mock() # reset for new requests checks to work correctly + api_mock.reset() # reset for new requests checks to work correctly search_matched = BasicCollection(name="test", items=search_items) - assert len(searcher.search([search_matched, search_album])) == 0 - assert len(api_mock.request_history) == 0 + assert len(await searcher.search([search_matched, search_album])) == 0 + api_mock.assert_not_called() diff --git a/tests/libraries/remote/core/utils.py b/tests/libraries/remote/core/utils.py index 616c577c..dc799af2 100644 --- a/tests/libraries/remote/core/utils.py +++ b/tests/libraries/remote/core/utils.py @@ -1,12 +1,12 @@ import re -from abc import abstractmethod +from abc import ABC, abstractmethod from collections.abc import Mapping -from typing import Any -from urllib.parse import parse_qs +from typing import Any, ContextManager -from requests_mock import Mocker -# noinspection PyProtectedMember,PyUnresolvedReferences -from requests_mock.request import _RequestObjectProxy +from aiohttp import ClientResponse +from aioresponses import aioresponses +from aioresponses.core import RequestCall +from yarl import URL from musify.libraries.remote.core.enum import RemoteIDType, RemoteObjectType @@ -14,7 +14,7 @@ ALL_ITEM_TYPES = RemoteObjectType.all() -class RemoteMock(Mocker): +class RemoteMock(aioresponses, ContextManager, ABC): """Generates responses and sets up Remote API requests mock""" range_start = 25 @@ -25,6 +25,17 @@ class RemoteMock(Mocker): limit_upper = 20 limit_max = 50 + requests: dict[tuple[str, URL], list[RequestCall]] + _responses: list[ClientResponse] + + def __enter__(self): + super().__enter__() + self.setup_mock() + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + super().__exit__(exc_type, exc_val, exc_tb) + @property @abstractmethod def item_type_map(self) -> dict[RemoteObjectType, list[dict[str, Any]]]: @@ -37,6 +48,21 @@ def item_type_map_user(self) -> dict[RemoteObjectType, list[dict[str, Any]]]: """Map of :py:class:`RemoteObjectType` to the mocked user items mapped as ``{: }``""" raise NotImplementedError + @abstractmethod + def setup_mock(self): + """Driver to set up mock responses for all endpoints""" + raise NotImplementedError + + @property + def total_requests(self) -> int: + """Returns the total number of requests made to this mock.""" + return sum(len(reqs) for reqs in self.requests.values()) + + def reset(self) -> None: + """Reset the log for the history of requests and responses for by this mock. Does not reset matches.""" + self.requests.clear() + self._responses.clear() + @staticmethod def calculate_pages(limit: int, total: int) -> int: """ @@ -55,65 +81,83 @@ def calculate_pages_from_response(self, response: Mapping[str, Any]) -> int: """ raise NotImplementedError - def get_requests( + async def get_requests( self, - url: str | re.Pattern[str] | None = None, method: str | None = None, + url: str | URL | re.Pattern[str] | None = None, # matches given after params have been stripped params: dict[str, Any] | None = None, response: dict[str, Any] | None = None - ) -> list[_RequestObjectProxy]: + ) -> list[tuple[URL, RequestCall, ClientResponse | None]]: """Get a get request from the history from the given URL and params""" - requests = [] - for request in self.request_history: - matches = [ - self._get_match_from_url(request=request, url=url), - self._get_match_from_method(request=request, method=method), - self._get_match_from_params(request=request, params=params), - self._get_match_from_response(request=request, response=response), - ] - if all(matches): - requests.append(request) - - return requests + results: list[tuple[URL, RequestCall]] = [] + for (request_method, request_url), requests in self.requests.items(): + for request in requests: + matches = [ + self._get_match_from_method(actual=request_method, expected=method), + self._get_match_from_url(actual=request_url, expected=url), + self._get_match_from_params(actual=request, expected=params), + await self._get_match_from_expected_response(actual=request_url, expected=response), + ] + if all(matches): + results.append((request_url, request)) + + return [(url, request, self._get_response_from_url(url=url)) for url, request in results] @staticmethod - def _get_match_from_url(request: _RequestObjectProxy, url: str | re.Pattern[str] | None = None) -> bool: - match = url is None + def _get_match_from_method(actual: str, expected: str | None = None) -> bool: + match = expected is None if not match: - if isinstance(url, str): - match = url.strip("/").endswith(request.path.strip("/")) - elif isinstance(url, re.Pattern): - match = bool(url.search(request.url)) + # noinspection PyProtectedMember + match = actual.upper() == expected.upper() return match @staticmethod - def _get_match_from_method(request: _RequestObjectProxy, method: str | None = None) -> bool: - match = method is None + def _get_match_from_url(actual: str | URL, expected: str | URL | re.Pattern[str] | None = None) -> bool: + match = expected is None if not match: - # noinspection PyProtectedMember - match = request._request.method.upper() == method.upper() + actual = str(actual).rstrip("/").split("?")[0] + if isinstance(expected, str): + match = actual == expected.split("?")[0] + elif isinstance(expected, URL): + match = actual == str(expected.with_query(None)) + elif isinstance(expected, re.Pattern): + match = bool(expected.search(actual)) return match @staticmethod - def _get_match_from_params(request: _RequestObjectProxy, params: dict[str, Any] | None = None) -> bool: - match = params is None - if not match and request.query: - for k, v in parse_qs(request.query).items(): - if k in params and str(params[k]) != v[0]: + def _get_match_from_params(actual: RequestCall, expected: dict[str, Any] | None = None) -> bool: + match = expected is None + if not match and (request_params := actual.kwargs.get("params")): + for k, v in request_params.items(): + if k in expected and str(expected[k]) != v: break match = True return match - @staticmethod - def _get_match_from_response(request: _RequestObjectProxy, response: dict[str, Any] | None = None) -> bool: - match = response is None - if not match and request.body: - for k, v in request.json().items(): - if k in response and str(response[k]) != str(v): + async def _get_match_from_expected_response( + self, actual: str | URL, expected: dict[str, Any] | None = None + ) -> bool: + match = expected is None + if not match: + response = self._get_response_from_url(url=actual) + if response is None: + return match + + payload = await response.json() + for k, v in payload.items(): + if k in expected and str(expected[k]) != str(v): break match = True return match + + def _get_response_from_url(self, url: str | URL) -> ClientResponse | None: + response = None + for response in self._responses: + if str(response.url) == url: + break + + return response diff --git a/tests/libraries/remote/spotify/api/mock.py b/tests/libraries/remote/spotify/api/mock.py index c995b889..0ea277c9 100644 --- a/tests/libraries/remote/spotify/api/mock.py +++ b/tests/libraries/remote/spotify/api/mock.py @@ -4,14 +4,12 @@ from datetime import datetime from random import choice, randrange, sample, random, shuffle from typing import Any -from urllib.parse import parse_qs +from urllib.parse import unquote from uuid import uuid4 +from aioresponses import CallbackResult from pycountry import countries, languages -# noinspection PyProtectedMember,PyUnresolvedReferences -from requests_mock.request import _RequestObjectProxy as Request -# noinspection PyProtectedMember,PyUnresolvedReferences -from requests_mock.response import _Context as Context +from yarl import URL from musify.libraries.remote.core.enum import RemoteObjectType as ObjectType from musify.libraries.remote.spotify.api import SpotifyAPI @@ -140,12 +138,6 @@ def get_duration(track: dict[str, Any]) -> int: self.user_audiobooks = [self.format_user_item(deepcopy(item), ObjectType.AUDIOBOOK) for item in self.audiobooks] self.setup_specific_conditions_user() - self.setup_requests_mock() - - track_ids = {track["id"] for track in self.tracks} - for album in self.albums: - for track in album["tracks"]["items"]: - assert track["id"] in track_ids ########################################################################### ## Setup - cross-reference @@ -235,12 +227,15 @@ def setup_valid_references(self): ########################################################################### ## Setup - requests ########################################################################### - def setup_requests_mock(self): - """Driver to setup requests_mock responses for all endpoints""" + def setup_mock(self): url_api = SpotifyDataWrangler.url_api self.setup_search_mock() - self.get(url=f"{url_api}/me", json=lambda _, __: deepcopy(self.user)) + self.get( + url=f"{url_api}/me", + callback=lambda *_, **__: CallbackResult(method="GET", payload=deepcopy(self.user)), + repeat=True + ) # setup responses as needed for each item type self.setup_items_mock(kind=ObjectType.TRACK, id_map={item["id"]: item for item in self.tracks}) @@ -273,14 +268,14 @@ def setup_requests_mock(self): items += generator(collection, count) url = f"{collection["href"]}/{key}" - self.setup_items_block_mock(url=url, items=items, total=collection[key]["total"]) + self.setup_items_block_mock(url_items=url, items=items, total=collection[key]["total"]) # artist's albums for i, artist in enumerate(self.artists): id_ = artist["id"] items = [album for album in self.artist_albums if any(art["id"] == id_ for art in album["artists"])] url = f"{url_api}/artists/{id_}/albums" - self.setup_items_block_mock(url=url, items=items) + self.setup_items_block_mock(url_items=url, items=items) # when getting a user's playlists, individual tracks are not returned user_playlists_reduced = deepcopy(self.user_playlists) @@ -288,27 +283,29 @@ def setup_requests_mock(self): playlist["tracks"] = {"href": playlist["tracks"]["href"], "total": playlist["tracks"]["total"]} # setup responses as needed for each 'user's saved' type - self.setup_items_block_mock(url=f"{url_api}/me/tracks", items=self.user_tracks) - self.setup_items_block_mock(url=f"{url_api}/me/following", items=self.user_artists) - self.setup_items_block_mock(url=f"{url_api}/me/episodes", items=self.user_episodes) + self.setup_items_block_mock(url_items=f"{url_api}/me/tracks", items=self.user_tracks) + self.setup_items_block_mock(url_items=f"{url_api}/me/following", items=self.user_artists) + self.setup_items_block_mock(url_items=f"{url_api}/me/episodes", items=self.user_episodes) - self.setup_items_block_mock(url=f"{url_api}/me/playlists", items=user_playlists_reduced) - self.setup_items_block_mock(url=f"{url_api}/users/{self.user_id}/playlists", items=user_playlists_reduced) - self.setup_items_block_mock(url=f"{url_api}/me/albums", items=self.user_albums) - self.setup_items_block_mock(url=f"{url_api}/me/shows", items=self.user_shows) - self.setup_items_block_mock(url=f"{url_api}/me/audiobooks", items=self.user_audiobooks) + self.setup_items_block_mock(url_items=f"{url_api}/me/playlists", items=user_playlists_reduced) + self.setup_items_block_mock(url_items=f"{url_api}/users/{self.user_id}/playlists", items=user_playlists_reduced) + self.setup_items_block_mock(url_items=f"{url_api}/me/albums", items=self.user_albums) + self.setup_items_block_mock(url_items=f"{url_api}/me/shows", items=self.user_shows) + self.setup_items_block_mock(url_items=f"{url_api}/me/audiobooks", items=self.user_audiobooks) self.setup_playlist_operations_mock() def setup_search_mock(self): """Setup requests mock for getting responses from the ``/search`` endpoint""" - def response_getter(req: Request, _: Context) -> dict[str, Any]: + def callback(url: URL, params: dict[str, Any], **_) -> CallbackResult: """Dynamically generate expected batched response from a request with an 'ids' param""" - req_params = parse_qs(req.query) - limit = int(req_params["limit"][0]) - offset = int(req_params.get("offset", [0])[0]) - query = req_params["q"][0] - kinds = req_params["type"][0].split(",") + params = params if params is not None else {} + params |= url.query + + limit = int(params["limit"]) + offset = int(params.get("offset", 0)) + query = unquote(params["q"]) + kinds = unquote(params["type"]).split(",") count = 0 total = 0 @@ -333,13 +330,14 @@ def response_getter(req: Request, _: Context) -> dict[str, Any]: shuffle(results[kind + "s"]) count += len(results[kind + "s"]) - return { - kind: self.format_items_block(url=url, items=items, offset=offset, limit=limit, total=total) + payload = { + kind: self.format_items_block(url=str(url), items=items, offset=offset, limit=limit, total=total) for kind, items in results.items() } + return CallbackResult(method="GET", payload=payload) - url = f"{SpotifyDataWrangler.url_api}/search" - self.get(url=re.compile(url + r"\?"), json=response_getter) + url_search = f"{SpotifyDataWrangler.url_api}/search" + self.get(url=re.compile(url_search + r"\?"), callback=callback, repeat=True) def setup_items_mock( self, kind: ObjectType | str, id_map: dict[str, dict[str, Any]], batchable: bool = True @@ -349,44 +347,47 @@ def setup_items_mock( Sets up mocks for /{``kind``}/{ID} endpoints for multi-calls and, when ``batchable`` is True, /{``kind``}?... for batchable-calls. """ - def response_getter(req: Request, _: Context) -> dict[str, Any]: + def callback(url: URL, params: dict[str, Any] = None, **_) -> CallbackResult: """Dynamically generate expected batched response from a request with an 'ids' param""" - req_params = parse_qs(req.query) - req_kind = req.path.split("/")[-1].replace("-", "_") + params = params if params is not None else {} + params |= url.query + req_kind = url.path.split("/")[-1].replace("-", "_") - id_list = req_params["ids"][0].split(",") - return {req_kind: [deepcopy(id_map[i]) for i in id_list]} + id_list = unquote(params["ids"]).split(",") + payload = {req_kind: [deepcopy(id_map[i]) for i in id_list]} + return CallbackResult(method="GET", payload=payload) url_api = SpotifyDataWrangler.url_api - url = f"{url_api}/{kind.name.lower()}s" if isinstance(kind, ObjectType) else f"{url_api}/{kind}" - if batchable: - self.get(url=re.compile(url + r"\?"), json=response_getter) # item batched calls + url_items = f"{url_api}/{kind.name.lower()}s" if isinstance(kind, ObjectType) else f"{url_api}/{kind}" - for id_, item in id_map.items(): - self.get(url=f"{url}/{id_}", json=deepcopy(item)) # item multi calls + if batchable: # item batched calls + self.get(url=re.compile(url_items + r"\?"), callback=callback, repeat=True) + for id_, item in id_map.items(): # item multi calls + self.get(url=re.compile(f"{url_items}/{id_}" + r"(\?|$)"), payload=deepcopy(item), repeat=True) - def setup_items_block_mock(self, url: str, items: list[dict[str, Any]], total: int | None = None) -> None: + def setup_items_block_mock(self, url_items: str, items: list[dict[str, Any]], total: int | None = None) -> None: """Setup requests mock for returning preset responses from the given ``items`` in an 'items block' format.""" - def response_getter(req: Request, _: Context) -> dict[str, Any]: + def callback(url: URL, params: dict[str, Any] = None, **__) -> CallbackResult: """Dynamically generate expected response for an items block from the given ``generator``""" - req_params = parse_qs(req.query) - limit = int(req_params["limit"][0]) - offset = int(req_params.get("offset", [0])[0]) + params = params if params is not None else {} + params |= url.query + limit = int(params["limit"]) + offset = int(params.get("offset", 0)) available = items available_total = total - if re.match(r".*/artists/\w+/albums$", url) and "include_groups" in req_params: + if re.match(r".*/artists/\w+/albums$", url.path) and "include_groups" in params: # special case for artist's albums - types = req_params["include_groups"][0].split(",") + types = unquote(params["include_groups"]).split(",") available = [i for i in items if i["album_type"] in types] available_total = len(available) it = deepcopy(available[offset: offset + limit]) items_block = self.format_items_block( - url=req.url, items=it, offset=offset, limit=limit, total=available_total + url=str(url), items=it, offset=offset, limit=limit, total=available_total ) - if url.endswith("me/following"): # special case for following artists + if url.path.endswith("me/following"): # special case for following artists items_block["cursors"] = {} if offset < available_total: items_block["cursors"]["after"] = it[-1]["id"] @@ -395,40 +396,43 @@ def response_getter(req: Request, _: Context) -> dict[str, Any]: items_block = {"artists": items_block} - return items_block + return CallbackResult(method="GET", payload=items_block) total = total or len(items) - self.get(url=re.compile(url + r"\?"), json=response_getter) + self.get(url=re.compile(url_items + r"\?"), callback=callback, repeat=True) def setup_playlist_operations_mock(self) -> None: """Generate playlist and setup ``requests_mock`` for playlist operations tests""" for playlist in self.user_playlists: - self.post(url=re.compile(playlist["href"] + "/tracks"), json={"snapshot_id": str(uuid4())}) - self.delete(url=re.compile(playlist["href"] + "/tracks"), json={"snapshot_id": str(uuid4())}) - self.delete(url=re.compile(playlist["href"]), json={"snapshot_id": str(uuid4())}) + self.post( + url=re.compile(playlist["href"] + "/tracks"), payload={"snapshot_id": str(uuid4())}, repeat=True + ) + self.delete( + url=re.compile(playlist["href"] + "/tracks"), payload={"snapshot_id": str(uuid4())}, repeat=True + ) + self.delete(url=re.compile(playlist["href"]), payload={"snapshot_id": str(uuid4())}, repeat=True) - def create_response_getter(req: Request, _: Context) -> dict[str, Any]: + def callback(_: str, json: dict[str, Any], **__) -> CallbackResult: """Process body and generate playlist response data""" - data = req.json() - playlist_ids = {pl["id"] for pl in self.playlists} - response = self.generate_playlist(owner=self.user, item_count=0) - while response["id"] in playlist_ids: - response = self.generate_playlist(owner=self.user, item_count=0) + payload = self.generate_playlist(owner=self.user, item_count=0) + while payload["id"] in playlist_ids: + payload = self.generate_playlist(owner=self.user, item_count=0) + + payload["name"] = json["name"] + payload["description"] = json["description"] + payload["public"] = json["public"] + payload["collaborative"] = json["collaborative"] + payload["owner"] = self.user_playlists[0]["owner"] - response["name"] = data["name"] - response["description"] = data["description"] - response["public"] = data["public"] - response["collaborative"] = data["collaborative"] - response["owner"] = self.user_playlists[0]["owner"] + self.get(url=re.compile(payload["href"]), payload=payload, repeat=True) + self.post(url=re.compile(payload["href"] + "/tracks"), payload={"snapshot_id": str(uuid4())}, repeat=True) + self.delete(url=re.compile(payload["href"] + "/tracks"), payload={"snapshot_id": str(uuid4())}, repeat=True) - self.get(url=response["href"], json=response) - self.post(url=re.compile(response["href"] + "/tracks"), json={"snapshot_id": str(uuid4())}) - self.delete(url=re.compile(response["href"] + "/tracks"), json={"snapshot_id": str(uuid4())}) - return response + return CallbackResult(method="POST", payload=payload) url = f"{SpotifyDataWrangler.url_api}/users/{self.user_id}/playlists" - self.post(url=url, json=create_response_getter) + self.post(url=url, callback=callback, repeat=True) ########################################################################### ## Formatters diff --git a/tests/libraries/remote/spotify/api/test_api.py b/tests/libraries/remote/spotify/api/test_api.py index f5ceb9ee..caf74c3e 100644 --- a/tests/libraries/remote/spotify/api/test_api.py +++ b/tests/libraries/remote/spotify/api/test_api.py @@ -1,8 +1,13 @@ +from random import choice + import pytest -from musify.api.cache.backend.base import ResponseCache, PaginatedRequestSettings +from musify.api.cache.backend.base import ResponseCache from musify.api.cache.backend.sqlite import SQLiteCache +from musify.api.cache.session import CachedSession +from musify.api.exception import APIError from musify.libraries.remote.spotify.api import SpotifyAPI +from tests.libraries.remote.spotify.api.mock import SpotifyMock from tests.libraries.remote.spotify.utils import random_id from tests.utils import random_str @@ -10,11 +15,12 @@ class TestSpotifyAPI: @pytest.fixture - def cache(self) -> ResponseCache: + async def cache(self) -> ResponseCache: """Yields a valid :py:class:`ResponseCache` to use throughout tests in this suite as a pytest.fixture.""" - return SQLiteCache.connect_with_in_memory_db() + async with SQLiteCache.connect_with_in_memory_db() as cache: + yield cache - def test_init_authoriser(self, cache: ResponseCache): + def test_init(self, cache: ResponseCache): client_id = "CLIENT_ID" client_secret = "CLIENT_SECRET" scopes = ["scope 1", "scope 2"] @@ -34,51 +40,72 @@ def test_init_authoriser(self, cache: ResponseCache): assert api.handler.authoriser.user_args["params"]["scope"] == " ".join(scopes) assert api.handler.authoriser.token_file_path == token_file_path - assert api.handler.cache.cache_name == cache.cache_name - - def test_init_cache(self, cache: ResponseCache): - SpotifyAPI(cache=cache) - - expected_names_normal = [ - "tracks", - "audio_features", - "audio_analysis", - "albums", - "artists", - "episodes", - "chapters", - ] - expected_names_paginated = ["album_tracks", "artist_albums", "show_episodes", "audiobook_chapters"] - - assert all(name in cache for name in expected_names_normal) - assert all(name in cache for name in expected_names_paginated) - - assert all(not isinstance(cache[name].settings, PaginatedRequestSettings) for name in expected_names_normal) - assert all(isinstance(cache[name].settings, PaginatedRequestSettings) for name in expected_names_paginated) - - def test_init_cache_repository_getter(self, cache: ResponseCache): - api = SpotifyAPI(cache=cache) - - name_url_map = { - "tracks": f"{api.wrangler.url_api}/tracks/{random_id()}", - "artists": f"{api.wrangler.url_api}/artists?ids={",".join(random_id() for _ in range(10))}", - "albums": f"{api.wrangler.url_api}/albums?ids={",".join(random_id() for _ in range(50))}", - } - names_paginated = ["artist_albums", "album_tracks"] - for name in names_paginated: - parent, child = name.split("_") - parent = parent.rstrip("s") + "s" - child = child.rstrip("s") + "s" - - url = f"{api.wrangler.url_api}/{parent}/{random_id()}/{child}" - name_url_map[name] = url - - for name, url in name_url_map.items(): - repository = cache.get_repository_from_url(url) - assert repository.settings.name == name - - # un-cached URLs - assert cache.get_repository_from_url(f"{api.wrangler.url_api}/me/albums") is None - assert cache.get_repository_from_url(f"{api.wrangler.url_api}/search") is None - assert cache.get_repository_from_url(f"{api.wrangler.url_api}/playlists/{random_id()}/followers") is None - assert cache.get_repository_from_url(f"{api.wrangler.url_api}/users/{random_str(10, 30)}/playlists") is None + async def test_context_management(self, cache: ResponseCache, api_mock: SpotifyMock): + api = SpotifyAPI( + cache=cache, + token={"access_token": "fake access token", "token_type": "Bearer", "scope": "test-read"}, + test_args=None, + test_expiry=0, + test_condition=None, + ) + + with pytest.raises(APIError): + assert api.user_id + + async with api as a: + assert a.user_id == api_mock.user_id + + assert isinstance(a.handler.session, CachedSession) + assert a.handler.session.cache.cache_name == cache.cache_name + + expected_names = [ + "tracks", + "audio_features", + "audio_analysis", + "albums", + "artists", + "episodes", + "chapters", + "album_tracks", + "artist_albums", + "show_episodes", + "audiobook_chapters" + ] + + assert all(name in cache for name in expected_names) + + repository = choice(list(a.handler.session.cache.values())) + await repository.count() # just check this doesn't fail + + # noinspection PyTestUnpassedFixture + async def test_cache_repository_getter(self, cache: ResponseCache, api_mock: SpotifyMock): + async with SpotifyAPI( + cache=cache, + token={"access_token": "fake access token", "token_type": "Bearer", "scope": "test-read"}, + test_args=None, + test_expiry=0, + test_condition=None, + ) as api: + name_url_map = { + "tracks": f"{api.wrangler.url_api}/tracks/{random_id()}", + "artists": f"{api.wrangler.url_api}/artists?ids={",".join(random_id() for _ in range(10))}", + "albums": f"{api.wrangler.url_api}/albums?ids={",".join(random_id() for _ in range(50))}", + } + names_paginated = ["artist_albums", "album_tracks"] + for name in names_paginated: + parent, child = name.split("_") + parent = parent.rstrip("s") + "s" + child = child.rstrip("s") + "s" + + url = f"{api.wrangler.url_api}/{parent}/{random_id()}/{child}" + name_url_map[name] = url + + for name, url in name_url_map.items(): + repository = cache.get_repository_from_url(url) + assert repository.settings.name == name + + # un-cached URLs + assert cache.get_repository_from_url(f"{api.wrangler.url_api}/me/albums") is None + assert cache.get_repository_from_url(f"{api.wrangler.url_api}/search") is None + assert cache.get_repository_from_url(f"{api.wrangler.url_api}/playlists/{random_id()}/followers") is None + assert cache.get_repository_from_url(f"{api.wrangler.url_api}/users/{random_str(10, 30)}/playlists") is None diff --git a/tests/libraries/remote/spotify/api/test_artist.py b/tests/libraries/remote/spotify/api/test_artist.py index 76e6cab1..1810ce10 100644 --- a/tests/libraries/remote/spotify/api/test_artist.py +++ b/tests/libraries/remote/spotify/api/test_artist.py @@ -71,7 +71,7 @@ def assert_artist_albums_enriched(albums: list[dict[str, Any]]) -> None: assert album["tracks"]["total"] == album["total_tracks"] assert album["id"] in album["tracks"]["href"] - def assert_artist_albums_results( + async def assert_artist_albums_results( self, results: dict[str, list[dict[str, Any]]], source: dict[str, dict[str, Any]], @@ -90,7 +90,7 @@ def assert_artist_albums_results( # appropriate number of requests made url = f"{api.url}/artists/{id_}/albums" - requests = api_mock.get_requests(url=url) + requests = await api_mock.get_requests(url=url) assert_calls(expected=expected[id_], requests=requests, limit=limit, api_mock=api_mock) if not update: @@ -102,7 +102,7 @@ def assert_artist_albums_results( assert reduced == expected[id_] self.assert_artist_albums_enriched(source[id_]["albums"]["items"]) - def test_get_artist_albums_single_string( + async def test_get_artist_albums_single_string( self, artist_albums: list[dict[str, Any]], artist: dict[str, Any], @@ -111,13 +111,13 @@ def test_get_artist_albums_single_string( api_mock: SpotifyMock, ): limit = get_limit(artist_albums, api_mock.limit_max) - results = api.get_artist_albums( + results = await api.get_artist_albums( values=random_id_type(id_=artist["id"], wrangler=api.wrangler, kind=RemoteObjectType.ARTIST), types=artist_album_types, limit=limit ) - self.assert_artist_albums_results( + await self.assert_artist_albums_results( results=results, source={artist["id"]: artist}, expected={artist["id"]: artist_albums}, @@ -127,7 +127,7 @@ def test_get_artist_albums_single_string( update=False ) - def test_get_artist_albums_single_mapping( + async def test_get_artist_albums_single_mapping( self, artist_albums: list[dict[str, Any]], artist: dict[str, Any], @@ -136,9 +136,9 @@ def test_get_artist_albums_single_mapping( api_mock: SpotifyMock, ): limit = get_limit(artist_albums, api_mock.limit_max) - results = api.get_artist_albums(values=artist, types=artist_album_types, limit=limit) + results = await api.get_artist_albums(values=artist, types=artist_album_types, limit=limit) - self.assert_artist_albums_results( + await self.assert_artist_albums_results( results=results, source={artist["id"]: artist}, expected={artist["id"]: artist_albums}, @@ -148,7 +148,7 @@ def test_get_artist_albums_single_mapping( update=True ) - def test_get_artist_albums_many_string( + async def test_get_artist_albums_many_string( self, artists_albums: dict[str, list[dict[str, Any]]], artists: dict[str, dict[str, Any]], @@ -157,13 +157,13 @@ def test_get_artist_albums_many_string( api_mock: SpotifyMock, ): limit = 50 - results = api.get_artist_albums( + results = await api.get_artist_albums( values=random_id_types(id_list=artists, wrangler=api.wrangler, kind=RemoteObjectType.ARTIST), types=artist_album_types, limit=limit, ) - self.assert_artist_albums_results( + await self.assert_artist_albums_results( results=results, source=artists, expected=artists_albums, @@ -173,7 +173,7 @@ def test_get_artist_albums_many_string( update=False ) - def test_get_artist_albums_many_mapping( + async def test_get_artist_albums_many_mapping( self, artists_albums: dict[str, list[dict[str, Any]]], artists: dict[str, dict[str, Any]], @@ -182,9 +182,9 @@ def test_get_artist_albums_many_mapping( api_mock: SpotifyMock, ): limit = 50 - results = api.get_artist_albums(values=artists.values(), types=artist_album_types, limit=limit) + results = await api.get_artist_albums(values=artists.values(), types=artist_album_types, limit=limit) - self.assert_artist_albums_results( + await self.assert_artist_albums_results( results=results, source=artists, expected=artists_albums, @@ -194,7 +194,7 @@ def test_get_artist_albums_many_mapping( update=True ) - def test_artist_albums_single_response( + async def test_artist_albums_single_response( self, artist_albums: list[dict[str, Any]], artist: dict[str, Any], @@ -205,10 +205,10 @@ def test_artist_albums_single_response( test = SpotifyArtist(artist, skip_checks=True) assert len(test.albums) < len(artist_albums) - api.get_artist_albums(values=test, types=artist_album_types) + await api.get_artist_albums(values=test, types=artist_album_types) assert len(test.albums) == len(artist_albums) - def test_artist_albums_many_response( + async def test_artist_albums_many_response( self, artists_albums: dict[str, list[dict[str, Any]]], artists: dict[str, dict[str, Any]], @@ -220,6 +220,6 @@ def test_artist_albums_many_response( for artist in test: assert len(artist.albums) < len(artists_albums[artist.id]) - api.get_artist_albums(values=test, types=artist_album_types) + await api.get_artist_albums(values=test, types=artist_album_types) for artist in test: assert len(artist.albums) == len(artists_albums[artist.id]) diff --git a/tests/libraries/remote/spotify/api/test_cache.py b/tests/libraries/remote/spotify/api/test_cache.py index 4a2ca8c0..100d7402 100644 --- a/tests/libraries/remote/spotify/api/test_cache.py +++ b/tests/libraries/remote/spotify/api/test_cache.py @@ -20,13 +20,13 @@ def test_core_getters(self, settings: SpotifyRequestSettings, api_mock: SpotifyM response = choice(responses) name = response["display_name"] if response["type"] == "user" else response["name"] assert settings.get_name(response) == name - assert settings.get_id(response["href"]) == response["id"] + assert settings.get_key(response["href"]) == (response["id"],) response = choice(api_mock.user_tracks) assert settings.get_name(response) is None url = f"{SpotifyDataWrangler.url_api}/me/tracks" - assert settings.get_id(url) is None + assert settings.get_key(url) == (None,) @pytest.fixture(scope="class") def settings_paginated(self) -> SpotifyPaginatedRequestSettings: diff --git a/tests/libraries/remote/spotify/api/test_item.py b/tests/libraries/remote/spotify/api/test_item.py index cb46cd21..10d92aaf 100644 --- a/tests/libraries/remote/spotify/api/test_item.py +++ b/tests/libraries/remote/spotify/api/test_item.py @@ -3,11 +3,12 @@ from itertools import batched from random import sample, randrange, choice from typing import Any -from urllib.parse import parse_qs +from urllib.parse import unquote import pytest from musify.api.cache.backend.sqlite import SQLiteCache +from musify.api.cache.session import CachedSession from musify.api.exception import APIError from musify.libraries.remote.core import RemoteResponse from musify.libraries.remote.core.enum import RemoteIDType, RemoteObjectType @@ -39,10 +40,16 @@ def responses(self, _responses: dict[str, dict[str, Any]], key: str) -> dict[str return {id_: response for id_, response in _responses.items() if key is None or response[key]["total"] > 3} @pytest.fixture(scope="class") - def api_cache(self, api: SpotifyAPI) -> SpotifyAPI: + async def api_cache(self, api: SpotifyAPI) -> SpotifyAPI: """Yield an authorised :py:class:`SpotifyAPI` object with a :py:class:`ResponseCache` configured.""" - cache = SQLiteCache.connect_with_in_memory_db() - return SpotifyAPI(cache=cache, token=api.handler.authoriser.token) + async with SpotifyAPI( + cache=SQLiteCache.connect_with_in_memory_db(), + token=api.handler.authoriser.token, + test_args=api.handler.authoriser.test_args, + test_expiry=api.handler.authoriser.test_expiry, + test_condition=api.handler.authoriser.test_condition, + ) as api: + yield api @staticmethod def reduce_items(response: dict[str, Any], key: str, api: SpotifyAPI, api_mock: SpotifyMock, pages: int = 3) -> int: @@ -136,7 +143,7 @@ def assert_get_items_results( assert len(test[result[self.id_key]][key]["items"]) == expect[key]["total"] self.assert_similar(expect, test[result[self.id_key]], key) - def assert_get_items_calls( + async def assert_get_items_calls( self, responses: Collection[dict[str, Any]], object_type: RemoteObjectType, @@ -147,12 +154,12 @@ def assert_get_items_calls( ): """Assert appropriate number of requests made for get_items method calls""" url = f"{api.url}/{object_type.name.lower()}s" - requests = api_mock.get_requests(url=url) + requests = await api_mock.get_requests(url=url) for response in responses: if limit is None or object_type in {RemoteObjectType.USER, RemoteObjectType.PLAYLIST}: - requests += api_mock.get_requests(url=f"{url}/{response[self.id_key]}") + requests += await api_mock.get_requests(url=f"{url}/{response[self.id_key]}") if key: - requests += api_mock.get_requests(url=f"{url}/{response[self.id_key]}/{key}") + requests += await api_mock.get_requests(url=f"{url}/{response[self.id_key]}/{key}") assert_calls(expected=responses, requests=requests, key=key, limit=limit, api_mock=api_mock) @staticmethod @@ -180,20 +187,19 @@ def test_get_unit(self, api: SpotifyAPI): assert api._get_unit(kind="Audio Features", key="tracks") == "audio features" assert api._get_unit(key="audio-features") == "audio features" - def test_get_items_batches_limited(self, api: SpotifyAPI, api_mock: SpotifyMock): + async def test_get_items_batches_limited(self, api: SpotifyAPI, api_mock: SpotifyMock): key = RemoteObjectType.TRACK.name.lower() + "s" url = f"{api.url}/{key}" id_list = [track[self.id_key] for track in api_mock.tracks] valid_limit = randrange(api_mock.limit_lower + 1, api_mock.limit_upper - 1) id_list_reduced = sample(id_list, k=api_mock.limit_lower) - api._get_items_batched(url=url, id_list=id_list_reduced, key=key, limit=api_mock.limit_upper - 50) - api._get_items_batched(url=url, id_list=id_list, key=key, limit=api_mock.limit_upper + 50) - api._get_items_batched(url=url, id_list=id_list, key=key, limit=valid_limit) + await api._get_items_batched(url=url, id_list=id_list_reduced, key=key, limit=api_mock.limit_upper - 50) + await api._get_items_batched(url=url, id_list=id_list, key=key, limit=api_mock.limit_upper + 50) + await api._get_items_batched(url=url, id_list=id_list, key=key, limit=valid_limit) - for request in api_mock.get_requests(url=url): - request_params = parse_qs(request.query) - count = len(request_params["ids"][0].split(",")) + for url, _, _ in await api_mock.get_requests(url=url): + count = len(unquote(url.query["ids"]).split(",")) assert count >= 1 assert count <= api_mock.limit_max @@ -201,7 +207,7 @@ def test_get_items_batches_limited(self, api: SpotifyAPI, api_mock: SpotifyMock) ## Input validation ########################################################################### @pytest.mark.parametrize("object_type", [RemoteObjectType.ALBUM], ids=idfn) - def test_extend_items_input_validation( + async def test_extend_items_input_validation( self, object_type: RemoteObjectType, response: dict[str, Any], @@ -213,41 +219,41 @@ def test_extend_items_input_validation( while len(response[key][api.items_key]) < response[key]["total"]: response[key][api.items_key].append(choice(response[key][api.items_key])) - api.extend_items(response, kind=object_type, key=api.collection_item_map[object_type]) - assert not api_mock.request_history + await api.extend_items(response, kind=object_type, key=api.collection_item_map[object_type]) + api_mock.assert_not_called() - def test_get_user_items_input_validation(self, api: SpotifyAPI): + async def test_get_user_items_input_validation(self, api: SpotifyAPI): # raises error when invalid item type given for kind in set(ALL_ITEM_TYPES) - api.user_item_types: with pytest.raises(RemoteObjectTypeError): - api.get_user_items(kind=kind) + await api.get_user_items(kind=kind) # may only get valid user item types that are not playlists from the currently authorised user for kind in api.user_item_types - {RemoteObjectType.PLAYLIST}: with pytest.raises(RemoteObjectTypeError): - api.get_user_items(user=random_str(1, RemoteIDType.ID.value - 1), kind=kind) + await api.get_user_items(user=random_str(1, RemoteIDType.ID.value - 1), kind=kind) - def test_extend_tracks_input_validation(self, api: SpotifyAPI): - assert api.extend_tracks(values=random_ids(), features=False, analysis=False) == [] - assert api.extend_tracks(values=[], features=True, analysis=True) == [] + async def test_extend_tracks_input_validation(self, api: SpotifyAPI): + assert await api.extend_tracks(values=random_ids(), features=False, analysis=False) == [] + assert await api.extend_tracks(values=[], features=True, analysis=True) == [] value = api.wrangler.convert( random_id(), kind=RemoteObjectType.ALBUM, type_in=RemoteIDType.ID, type_out=RemoteIDType.URL ) with pytest.raises(RemoteObjectTypeError): - api.extend_tracks(values=value, features=True) + await api.extend_tracks(values=value, features=True) - def test_get_artist_albums_input_validation(self, api: SpotifyAPI): - assert api.get_artist_albums(values=[]) == {} + async def test_get_artist_albums_input_validation(self, api: SpotifyAPI): + assert await api.get_artist_albums(values=[]) == {} value = api.wrangler.convert( random_id(), kind=RemoteObjectType.ALBUM, type_in=RemoteIDType.ID, type_out=RemoteIDType.URL ) with pytest.raises(RemoteObjectTypeError): - api.get_artist_albums(values=value) + await api.get_artist_albums(values=value) with pytest.raises(APIError): - api.get_artist_albums(values=random_id(), types=("unknown", "invalid")) + await api.get_artist_albums(values=random_id(), types=("unknown", "invalid")) ########################################################################### ## Cached-, Multi-, Batched-, and Extend tests for each supported item type @@ -258,7 +264,7 @@ def test_get_artist_albums_input_validation(self, api: SpotifyAPI): RemoteObjectType.PLAYLIST, RemoteObjectType.USER, ], ids=idfn) - def test_get_items_from_cache_skips( + async def test_get_items_from_cache_skips( self, object_type: RemoteObjectType, responses: dict[str, dict[str, Any]], @@ -270,11 +276,11 @@ def test_get_items_from_cache_skips( id_list = list(responses.keys()) # skip when no cache present - assert api.handler.cache is None - assert api._get_items_from_cache(method="GET", url=url, id_list=id_list) == ([], [], id_list) + assert not isinstance(api.handler.session, CachedSession) is None + assert await api._get_items_from_cache(method="GET", url=url, id_list=id_list) == ([], [], id_list) # skip when no repository found - assert api_cache._get_items_from_cache(method="GET", url=url, id_list=id_list) == ([], [], id_list) + assert await api_cache._get_items_from_cache(method="GET", url=url, id_list=id_list) == ([], [], id_list) # noinspection PyTestUnpassedFixture @pytest.mark.parametrize("object_type", [ @@ -286,7 +292,7 @@ def test_get_items_from_cache_skips( RemoteObjectType.AUDIOBOOK, RemoteObjectType.CHAPTER, ], ids=idfn) - def test_get_items_from_cache( + async def test_get_items_from_cache( self, object_type: RemoteObjectType, responses: dict[str, dict[str, Any]], @@ -299,17 +305,17 @@ def test_get_items_from_cache( method = "GET" responses_remapped = {(method, id_): response for id_, response in list(responses.items())[:limit]} - repository = api_cache.handler.cache.get_repository_from_url(url=url) - repository.update(responses_remapped) - assert all((method, id_) in repository for id_ in id_list[:limit]) + repository = api_cache.handler.session.cache.get_repository_from_url(url=url) + await repository.save_responses(responses_remapped) + assert all(repository.contains((method, id_)) for id_ in id_list[:limit]) - results, ids_found, ids_not_found = api_cache._get_items_from_cache(method=method, url=url, id_list=id_list) - assert len(results) == len(ids_found) == limit - assert len(ids_not_found) == len(responses) - limit + results, found, not_found = await api_cache._get_items_from_cache(method=method, url=url, id_list=id_list) + assert len(results) == len(found) == limit + assert len(not_found) == len(responses) - limit assert results == list(responses.values())[:limit] - assert not api_mock.request_history + api_mock.assert_not_called() - def test_get_items_multi( + async def test_get_items_multi( self, object_type: RemoteObjectType, responses: dict[str, dict[str, Any]], @@ -319,15 +325,15 @@ def test_get_items_multi( url = f"{api.url}/{object_type.name.lower()}s" params = {"key": "value"} - results = api._get_items_multi(url=url, id_list=responses, params=params, key=None) - requests = api_mock.get_requests(url=url) + results = await api._get_items_multi(url=url, id_list=responses, params=params, key=None) + requests = [url for url, _, _ in await api_mock.get_requests(url=url)] self.assert_item_types(results=results, key=object_type.name.lower()) self.assert_get_items_results(results=results, expected=responses, object_type=object_type) self.assert_params(requests=requests, params=params) # appropriate number of requests were made for multi requests - requests = [req for id_ in responses for req in api_mock.get_requests(url=f"{url}/{id_}")] + requests = [req for id_ in responses for req in await api_mock.get_requests(url=f"{url}/{id_}")] assert len(requests) == len(responses) @pytest.mark.parametrize("object_type", [ @@ -339,7 +345,7 @@ def test_get_items_multi( RemoteObjectType.AUDIOBOOK, RemoteObjectType.CHAPTER, ], ids=idfn) - def test_get_items_batched( + async def test_get_items_batched( self, object_type: RemoteObjectType, responses: dict[str, dict[str, Any]], @@ -351,8 +357,8 @@ def test_get_items_batched( params = {"key": "value"} limit = get_limit(responses, max_limit=api_mock.limit_max, pages=3) - results = api._get_items_batched(url=url, id_list=responses, params=params, key=key, limit=limit) - requests = api_mock.get_requests(url=url) + results = await api._get_items_batched(url=url, id_list=responses, params=params, key=key, limit=limit) + requests = [url for url, _, _ in await api_mock.get_requests(url=url)] self.assert_item_types(results=results, key=object_type.name.lower()) self.assert_get_items_results(results=results, expected=responses, object_type=object_type) @@ -360,14 +366,14 @@ def test_get_items_batched( # appropriate number of requests were made for batched requests id_params = [{"ids": ",".join(ids)} for ids in batched(responses, limit)] - requests = [req for req in requests if "ids" in parse_qs(req.query)] + requests = [url for url in requests if "ids" in url.query] assert len(requests) == len(id_params) < len(results) self.assert_params(requests=requests, params=id_params) @pytest.mark.parametrize("object_type", [ RemoteObjectType.PLAYLIST, RemoteObjectType.ALBUM, RemoteObjectType.SHOW, RemoteObjectType.AUDIOBOOK, ], ids=idfn) - def test_extend_items( + async def test_extend_items( self, object_type: RemoteObjectType, response: dict[str, Any], @@ -379,8 +385,7 @@ def test_extend_items( limit = self.reduce_items(response=response, key=key, api=api, api_mock=api_mock) test = response[key] - results = api.extend_items(response=test, key=api.collection_item_map.get(object_type, object_type)) - requests = api_mock.get_requests(url=test["href"].split("?")[0]) + results = await api.extend_items(response=test, key=api.collection_item_map.get(object_type, object_type)) # assert extension to total assert len(results) == total @@ -389,13 +394,14 @@ def test_extend_items( self.assert_item_types(results=test[api.items_key], object_type=object_type, key=key) # appropriate number of requests made (minus 1 for initial input) + requests = await api_mock.get_requests(url=test["href"].split("?")[0]) assert len(requests) == api_mock.calculate_pages(limit=limit, total=total) - 1 @pytest.mark.parametrize("object_type", [ RemoteObjectType.PLAYLIST, RemoteObjectType.ALBUM, # RemoteObjectType.SHOW, RemoteObjectType.AUDIOBOOK, RemoteResponse types not yet implemented for these ], ids=idfn) - def test_extend_items( + async def test_extend_items( self, object_type: RemoteObjectType, response: dict[str, Any], @@ -408,7 +414,7 @@ def test_extend_items( original = object_factory[object_type](deepcopy(response), skip_checks=True) test = object_factory[object_type](response, skip_checks=True) - api.extend_items(response=test, key=api.collection_item_map.get(object_type, object_type)) + await api.extend_items(response=test, key=api.collection_item_map.get(object_type, object_type)) test.refresh() self.assert_response_extended(actual=test, expected=original) @@ -417,35 +423,34 @@ def test_extend_items( @pytest.mark.parametrize("object_type", [ RemoteObjectType.PLAYLIST, RemoteObjectType.ALBUM, RemoteObjectType.SHOW, RemoteObjectType.AUDIOBOOK ], ids=idfn) - def test_extend_items_cache( + async def test_extend_items_cache( self, object_type: RemoteObjectType, - responses: dict[str, dict[str, Any]], + response: dict[str, Any], key: str, api_cache: SpotifyAPI, api_mock: SpotifyMock ): + self.reduce_items(response=response, key=key, api=api_cache, api_mock=api_mock) method = "GET" - response = list(responses.values())[0][key][api_cache.items_key][0] - url = response["href"] if object_type != RemoteObjectType.PLAYLIST else response[key.rstrip("s")]["href"] - repository = api_cache.handler.cache.get_repository_from_url(url=url) - id_list = None + response = response[key] - # get a response that has not already had its items persisted to the cache by a previous test - for response in responses.values(): - self.reduce_items(response=response, key=key, api=api_cache, api_mock=api_mock) - items = [ - item if object_type != RemoteObjectType.PLAYLIST else item[key.rstrip("s")] - for item in response[key][api_cache.items_key] - ] + if object_type == RemoteObjectType.PLAYLIST: + url = response[api_cache.items_key][0][key.rstrip("s")]["href"] + else: + url = response[api_cache.items_key][0]["href"] + repository = api_cache.handler.session.cache.get_repository_from_url(url=url) + await repository.clear() - id_list = [item[self.id_key] for item in items] - if any((method, id_) not in repository for id_ in id_list): - break + items = [ + item if object_type != RemoteObjectType.PLAYLIST else item[key.rstrip("s")] + for item in response[api_cache.items_key] + ] + id_list = [item[self.id_key] for item in items] - assert any((method, id_) not in repository for id_ in id_list) - api_cache.extend_items(response=response, key=api_cache.collection_item_map.get(object_type, object_type)) - assert all((method, id_) in repository for id_ in id_list) + assert any([not await repository.contains((method, id_)) for id_ in id_list]) + await api_cache.extend_items(response=response, key=api_cache.collection_item_map.get(object_type, object_type)) + assert all([await repository.contains((method, id_)) for id_ in id_list]) ########################################################################### ## ``get_user_items`` @@ -460,7 +465,7 @@ def test_extend_items_cache( (RemoteObjectType.EPISODE, False), (RemoteObjectType.AUDIOBOOK, False), ], ids=idfn) - def test_get_user_items( + async def test_get_user_items( self, object_type: RemoteObjectType, user: bool, @@ -488,11 +493,11 @@ def test_get_user_items( total = len(responses) limit = get_limit(total, max_limit=api_mock.limit_max, pages=3) - results = api.get_user_items(user=test, kind=object_type, limit=limit) + results = await api.get_user_items(user=test, kind=object_type, limit=limit) assert len(results) == total # appropriate number of requests made - requests = api_mock.get_requests(url=url) + requests = await api_mock.get_requests(url=url) assert len(requests) == api_mock.calculate_pages(limit=limit, total=total) for result in results: # check results are as expected @@ -518,7 +523,7 @@ def test_get_user_items( RemoteObjectType.CHAPTER: {"languages", "images", "chapters"}, } - def test_get_items_single_string( + async def test_get_items_single_string( self, object_type: RemoteObjectType, response: dict[str, Any], @@ -527,7 +532,7 @@ def test_get_items_single_string( api: SpotifyAPI, api_mock: SpotifyMock ): - results = api.get_items( + results = await api.get_items( values=random_id_type(id_=response[self.id_key], wrangler=api.wrangler, kind=object_type), kind=object_type, extend=extend @@ -536,14 +541,16 @@ def test_get_items_single_string( self.assert_get_items_results( results=results, expected={response[self.id_key]: response}, object_type=object_type, key=key ) - self.assert_get_items_calls(responses=[response], object_type=object_type, key=key, api=api, api_mock=api_mock) + await self.assert_get_items_calls( + responses=[response], object_type=object_type, key=key, api=api, api_mock=api_mock + ) # just check that these don't fail - api.get_items(values=response["uri"]) - api.get_items(values=response["href"]) - api.get_items(values=response["external_urls"]["spotify"]) + await api.get_items(values=response["uri"]) + await api.get_items(values=response["href"]) + await api.get_items(values=response["external_urls"]["spotify"]) - def test_get_items_many_string( + async def test_get_items_many_string( self, object_type: RemoteObjectType, responses: dict[str, dict[str, Any]], @@ -557,7 +564,7 @@ def test_get_items_many_string( limit = get_limit(responses, max_limit=api_mock.limit_max) assert len(responses) > limit - results = api.get_items( + results = await api.get_items( values=random_id_types(id_list=responses, wrangler=api.wrangler, kind=object_type), kind=object_type, limit=limit or api_mock.limit_max, @@ -565,11 +572,11 @@ def test_get_items_many_string( ) self.assert_item_types(results=results, key=object_type.name.lower()) self.assert_get_items_results(results=results, expected=responses, object_type=object_type, key=key) - self.assert_get_items_calls( + await self.assert_get_items_calls( responses=responses.values(), object_type=object_type, key=key, limit=limit, api=api, api_mock=api_mock ) - def test_get_items_single_mapping( + async def test_get_items_single_mapping( self, object_type: RemoteObjectType, response: dict[str, Any], @@ -579,7 +586,7 @@ def test_get_items_single_mapping( test = {k: v for k, v in response.items() if k not in self.update_keys[object_type]} self.assert_different(response, test, key) - results = api.get_items(values=test) + results = await api.get_items(values=test) self.assert_item_types(results=results, key=object_type.name.lower()) self.assert_get_items_results( results=results, @@ -589,7 +596,7 @@ def test_get_items_single_mapping( object_type=object_type ) - def test_get_items_many_mapping( + async def test_get_items_many_mapping( self, object_type: RemoteObjectType, responses: dict[str, dict[str, Any]], @@ -604,7 +611,7 @@ def test_get_items_many_mapping( for response in responses.values(): self.assert_different(response, test[response[self.id_key]], key) - results = api.get_items(values=test.values()) + results = await api.get_items(values=test.values()) self.assert_item_types(results=results, key=object_type.name.lower()) self.assert_get_items_results(results=results, expected=responses, test=test, key=key, object_type=object_type) @@ -612,7 +619,7 @@ def test_get_items_many_mapping( RemoteObjectType.TRACK, RemoteObjectType.PLAYLIST, RemoteObjectType.ALBUM, # other RemoteResponse types not yet implemented/do not provide expected results ], ids=idfn) - def test_get_items_single_response( + async def test_get_items_single_response( self, object_type: RemoteObjectType, response: dict[str, Any], @@ -629,7 +636,7 @@ def test_get_items_single_response( test = factory({k: v for k, v in response.items() if k not in self.update_keys[object_type]}, skip_checks=True) self.assert_different(original.response, test.response, key) - results = api.get_items(values=test) + results = await api.get_items(values=test) self.assert_get_items_results( results=results, expected={original.id: original.response}, @@ -645,7 +652,7 @@ def test_get_items_single_response( RemoteObjectType.TRACK, RemoteObjectType.PLAYLIST, RemoteObjectType.ALBUM, # other RemoteResponse types not yet implemented/do not provide expected results ], ids=idfn) - def test_get_items_many_response( + async def test_get_items_many_response( self, object_type: RemoteObjectType, responses: dict[str, dict[str, Any]], @@ -667,7 +674,7 @@ def test_get_items_many_response( for orig, ts in zip(original, test): self.assert_different(orig.response, ts.response, key) - results = api.get_items(values=test) + results = await api.get_items(values=test) self.assert_get_items_results( results=results, expected={response.id: response.response for response in original}, @@ -761,7 +768,7 @@ def _assert_extend_tracks_result( if test: assert key not in test[result[self.id_key]] - def assert_extend_tracks_calls( + async def assert_extend_tracks_calls( self, responses: Collection[dict[str, Any]], api: SpotifyAPI, @@ -773,20 +780,20 @@ def assert_extend_tracks_calls( """Assert appropriate number of requests made for extend_tracks method calls""" requests = [] if features and limit > 1: - requests += api_mock.get_requests(url=f"{api.url}/audio-features") + requests += await api_mock.get_requests(url=f"{api.url}/audio-features") if analysis and limit > 1: - requests += api_mock.get_requests(url=f"{api.url}/audio-analysis") + requests += await api_mock.get_requests(url=f"{api.url}/audio-analysis") for response in responses: if features: - requests += api_mock.get_requests(url=f"{api.url}/audio-features/{response[self.id_key]}") + requests += await api_mock.get_requests(url=f"{api.url}/audio-features/{response[self.id_key]}") if analysis: - requests += api_mock.get_requests(url=f"{api.url}/audio-analysis/{response[self.id_key]}") + requests += await api_mock.get_requests(url=f"{api.url}/audio-analysis/{response[self.id_key]}") - assert len(api_mock.request_history) == len(list(batched(responses, limit))) + len(responses) + assert len(requests) == len(list(batched(responses, limit))) + len(responses) @pytest.mark.parametrize("object_type", [RemoteObjectType.TRACK], ids=idfn) - def test_extend_tracks_single_string( + async def test_extend_tracks_single_string( self, object_type: RemoteObjectType, response: dict[str, Any], @@ -795,7 +802,7 @@ def test_extend_tracks_single_string( api: SpotifyAPI, api_mock: SpotifyMock ): - results = api.extend_tracks( + results = await api.extend_tracks( values=random_id_type(id_=response[self.id_key], wrangler=api.wrangler, kind=RemoteObjectType.TRACK), features=True, analysis=True @@ -805,15 +812,17 @@ def test_extend_tracks_single_string( self.assert_extend_tracks_results( results=results, features={response[self.id_key]: features}, analysis={response[self.id_key]: analysis} ) - self.assert_extend_tracks_calls(responses=[response], features=True, analysis=True, api=api, api_mock=api_mock) + await self.assert_extend_tracks_calls( + responses=[response], features=True, analysis=True, api=api, api_mock=api_mock + ) # just check that these don't fail - api.extend_tracks(values=response["uri"]) - api.extend_tracks(values=response["href"]) - api.extend_tracks(values=response["external_urls"]["spotify"]) + await api.extend_tracks(values=response["uri"]) + await api.extend_tracks(values=response["href"]) + await api.extend_tracks(values=response["external_urls"]["spotify"]) @pytest.mark.parametrize("object_type", [RemoteObjectType.TRACK], ids=idfn) - def test_extend_tracks_many_string( + async def test_extend_tracks_many_string( self, object_type: RemoteObjectType, responses: dict[str, Any], @@ -825,15 +834,15 @@ def test_extend_tracks_many_string( limit = get_limit(responses, max_limit=api_mock.limit_max) test = random_id_types(id_list=responses, wrangler=api.wrangler, kind=RemoteObjectType.TRACK) - results = api.extend_tracks(values=test, features=True, analysis=True, limit=limit) + results = await api.extend_tracks(values=test, features=True, analysis=True, limit=limit) self.assert_extend_tracks_results(results=results, features=features_all, analysis=analysis_all) - self.assert_extend_tracks_calls( + await self.assert_extend_tracks_calls( responses=responses.values(), features=True, analysis=True, api=api, api_mock=api_mock, limit=limit ) @pytest.mark.parametrize("object_type", [RemoteObjectType.TRACK], ids=idfn) - def test_extend_tracks_single_mapping( + async def test_extend_tracks_single_mapping( self, object_type: RemoteObjectType, response: dict[str, Any], @@ -841,12 +850,12 @@ def test_extend_tracks_single_mapping( analysis: dict[str, Any], api: SpotifyAPI, ): - results = api.extend_tracks(values=response, features=True, analysis=False) + results = await api.extend_tracks(values=response, features=True, analysis=False) self.assert_extend_tracks_results( results=results, test={response[self.id_key]: response}, features={response[self.id_key]: features}, ) - results = api.extend_tracks(values=response, features=False, analysis=True) + results = await api.extend_tracks(values=response, features=False, analysis=True) self.assert_extend_tracks_results( results=results, test={response[self.id_key]: response}, @@ -857,7 +866,7 @@ def test_extend_tracks_single_mapping( # noinspection PyTestUnpassedFixture @pytest.mark.parametrize("object_type", [RemoteObjectType.TRACK], ids=idfn) - def test_extend_tracks_many_mapping( + async def test_extend_tracks_many_mapping( self, object_type: RemoteObjectType, responses: dict[str, Any], @@ -866,16 +875,16 @@ def test_extend_tracks_many_mapping( api: SpotifyAPI, api_mock: SpotifyMock, ): - results = api.extend_tracks(values=responses.values(), features=True, analysis=False) + results = await api.extend_tracks(values=responses.values(), features=True, analysis=False) self.assert_extend_tracks_results(results=results, test=responses, features=features_all) - results = api.extend_tracks(values=responses.values(), features=False, analysis=True) + results = await api.extend_tracks(values=responses.values(), features=False, analysis=True) self.assert_extend_tracks_results( results=results, test=responses, features=features_all, features_in_results=False, analysis=analysis_all, ) @pytest.mark.parametrize("object_type", [RemoteObjectType.TRACK], ids=idfn) - def test_extend_tracks_single_response( + async def test_extend_tracks_single_response( self, object_type: RemoteObjectType, response: dict[str, Any], @@ -890,11 +899,11 @@ def test_extend_tracks_single_response( test: SpotifyTrack = object_factory[object_type](response, skip_checks=True) assert test.bpm is None - results = api.extend_tracks(values=response, features=True, analysis=False) + results = await api.extend_tracks(values=response, features=True, analysis=False) self.assert_extend_tracks_results(results=results, test={test.id: test.response}, features={test.id: features}) assert test.bpm is not None - results = api.extend_tracks(values=response, features=False, analysis=True) + results = await api.extend_tracks(values=response, features=False, analysis=True) self.assert_extend_tracks_results( results=results, test={test.id: test.response}, @@ -906,7 +915,7 @@ def test_extend_tracks_single_response( # noinspection PyTestUnpassedFixture @pytest.mark.parametrize("object_type", [RemoteObjectType.TRACK], ids=idfn) - def test_extend_tracks_many_response( + async def test_extend_tracks_many_response( self, object_type: RemoteObjectType, responses: dict[str, dict[str, Any]], @@ -923,18 +932,18 @@ def test_extend_tracks_many_response( for t in test: assert t.bpm is None - results = api.extend_tracks(values=responses.values(), features=True, analysis=False) + results = await api.extend_tracks(values=responses.values(), features=True, analysis=False) self.assert_extend_tracks_results(results=results, test=responses, features=features_all) for t in test: assert t.bpm is not None - results = api.extend_tracks(values=responses.values(), features=False, analysis=True) + results = await api.extend_tracks(values=responses.values(), features=False, analysis=True) self.assert_extend_tracks_results( results=results, test=responses, features=features_all, features_in_results=False, analysis=analysis_all, ) @pytest.mark.parametrize("object_type", [RemoteObjectType.TRACK], ids=idfn) - def test_get_tracks( + async def test_get_tracks( self, object_type: RemoteObjectType, response: dict[str, Any], @@ -942,7 +951,7 @@ def test_get_tracks( api_mock: SpotifyMock, object_factory: SpotifyObjectFactory, ): - results = api.get_tracks( + results = await api.get_tracks( values=random_id_type(id_=response[self.id_key], wrangler=api.wrangler, kind=RemoteObjectType.TRACK), features=True, analysis=True @@ -952,7 +961,7 @@ def test_get_tracks( assert "audio_analysis" not in response test_response = deepcopy(response) - results = api.get_tracks(values=test_response, features=True, analysis=True) + results = await api.get_tracks(values=test_response, features=True, analysis=True) assert results[0] == test_response assert "audio_features" in test_response assert "audio_analysis" in test_response @@ -961,7 +970,7 @@ def test_get_tracks( test_object: SpotifyTrack = object_factory[object_type](response, skip_checks=True) assert test_object.bpm is None - api.get_tracks(values=response, features=True, analysis=True) + await api.get_tracks(values=response, features=True, analysis=True) assert "audio_features" in response assert "audio_analysis" in response assert test_object.bpm is not None diff --git a/tests/libraries/remote/spotify/api/test_misc.py b/tests/libraries/remote/spotify/api/test_misc.py index 2a530c2c..6bce14e1 100644 --- a/tests/libraries/remote/spotify/api/test_misc.py +++ b/tests/libraries/remote/spotify/api/test_misc.py @@ -1,6 +1,6 @@ from copy import deepcopy from typing import Any -from urllib.parse import parse_qs +from urllib.parse import unquote import pytest @@ -18,7 +18,7 @@ class TestSpotifyAPIMisc: ("query", {"query": "valid query", "kind": ObjectType.TRACK}, 1, 50), ("get_user_items", {"kind": ObjectType.PLAYLIST}, 1, 50) ], ids=idfn) - def test_limit_param_limited( + async def test_limit_param_limited( self, method_name: str, kwargs: dict[str, Any], @@ -28,40 +28,40 @@ def test_limit_param_limited( api_mock: SpotifyMock ): # too small - getattr(api, method_name)(limit=floor - 20, **kwargs) - params = parse_qs(api_mock.last_request.query) - assert "limit" in params - assert int(params["limit"][0]) == floor + await getattr(api, method_name)(limit=floor - 20, **kwargs) + url, _, _ = next(reversed(await api_mock.get_requests())) + assert "limit" in url.query + assert int(url.query["limit"]) == floor # good value limit = floor + (ceil // 2) - getattr(api, method_name)(limit=limit, **kwargs) - params = parse_qs(api_mock.last_request.query) - assert "limit" in params - assert int(params["limit"][0]) == limit + await getattr(api, method_name)(limit=limit, **kwargs) + url, _, _ = next(reversed(await api_mock.get_requests())) + assert "limit" in url.query + assert int(url.query["limit"]) == limit # too big - getattr(api, method_name)(limit=ceil + 100, **kwargs) - params = parse_qs(api_mock.last_request.query) - assert "limit" in params - assert int(params["limit"][0]) == ceil + await getattr(api, method_name)(limit=ceil + 100, **kwargs) + url, _, _ = next(reversed(await api_mock.get_requests())) + assert "limit" in url.query + assert int(url.query["limit"]) == ceil ########################################################################### ## /me + /search endpoints ########################################################################### - def test_get_self(self, api: SpotifyAPI, api_mock: SpotifyMock): + async def test_get_self(self, api: SpotifyAPI, api_mock: SpotifyMock): api.user_data = {} - assert api.get_self(update_user_data=False) == api_mock.user + assert await api.get_self(update_user_data=False) == api_mock.user assert api.user_data == {} - assert api.get_self(update_user_data=True) == api_mock.user + assert await api.get_self(update_user_data=True) == api_mock.user assert api.user_data == api_mock.user - def test_query_input_validation(self, api: SpotifyAPI, api_mock: SpotifyMock): - assert api.query(query=None, kind=ObjectType.EPISODE) == [] - assert api.query(query="", kind=ObjectType.SHOW) == [] + async def test_query_input_validation(self, api: SpotifyAPI, api_mock: SpotifyMock): + assert await api.query(query=None, kind=ObjectType.EPISODE) == [] + assert await api.query(query="", kind=ObjectType.SHOW) == [] # long queries that would cause the API to give an error should fail safely - assert api.query(query=random_str(151, 200), kind=ObjectType.CHAPTER) == [] + assert await api.query(query=random_str(151, 200), kind=ObjectType.CHAPTER) == [] @pytest.mark.parametrize("kind,query,limit", [ (ObjectType.PLAYLIST, "super cool playlist", 5), @@ -72,7 +72,7 @@ def test_query_input_validation(self, api: SpotifyAPI, api_mock: SpotifyMock): (ObjectType.EPISODE, "incredible episode", 25), (ObjectType.AUDIOBOOK, "i love this audiobook", 6), ], ids=idfn) - def test_query( + async def test_query( self, kind: ObjectType, query: str, @@ -81,18 +81,16 @@ def test_query( api_mock: SpotifyMock, ): expected = api_mock.item_type_map[kind] - results = api.query(query=query, kind=kind, limit=limit) + results = await api.query(query=query, kind=kind, limit=limit) assert len(results) <= min(len(expected), limit) for result in results: assert result["type"] == kind.name.lower() - request = api_mock.get_requests(url=f"{api.url}/search", params={"q": query})[0] - params = parse_qs(request.query) - - assert params["q"][0] == query - assert int(params["limit"][0]) == limit - assert params["type"][0] == kind.name.lower() + url, _, _ = next(iter(await api_mock.get_requests(url=f"{api.url}/search", params={"q": query}))) + assert unquote(url.query["q"]) == query + assert int(url.query["limit"]) == limit + assert unquote(url.query["type"]) == kind.name.lower() ########################################################################### ## Utilities @@ -100,18 +98,18 @@ def test_query( @pytest.mark.parametrize("kind", [ ObjectType.PLAYLIST, ObjectType.ALBUM, ObjectType.SHOW, ObjectType.AUDIOBOOK, ], ids=idfn) - def test_pretty_print_uris( + async def test_pretty_print_uris( self, kind: ObjectType, api: SpotifyAPI, api_mock: SpotifyMock, capfd: pytest.CaptureFixture ): key = api.collection_item_map.get(kind, kind).name.lower() + "s" source = deepcopy(next(item for item in api_mock.item_type_map[kind] if item[key]["total"] > 50)) - api.print_collection(value=source) + await api.print_collection(value=source) stdout = get_stdout(capfd) # printed in blocks blocks = [block for block in stdout.strip().split("\n\n") if SpotifyDataWrangler.url_ext in block] - assert len(blocks) == len(api_mock.request_history) + assert len(blocks) == api_mock.total_requests # lines printed = total tracks + 1 extra for title lines = [line for line in stdout.strip().split("\n") if SpotifyDataWrangler.url_ext in line] diff --git a/tests/libraries/remote/spotify/api/test_playlist.py b/tests/libraries/remote/spotify/api/test_playlist.py index 350ccae9..a3bc8966 100644 --- a/tests/libraries/remote/spotify/api/test_playlist.py +++ b/tests/libraries/remote/spotify/api/test_playlist.py @@ -3,6 +3,8 @@ from typing import Any import pytest +from aioresponses.core import RequestCall +from yarl import URL from musify import PROGRAM_NAME from musify.libraries.remote.core.enum import RemoteIDType, RemoteObjectType @@ -34,69 +36,84 @@ def playlist_unique(api_mock: SpotifyMock) -> dict[str, Any]: if names.count(pl["name"]) == 1 and len(pl["name"]) != RemoteIDType.ID.value ) + @staticmethod + def _get_payload_from_request(request: RequestCall) -> dict[str, Any] | None: + return request.kwargs.get("body", request.kwargs.get("json")) + + @classmethod + async def _get_payloads_from_url_base(cls, url: str | URL, api_mock: SpotifyMock) -> list[dict[str, Any]]: + return [ + cls._get_payload_from_request(req) for _, req, _ in await api_mock.get_requests(url=url) + if cls._get_payload_from_request(req) + ] + ########################################################################### ## Basic functionality ########################################################################### - def test_get_playlist_url(self, playlist_unique: dict[str, Any], api: SpotifyAPI, api_mock: SpotifyMock): - assert api.get_playlist_url(playlist=playlist_unique) == playlist_unique["href"] - assert api.get_playlist_url(playlist=playlist_unique["name"]) == playlist_unique["href"] + async def test_get_playlist_url(self, playlist_unique: dict[str, Any], api: SpotifyAPI, api_mock: SpotifyMock): + assert await api.get_playlist_url(playlist=playlist_unique) == playlist_unique["href"] + assert await api.get_playlist_url(playlist=playlist_unique["name"]) == playlist_unique["href"] pl_object = SpotifyPlaylist(playlist_unique, skip_checks=True) - assert api.get_playlist_url(playlist=pl_object) == playlist_unique["href"] + assert await api.get_playlist_url(playlist=pl_object) == playlist_unique["href"] with pytest.raises(RemoteIDTypeError): - api.get_playlist_url("does not exist") + await api.get_playlist_url("does not exist") ########################################################################### ## POST playlist operations ########################################################################### - def test_create_playlist(self, api: SpotifyAPI, api_mock: SpotifyMock): + async def test_create_playlist(self, api: SpotifyAPI, api_mock: SpotifyMock): name = "test playlist" url = f"{api.url}/users/{api_mock.user_id}/playlists" - result = api.create_playlist(name=name, public=False, collaborative=True) + result = await api.create_playlist(name=name, public=False, collaborative=True) - body = api_mock.get_requests(url=url, response={"name": name})[0].json() + _, _, response = next(iter(await api_mock.get_requests(url=url, response={"name": name}))) + body = await response.json() assert body["name"] == name assert PROGRAM_NAME in body["description"] assert not body["public"] assert body["collaborative"] assert result.removeprefix(f"{api.url}/playlists/").strip("/") - def test_add_to_playlist_input_validation_and_skips(self, api: SpotifyAPI, api_mock: SpotifyMock): + async def test_add_to_playlist_input_validation_and_skips(self, api: SpotifyAPI, api_mock: SpotifyMock): url = f"{api.url}/playlists/{random_id()}" for kind in ALL_ITEM_TYPES: if kind == RemoteObjectType.TRACK: continue with pytest.raises(RemoteObjectTypeError): - api.add_to_playlist(playlist=url, items=random_uris(kind=kind)) + await api.add_to_playlist(playlist=url, items=random_uris(kind=kind)) with pytest.raises(RemoteObjectTypeError): - api.add_to_playlist(playlist=url, items=random_api_urls(kind=kind)) + await api.add_to_playlist(playlist=url, items=random_api_urls(kind=kind)) with pytest.raises(RemoteObjectTypeError): - api.add_to_playlist(playlist=url, items=random_ext_urls(kind=kind)) + await api.add_to_playlist(playlist=url, items=random_ext_urls(kind=kind)) - assert api.add_to_playlist(playlist=url, items=()) == 0 + assert await api.add_to_playlist(playlist=url, items=()) == 0 with pytest.raises(RemoteIDTypeError): - api.add_to_playlist(playlist="does not exist", items=random_ids()) + await api.add_to_playlist(playlist="does not exist", items=random_ids()) - def test_add_to_playlist_batches_limited(self, playlist: dict[str, Any], api: SpotifyAPI, api_mock: SpotifyMock): + async def test_add_to_playlist_batches_limited( + self, playlist: dict[str, Any], api: SpotifyAPI, api_mock: SpotifyMock + ): id_list = random_ids(200, 300) valid_limit = 80 - api.add_to_playlist(playlist=playlist["href"], items=sample(id_list, k=10), limit=-30, skip_dupes=False) - api.add_to_playlist(playlist=playlist["href"], items=id_list, limit=200, skip_dupes=False) - api.add_to_playlist(playlist=playlist["href"], items=id_list, limit=valid_limit, skip_dupes=False) + await api.add_to_playlist(playlist=playlist["href"], items=sample(id_list, k=10), limit=-30, skip_dupes=False) + await api.add_to_playlist(playlist=playlist["href"], items=id_list, limit=200, skip_dupes=False) + await api.add_to_playlist(playlist=playlist["href"], items=id_list, limit=valid_limit, skip_dupes=False) - requests = api_mock.get_requests(url=playlist["href"] + "/tracks") + requests = await api_mock.get_requests(url=playlist["href"] + "/tracks") - for i, request in enumerate(requests, 1): - count = len(request.json()["uris"]) + for i, (_, request, _) in enumerate(requests, 1): + payload = self._get_payload_from_request(request) + count = len(payload["uris"]) assert count >= 1 assert count <= 100 - def test_add_to_playlist(self, playlist: dict[str, Any], api: SpotifyAPI, api_mock: SpotifyMock): + async def test_add_to_playlist(self, playlist: dict[str, Any], api: SpotifyAPI, api_mock: SpotifyMock): total = playlist["tracks"]["total"] limit = total // 3 assert total > limit # ensure ranges are valid for test to work @@ -106,29 +123,26 @@ def test_add_to_playlist(self, playlist: dict[str, Any], api: SpotifyAPI, api_mo ) assert len(id_list) < total - result = api.add_to_playlist(playlist=playlist["id"], items=id_list, limit=limit, skip_dupes=False) + result = await api.add_to_playlist(playlist=playlist["id"], items=id_list, limit=limit, skip_dupes=False) assert result == len(id_list) uris = [] - for request in api_mock.get_requests(url=playlist["href"] + "/tracks"): - if not request.body: - continue - - request_body = request.json() - if "uris" in request_body: - uris.extend(request_body["uris"]) + for _, request, _ in await api_mock.get_requests(url=playlist["href"] + "/tracks"): + payload = self._get_payload_from_request(request) + if "uris" in payload: + uris.extend(payload["uris"]) assert len(uris) == len(id_list) # check same results for other input types - result = api.add_to_playlist(playlist=playlist, items=id_list, limit=limit, skip_dupes=False) + result = await api.add_to_playlist(playlist=playlist, items=id_list, limit=limit, skip_dupes=False) assert result == len(id_list) pl = SpotifyPlaylist(playlist, skip_checks=True) - result = api.add_to_playlist(playlist=pl, items=id_list, limit=limit, skip_dupes=False) + result = await api.add_to_playlist(playlist=pl, items=id_list, limit=limit, skip_dupes=False) assert result == len(id_list) - def test_add_to_playlist_with_skip(self, playlist: dict[str, Any], api: SpotifyAPI, api_mock: SpotifyMock): - api.extend_items(playlist["tracks"]) + async def test_add_to_playlist_with_skip(self, playlist: dict[str, Any], api: SpotifyAPI, api_mock: SpotifyMock): + await api.extend_items(playlist["tracks"]) initial = len(playlist["tracks"]["items"]) total = playlist["tracks"]["total"] @@ -141,72 +155,69 @@ def test_add_to_playlist_with_skip(self, playlist: dict[str, Any], api: SpotifyA wrangler=api.wrangler, kind=RemoteObjectType.TRACK, start=api_mock.limit_lower, stop=randrange(20, 30) ) - result = api.add_to_playlist(playlist=playlist["uri"], items=id_list_dupes + id_list_new, limit=limit) + result = await api.add_to_playlist(playlist=playlist["uri"], items=id_list_dupes + id_list_new, limit=limit) assert result == len(id_list_new) uris = [] - for request in api_mock.get_requests(url=playlist["href"] + "/tracks"): - if not request.body: - continue - - request_body = request.json() - if "uris" in request_body: - uris.extend(request_body["uris"]) + for _, request, _ in await api_mock.get_requests(url=playlist["href"] + "/tracks"): + payload = self._get_payload_from_request(request) + if payload and "uris" in payload: + uris.extend(payload["uris"]) assert len(uris) == len(id_list_new) ########################################################################### ## DELETE playlist operations ########################################################################### - def test_delete_playlist(self, playlist_unique: dict[str, Any], api: SpotifyAPI, api_mock: SpotifyMock): - result = api.delete_playlist( + async def test_delete_playlist(self, playlist_unique: dict[str, Any], api: SpotifyAPI, api_mock: SpotifyMock): + result = await api.delete_playlist( random_id_type(id_=playlist_unique["id"], wrangler=api.wrangler, kind=RemoteObjectType.PLAYLIST) ) assert result == playlist_unique["href"] + "/followers" - result = api.delete_playlist(playlist_unique) + result = await api.delete_playlist(playlist_unique) assert result == playlist_unique["href"] + "/followers" - result = api.delete_playlist(SpotifyPlaylist(playlist_unique, skip_checks=True)) + result = await api.delete_playlist(SpotifyPlaylist(playlist_unique, skip_checks=True)) assert result == playlist_unique["href"] + "/followers" - def test_clear_from_playlist_input_validation_and_skips(self, api: SpotifyAPI, api_mock: SpotifyMock): + async def test_clear_from_playlist_input_validation_and_skips(self, api: SpotifyAPI, api_mock: SpotifyMock): url = f"{api.url}/playlists/{random_id()}" for kind in ALL_ITEM_TYPES: if kind == RemoteObjectType.TRACK: continue with pytest.raises(RemoteObjectTypeError): - api.clear_from_playlist(playlist=url, items=random_uris(kind=kind)) + await api.clear_from_playlist(playlist=url, items=random_uris(kind=kind)) with pytest.raises(RemoteObjectTypeError): - api.clear_from_playlist(playlist=url, items=random_api_urls(kind=kind)) + await api.clear_from_playlist(playlist=url, items=random_api_urls(kind=kind)) with pytest.raises(RemoteObjectTypeError): - api.clear_from_playlist(playlist=url, items=random_ext_urls(kind=kind)) + await api.clear_from_playlist(playlist=url, items=random_ext_urls(kind=kind)) - result = api.clear_from_playlist(playlist=url, items=()) + result = await api.clear_from_playlist(playlist=url, items=()) assert result == 0 with pytest.raises(RemoteIDTypeError): - api.add_to_playlist(playlist="does not exist", items=random_ids()) + await api.add_to_playlist(playlist="does not exist", items=random_ids()) - def test_clear_from_playlist_batches_limited( + async def test_clear_from_playlist_batches_limited( self, playlist: dict[str, Any], api: SpotifyAPI, api_mock: SpotifyMock ): id_list = random_ids(200, 300) valid_limit = 80 - api.clear_from_playlist(playlist=playlist["href"], items=sample(id_list, k=10), limit=-30) - api.clear_from_playlist(playlist=playlist["href"], items=id_list, limit=valid_limit) - api.clear_from_playlist(playlist=playlist["href"], items=id_list, limit=200) + await api.clear_from_playlist(playlist=playlist["href"], items=sample(id_list, k=10), limit=-30) + await api.clear_from_playlist(playlist=playlist["href"], items=id_list, limit=valid_limit) + await api.clear_from_playlist(playlist=playlist["href"], items=id_list, limit=200) - requests = [req.json() for req in api_mock.get_requests(url=playlist["href"] + "/tracks") if req.body] - for i, body in enumerate(requests, 1): - count = len(body["tracks"]) + requests = await self._get_payloads_from_url_base(url=playlist["href"] + "/tracks", api_mock=api_mock) + for i, payload in enumerate(requests, 1): + count = len(payload["tracks"]) assert count >= 1 assert count <= 100 - def test_clear_from_playlist_items(self, playlist: dict[str, Any], api: SpotifyAPI, api_mock: SpotifyMock): + async def test_clear_from_playlist_items(self, playlist: dict[str, Any], api: SpotifyAPI, api_mock: SpotifyMock): total = playlist["tracks"]["total"] limit = total // 3 assert total > limit # ensure ranges are valid for test to work @@ -216,29 +227,29 @@ def test_clear_from_playlist_items(self, playlist: dict[str, Any], api: SpotifyA ) assert len(id_list) < total - result = api.clear_from_playlist(playlist=playlist["uri"], items=id_list, limit=limit) + result = await api.clear_from_playlist(playlist=playlist["uri"], items=id_list, limit=limit) assert result == len(id_list) - requests = [req.json() for req in api_mock.get_requests(url=playlist["href"] + "/tracks") if req.body] + requests = await self._get_payloads_from_url_base(url=playlist["href"] + "/tracks", api_mock=api_mock) assert all("tracks" in body for body in requests) assert len([uri["uri"] for req in requests for uri in req["tracks"]]) == len(id_list) # check same results for other input types - result = api.clear_from_playlist(playlist=playlist, items=id_list, limit=limit) + result = await api.clear_from_playlist(playlist=playlist, items=id_list, limit=limit) assert result == len(id_list) pl = SpotifyPlaylist(playlist, skip_checks=True) - result = api.clear_from_playlist(playlist=pl, items=id_list, limit=limit) + result = await api.clear_from_playlist(playlist=pl, items=id_list, limit=limit) assert result == len(id_list) - def test_clear_from_playlist_all(self, playlist: dict[str, Any], api: SpotifyAPI, api_mock: SpotifyMock): + async def test_clear_from_playlist_all(self, playlist: dict[str, Any], api: SpotifyAPI, api_mock: SpotifyMock): total = playlist["tracks"]["total"] limit = total // 4 assert total > limit # ensure ranges are valid for test to work - result = api.clear_from_playlist(playlist=playlist, limit=limit) + result = await api.clear_from_playlist(playlist=playlist, limit=limit) assert result == total - requests = [req.json() for req in api_mock.get_requests(url=playlist["href"] + "/tracks") if req.body] + requests = await self._get_payloads_from_url_base(url=playlist["href"] + "/tracks", api_mock=api_mock) assert all("tracks" in body for body in requests) assert len([uri["uri"] for body in requests for uri in body["tracks"]]) == total diff --git a/tests/libraries/remote/spotify/api/utils.py b/tests/libraries/remote/spotify/api/utils.py index 8aebfd8c..2f86cdfa 100644 --- a/tests/libraries/remote/spotify/api/utils.py +++ b/tests/libraries/remote/spotify/api/utils.py @@ -2,9 +2,6 @@ from itertools import batched from typing import Any -# noinspection PyProtectedMember,PyUnresolvedReferences -from requests_mock.request import _RequestObjectProxy as Request - from tests.libraries.remote.spotify.api.mock import SpotifyMock @@ -22,7 +19,7 @@ def get_limit(values: Collection | int, max_limit: int, pages: int = 3) -> int: def assert_calls( expected: Collection[Mapping[str, Any]], - requests: list[Request], + requests: Collection, api_mock: SpotifyMock, key: str | None = None, limit: int | None = None, diff --git a/tests/libraries/remote/spotify/conftest.py b/tests/libraries/remote/spotify/conftest.py index dc2096ad..e60de518 100644 --- a/tests/libraries/remote/spotify/conftest.py +++ b/tests/libraries/remote/spotify/conftest.py @@ -36,5 +36,5 @@ def api_mock(_api_mock: SpotifyMock) -> SpotifyMock: Creates a copy of ``_api_mock`` to allow for successful requests history assertions. """ mock = copy(_api_mock) - mock.reset_mock() + mock.reset() return mock diff --git a/tests/libraries/remote/spotify/object/test_album.py b/tests/libraries/remote/spotify/object/test_album.py index 32a5bb97..960956ff 100644 --- a/tests/libraries/remote/spotify/object/test_album.py +++ b/tests/libraries/remote/spotify/object/test_album.py @@ -44,13 +44,13 @@ def response_random(self, api_mock: SpotifyMock) -> dict[str, Any]: return response @pytest.fixture(scope="class") - def _response_valid(self, api: SpotifyAPI, _api_mock: SpotifyMock) -> dict[str, Any]: + async def _response_valid(self, api: SpotifyAPI, _api_mock: SpotifyMock) -> dict[str, Any]: response = next( deepcopy(album) for album in _api_mock.albums if album["tracks"]["total"] > len(album["tracks"]["items"]) > 5 and album["genres"] and album["artists"] ) - api.extend_items(response=response, key=RemoteObjectType.TRACK) + await api.extend_items(response=response, key=RemoteObjectType.TRACK) return response @@ -62,11 +62,11 @@ def response_valid(self, _response_valid: dict[str, Any]) -> dict[str, Any]: """ return deepcopy(_response_valid) - def test_input_validation(self, response_random: dict[str, Any], api_mock: SpotifyMock): + async def test_input_validation(self, response_random: dict[str, Any], api_mock: SpotifyMock): with pytest.raises(RemoteObjectTypeError): SpotifyAlbum(api_mock.generate_playlist(item_count=0)) with pytest.raises(APIError): - SpotifyAlbum(response_random).reload() + await SpotifyAlbum(response_random).reload() response_random["total_tracks"] += 10 with pytest.raises(RemoteError): @@ -186,7 +186,7 @@ def test_refresh(self, response_valid: dict[str, Any]): assert "genres" in track.response["album"] assert track.disc_total == original_disc_total + 5 - def test_reload(self, response_valid: dict[str, Any], api: SpotifyAPI): + async def test_reload(self, response_valid: dict[str, Any], api: SpotifyAPI): response_valid.pop("genres", None) response_valid.pop("popularity", None) @@ -202,7 +202,7 @@ def test_reload(self, response_valid: dict[str, Any], api: SpotifyAPI): assert album.compilation != is_compilation album.api = api - album.reload(extend_artists=True, extend_features=False) + await album.reload(extend_artists=True, extend_features=False) assert album.genres assert album.rating is not None assert album.compilation == is_compilation @@ -211,13 +211,13 @@ def test_reload(self, response_valid: dict[str, Any], api: SpotifyAPI): ## Load method tests ########################################################################### @staticmethod - def get_load_without_items( + async def get_load_without_items( loader: SpotifyAlbum, response_valid: dict[str, Any], api: SpotifyAPI, api_mock: SpotifyMock ): - return loader.load(response_valid["href"], api=api, extend_tracks=True) + return await loader.load(response_valid["href"], api=api, extend_tracks=True) @pytest.fixture def load_items( @@ -230,7 +230,6 @@ def load_items( :return: The extracted response as SpotifyTracks. """ - # ensure extension of items can be made by reducing available items limit = response_valid[item_key]["limit"] response_valid[item_key]["items"] = response_valid[item_key][api.items_key][:limit] @@ -260,17 +259,17 @@ def load_items( return items - def test_load_with_all_items( + async def test_load_with_all_items( self, response_valid: dict[str, Any], item_key: str, api: SpotifyAPI, api_mock: SpotifyMock ): load_items = [SpotifyTrack(response) for response in response_valid[item_key][api.items_key]] - SpotifyAlbum.load( + await SpotifyAlbum.load( response_valid, api=api, items=load_items, extend_albums=True, extend_tracks=False, extend_features=False ) - assert not api_mock.request_history + api_mock.assert_not_called() - def test_load_with_some_items( + async def test_load_with_some_items( self, response_valid: dict[str, Any], item_key: str, @@ -280,25 +279,25 @@ def test_load_with_some_items( ): kind = RemoteObjectType.ALBUM - result: SpotifyAlbum = SpotifyAlbum.load( + result: SpotifyAlbum = await SpotifyAlbum.load( response_valid, api=api, items=load_items, extend_tracks=True, extend_features=True ) - self.assert_load_with_items_requests( + await self.assert_load_with_items_requests( response=response_valid, result=result, items=load_items, key=item_key, api_mock=api_mock ) - self.assert_load_with_items_extended( + await self.assert_load_with_items_extended( response=response_valid, result=result, items=load_items, kind=kind, key=item_key, api_mock=api_mock ) # requests for extension data expected = api_mock.calculate_pages_from_response(response_valid) # -1 for not calling initial page - assert len(api_mock.get_requests(re.compile(f"{result.url}/{item_key}"))) == expected - 1 - assert len(api_mock.get_requests(re.compile(f"{api.url}/audio-features"))) == expected - assert not api_mock.get_requests(re.compile(f"{api.url}/artists")) # did not extend artists + assert len(await api_mock.get_requests(url=f"{result.url}/{item_key}")) == expected - 1 + assert len(await api_mock.get_requests(url=f"{api.url}/audio-features")) == expected + assert not await api_mock.get_requests(url=f"{api.url}/artists") # did not extend artists - def test_load_with_some_items_and_no_extension( + async def test_load_with_some_items_and_no_extension( self, response_valid: dict[str, Any], item_kind: RemoteObjectType, @@ -307,22 +306,23 @@ def test_load_with_some_items_and_no_extension( api: SpotifyAPI, api_mock: SpotifyMock ): - api.extend_items(response_valid, kind=RemoteObjectType.ALBUM, key=item_kind) - api_mock.reset_mock() # reset for new requests checks to work correctly + await api.extend_items(response_valid, kind=RemoteObjectType.ALBUM, key=item_kind) + api_mock.reset() # reset for new requests checks to work correctly assert len(response_valid[item_key][api.items_key]) == response_valid[item_key]["total"] - assert not api_mock.get_requests(response_valid[item_key]["href"]) + assert not await api_mock.get_requests(url=response_valid[item_key]["href"]) - result: SpotifyAlbum = SpotifyAlbum.load( + result: SpotifyAlbum = await SpotifyAlbum.load( response_valid, api=api, items=load_items, extend_artists=True, extend_tracks=True, extend_features=True ) - self.assert_load_with_items_requests( + await self.assert_load_with_items_requests( response=response_valid, result=result, items=load_items, key=item_key, api_mock=api_mock ) # requests for extension data expected = api_mock.calculate_pages_from_response(response_valid) - assert not api_mock.get_requests(re.compile(f"{result.url}/{item_key}")) # already extended on input - assert len(api_mock.get_requests(re.compile(f"{api.url}/audio-features"))) == expected - assert api_mock.get_requests(re.compile(f"{api.url}/artists")) # called the artists endpoint at least once + assert not await api_mock.get_requests(url=f"{result.url}/{item_key}") # already extended on input + assert len(await api_mock.get_requests(url=f"{api.url}/audio-features")) == expected + # called the artists endpoint at least once + assert await api_mock.get_requests(url=re.compile(f"{api.url}/artists")) diff --git a/tests/libraries/remote/spotify/object/test_artist.py b/tests/libraries/remote/spotify/object/test_artist.py index b4830bb6..74f8fe12 100644 --- a/tests/libraries/remote/spotify/object/test_artist.py +++ b/tests/libraries/remote/spotify/object/test_artist.py @@ -1,4 +1,3 @@ -import re from collections.abc import Iterable from copy import deepcopy from random import randrange @@ -79,11 +78,11 @@ def response_valid(self, api_mock: SpotifyMock) -> dict[str, Any]: ) return artist | {"albums": items_block} - def test_input_validation(self, response_random: dict[str, Any], api_mock: SpotifyMock): + async def test_input_validation(self, response_random: dict[str, Any], api_mock: SpotifyMock): with pytest.raises(RemoteObjectTypeError): SpotifyArtist(api_mock.generate_track(artists=False, album=False)) with pytest.raises(APIError): - SpotifyArtist(response_random).reload() + await SpotifyArtist(response_random).reload() def test_attributes(self, response_random: dict[str, Any]): artist = SpotifyArtist(response_random) @@ -132,7 +131,7 @@ def test_refresh(self, response_valid: dict[str, Any]): artist.refresh(skip_checks=True) assert len(artist.albums) == original_album_count // 2 - def test_reload(self, response_valid: dict[str, Any], api: SpotifyAPI): + async def test_reload(self, response_valid: dict[str, Any], api: SpotifyAPI): genres = response_valid.pop("genres", None) response_valid.pop("popularity", None) response_valid.pop("followers", None) @@ -150,7 +149,7 @@ def test_reload(self, response_valid: dict[str, Any], api: SpotifyAPI): assert not artist.tracks artist.api = api - artist.reload(extend_albums=False, extend_tracks=True) + await artist.reload(extend_albums=False, extend_tracks=True) if genres: assert artist.genres assert artist.rating is not None @@ -159,26 +158,26 @@ def test_reload(self, response_valid: dict[str, Any], api: SpotifyAPI): assert not artist.artists assert not artist.tracks - artist.reload(extend_albums=True, extend_tracks=False) + await artist.reload(extend_albums=True, extend_tracks=False) assert {album.id for album in artist._albums} == album_ids assert len(artist.artists) == len(artist_names) assert set(artist.artists) == artist_names assert not artist.tracks - artist.reload(extend_albums=True, extend_tracks=True) + await artist.reload(extend_albums=True, extend_tracks=True) assert artist.tracks ########################################################################### ## Load method tests ########################################################################### @staticmethod - def get_load_without_items( + async def get_load_without_items( loader: SpotifyArtist, response_valid: dict[str, Any], api: SpotifyAPI, api_mock: SpotifyMock ): - return loader.load(response_valid["href"], api=api, extend_albums=True, extend_tracks=True) + return await loader.load(response_valid["href"], api=api, extend_albums=True, extend_tracks=True) @pytest.fixture def load_items( @@ -216,17 +215,17 @@ def load_items( return items - def test_load_with_all_items( + async def test_load_with_all_items( self, response_valid: dict[str, Any], item_key: str, api: SpotifyAPI, api_mock: SpotifyMock ): load_items = [SpotifyAlbum(response, skip_checks=True) for response in response_valid[item_key][api.items_key]] - SpotifyArtist.load( + await SpotifyArtist.load( response_valid, api=api, items=load_items, extend_albums=True, extend_tracks=False, extend_features=False ) - assert not api_mock.request_history + api_mock.assert_not_called() - def test_load_with_some_items( + async def test_load_with_some_items( self, response_valid: dict[str, Any], item_key: str, @@ -236,31 +235,31 @@ def test_load_with_some_items( ): kind = RemoteObjectType.ARTIST - result: SpotifyArtist = SpotifyArtist.load( + result: SpotifyArtist = await SpotifyArtist.load( response_valid, api=api, items=load_items, extend_albums=True, extend_tracks=True, extend_features=True ) - self.assert_load_with_items_requests( + await self.assert_load_with_items_requests( response=response_valid, result=result, items=load_items, key=item_key, api_mock=api_mock ) - self.assert_load_with_items_extended( + await self.assert_load_with_items_extended( response=response_valid, result=result, items=load_items, kind=kind, key=item_key, api_mock=api_mock ) # requests for extension data expected = api_mock.calculate_pages_from_response(response_valid, item_key=item_key) - assert len(api_mock.get_requests(re.compile(f"{result.url}/{item_key}"))) == expected + assert len(await api_mock.get_requests(url=f"{result.url}/{item_key}")) == expected for album in result.response[item_key][api.items_key]: url = album["tracks"]["href"].split("?")[0] expected = api_mock.calculate_pages_from_response(album) - assert len(api_mock.get_requests(re.compile(url))) == expected + assert len(await api_mock.get_requests(url=url)) == expected assert result.tracks expected_features = api_mock.calculate_pages(limit=response_valid[item_key]["limit"], total=len(result.tracks)) - assert len(api_mock.get_requests(re.compile(f"{api.url}/audio-features"))) == expected_features + assert len(await api_mock.get_requests(url=f"{api.url}/audio-features")) == expected_features - def test_load_with_some_items_and_no_extension( + async def test_load_with_some_items_and_no_extension( self, response_valid: dict[str, Any], item_kind: RemoteObjectType, @@ -269,19 +268,19 @@ def test_load_with_some_items_and_no_extension( api: SpotifyAPI, api_mock: SpotifyMock ): - api.extend_items(response_valid, kind=RemoteObjectType.ARTIST, key=item_kind) - api_mock.reset_mock() # reset for new requests checks to work correctly + await api.extend_items(response_valid, kind=RemoteObjectType.ARTIST, key=item_kind) + api_mock.reset() # reset for new requests checks to work correctly assert len(response_valid[item_key][api.items_key]) == response_valid[item_key]["total"] - assert not api_mock.get_requests(response_valid[item_key]["href"]) + assert not await api_mock.get_requests(url=response_valid[item_key]["href"]) - result: SpotifyArtist = SpotifyArtist.load(response_valid, api=api, items=load_items, extend_albums=True) + result: SpotifyArtist = await SpotifyArtist.load(response_valid, api=api, items=load_items, extend_albums=True) - self.assert_load_with_items_requests( + await self.assert_load_with_items_requests( response=response_valid, result=result, items=load_items, key=item_key, api_mock=api_mock ) # requests for extension data - assert not api_mock.get_requests(re.compile(f"{result.url}/{item_key}")) - assert not api_mock.get_requests(re.compile(f"{api.url}/audio-features")) - assert not api_mock.get_requests(re.compile(f"{api.url}/artists")) + assert not await api_mock.get_requests(url=f"{result.url}/{item_key}") + assert not await api_mock.get_requests(url=f"{api.url}/audio-features") + assert not await api_mock.get_requests(url=f"{api.url}/artists") diff --git a/tests/libraries/remote/spotify/object/test_playlist.py b/tests/libraries/remote/spotify/object/test_playlist.py index aeef922b..00f923ff 100644 --- a/tests/libraries/remote/spotify/object/test_playlist.py +++ b/tests/libraries/remote/spotify/object/test_playlist.py @@ -1,4 +1,3 @@ -import re from collections.abc import Iterable from copy import deepcopy from datetime import datetime @@ -30,12 +29,6 @@ def collection_merge_items(self, api_mock: SpotifyMock) -> Iterable[SpotifyTrack def item_kind(self, api: SpotifyAPI) -> RemoteObjectType: return api.collection_item_map[RemoteObjectType.PLAYLIST] - @pytest.fixture - def playlist(self, response_valid: dict[str, Any], api: SpotifyAPI) -> SpotifyPlaylist: - pl = SpotifyPlaylist(response=response_valid, api=api) - pl._tracks = [item for item in pl.items if pl.items.count(item) == 1] - return pl - @pytest.fixture def response_random(self, api_mock: SpotifyMock) -> dict[str, Any]: """Yield a randomly generated response from the Spotify API for a track item type""" @@ -45,14 +38,14 @@ def response_random(self, api_mock: SpotifyMock) -> dict[str, Any]: return response @pytest.fixture - def _response_valid(self, api: SpotifyAPI, api_mock: SpotifyMock) -> dict[str, Any]: + async def _response_valid(self, api: SpotifyAPI, api_mock: SpotifyMock) -> dict[str, Any]: response = next( deepcopy(pl) for pl in api_mock.user_playlists if pl["tracks"]["total"] > 50 and len(pl["tracks"]["items"]) > 10 ) - api.extend_items(response=response, key=RemoteObjectType.TRACK) + await api.extend_items(response=response, key=RemoteObjectType.TRACK) - api_mock.reset_mock() # reset for new requests checks to work correctly + api_mock.reset() # reset for new requests checks to work correctly return response @pytest.fixture @@ -63,11 +56,17 @@ def response_valid(self, _response_valid: dict[str, Any]) -> dict[str, Any]: """ return deepcopy(_response_valid) - def test_input_validation(self, response_random: dict[str, Any], api_mock: SpotifyMock): + @pytest.fixture + def playlist(self, response_valid: dict[str, Any], api: SpotifyAPI) -> SpotifyPlaylist: + pl = SpotifyPlaylist(response=response_valid, api=api) + pl._tracks = [item for item in pl.items if pl.items.count(item) == 1] + return pl + + async def test_input_validation(self, response_random: dict[str, Any], api_mock: SpotifyMock): with pytest.raises(RemoteObjectTypeError): SpotifyPlaylist(api_mock.generate_album(track_count=0)) with pytest.raises(APIError): - SpotifyPlaylist(response_random).reload() + await SpotifyPlaylist(response_random).reload() response_random["tracks"]["total"] += 10 with pytest.raises(RemoteError): @@ -81,11 +80,11 @@ def test_input_validation(self, response_random: dict[str, Any], api_mock: Spoti # no API set, these will not run with pytest.raises(APIError): - pl.reload() + await pl.reload() with pytest.raises(APIError): - pl.delete() + await pl.delete() with pytest.raises(RemoteError): - pl.sync() + await pl.sync() def test_writeable(self, response_valid: dict[str, Any], api: SpotifyAPI, api_mock: SpotifyMock): pl = SpotifyPlaylist(response_valid) @@ -177,7 +176,7 @@ def test_refresh(self, response_valid: dict[str, Any]): pl.refresh(skip_checks=True) assert len(pl.tracks) == original_track_count // 2 - def test_reload(self, response_valid: dict[str, Any], api: SpotifyAPI): + async def test_reload(self, response_valid: dict[str, Any], api: SpotifyAPI): response_valid["description"] = None response_valid["public"] = not response_valid["public"] response_valid["collaborative"] = not response_valid["collaborative"] @@ -188,7 +187,7 @@ def test_reload(self, response_valid: dict[str, Any], api: SpotifyAPI): assert pl.collaborative is response_valid["collaborative"] pl.api = api - pl.reload(extend_artists=True) + await pl.reload(extend_artists=True) assert pl.description assert pl.public is not response_valid["public"] assert pl.collaborative is not response_valid["collaborative"] @@ -197,13 +196,13 @@ def test_reload(self, response_valid: dict[str, Any], api: SpotifyAPI): ## Load method tests ########################################################################### @staticmethod - def get_load_without_items( + async def get_load_without_items( loader: SpotifyPlaylist, response_valid: dict[str, Any], api: SpotifyAPI, api_mock: SpotifyMock ): - return loader.load(response_valid["href"], api=api, extend_tracks=True) + return await loader.load(response_valid["href"], api=api, extend_tracks=True) @pytest.fixture def load_items( @@ -244,17 +243,17 @@ def load_items( return items - def test_load_with_all_items( + async def test_load_with_all_items( self, response_valid: dict[str, Any], item_key: str, api: SpotifyAPI, api_mock: SpotifyMock ): load_items = [SpotifyTrack(response) for response in response_valid[item_key][api.items_key]] - SpotifyPlaylist.load( + await SpotifyPlaylist.load( response_valid, api=api, items=load_items, extend_albums=True, extend_tracks=False, extend_features=False ) - assert not api_mock.request_history + api_mock.assert_not_called() - def test_load_with_some_items( + async def test_load_with_some_items( self, response_valid: dict[str, Any], item_key: str, @@ -264,24 +263,24 @@ def test_load_with_some_items( ): kind = RemoteObjectType.PLAYLIST - result: SpotifyPlaylist = SpotifyPlaylist.load( + result: SpotifyPlaylist = await SpotifyPlaylist.load( response_valid, api=api, items=load_items, extend_tracks=True, extend_features=True ) - self.assert_load_with_items_requests( + await self.assert_load_with_items_requests( response=response_valid, result=result, items=load_items, key=item_key, api_mock=api_mock ) - self.assert_load_with_items_extended( + await self.assert_load_with_items_extended( response=response_valid, result=result, items=load_items, kind=kind, key=item_key, api_mock=api_mock ) # requests for extension data expected = api_mock.calculate_pages_from_response(response_valid) # -1 for not calling initial page - assert len(api_mock.get_requests(re.compile(f"{result.url}/{item_key}"))) == expected - 1 - assert len(api_mock.get_requests(re.compile(f"{api.url}/audio-features"))) == expected + assert len(await api_mock.get_requests(url=f"{result.url}/{item_key}")) == expected - 1 + assert len(await api_mock.get_requests(url=f"{api.url}/audio-features")) == expected - def test_load_with_some_items_and_no_extension( + async def test_load_with_some_items_and_no_extension( self, response_valid: dict[str, Any], item_kind: RemoteObjectType, @@ -290,51 +289,52 @@ def test_load_with_some_items_and_no_extension( api: SpotifyAPI, api_mock: SpotifyMock ): - api.extend_items(response_valid, kind=RemoteObjectType.PLAYLIST, key=item_kind) - api_mock.reset_mock() # reset for new requests checks to work correctly + await api.extend_items(response_valid, kind=RemoteObjectType.PLAYLIST, key=item_kind) + api_mock.reset() # reset for new requests checks to work correctly assert len(response_valid[item_key][api.items_key]) == response_valid[item_key]["total"] - assert not api_mock.get_requests(response_valid[item_key]["href"]) + assert not await api_mock.get_requests(url=response_valid[item_key]["href"]) - result: SpotifyPlaylist = SpotifyPlaylist.load( + result: SpotifyPlaylist = await SpotifyPlaylist.load( response_valid, api=api, items=load_items, extend_tracks=True, extend_features=False ) - self.assert_load_with_items_requests( + await self.assert_load_with_items_requests( response=response_valid, result=result, items=load_items, key=item_key, api_mock=api_mock ) - assert not api_mock.get_requests(response_valid[item_key]["href"]) + assert not await api_mock.get_requests(url=response_valid[item_key]["href"]) # requests for extension data - assert not api_mock.get_requests(re.compile(f"{result.url}/{item_key}")) # already extended on input - assert not api_mock.get_requests(re.compile(f"{api.url}/audio-features")) + assert not await api_mock.get_requests(url=f"{result.url}/{item_key}") # already extended on input + assert not await api_mock.get_requests(url=f"{api.url}/audio-features") - def test_create_playlist(self, api: SpotifyAPI, api_mock: SpotifyMock): + async def test_create_playlist(self, api: SpotifyAPI, api_mock: SpotifyMock): name = "new playlist" - pl = SpotifyPlaylist.create(api=api, name="new playlist", public=False, collaborative=True) + pl = await SpotifyPlaylist.create(api=api, name="new playlist", public=False, collaborative=True) url = f"{api.url}/users/{api_mock.user_id}/playlists" - body = api_mock.get_requests(url=url, response={"name": name})[0].json() + _, request, _ = next(iter(await api_mock.get_requests(url=url, response={"name": name}))) + payload = self._get_payload_from_request(request) - assert body["name"] == name - assert PROGRAM_NAME in body["description"] - assert not body["public"] - assert body["collaborative"] + assert payload["name"] == name + assert PROGRAM_NAME in payload["description"] + assert not payload["public"] + assert payload["collaborative"] assert pl.name == name assert PROGRAM_NAME in pl.description assert not pl.public assert pl.collaborative - def test_delete_playlist(self, response_valid: dict[str, Any], api: SpotifyAPI, api_mock: SpotifyMock): + async def test_delete_playlist(self, response_valid: dict[str, Any], api: SpotifyAPI, api_mock: SpotifyMock): names = [pl["name"] for pl in api_mock.user_playlists] response = next(deepcopy(pl) for pl in api_mock.user_playlists if names.count(pl["name"]) == 1) - api.extend_items(response=response, key=RemoteObjectType.TRACK) + await api.extend_items(response=response, key=RemoteObjectType.TRACK) pl = SpotifyPlaylist(response=response, api=api) url = pl.url - pl.delete() - assert api_mock.get_requests(url=url + "/followers") + await pl.delete() + assert await api_mock.get_requests(url=url + "/followers") assert not pl.response ########################################################################### @@ -349,8 +349,7 @@ def sync_playlist(self, response_valid: dict[str, Any], api: SpotifyAPI) -> Spot def sync_items( response_valid: dict[str, Any], response_random: dict[str, Any], api: SpotifyAPI, api_mock: SpotifyMock, ) -> list[SpotifyTrack]: - api.load_user_data() - api_mock.reset_mock() # reset for new requests checks to work correctly + api_mock.reset() # reset for new requests checks to work correctly uri_valid = [track["track"]["uri"] for track in response_valid["tracks"]["items"]] return [ @@ -358,20 +357,20 @@ def sync_items( if track["track"]["uri"] not in uri_valid ] - @staticmethod - def get_sync_uris(url: str, api_mock: SpotifyMock) -> tuple[list[str], list[str]]: - requests = api_mock.get_requests(url=f"{url}/tracks") + @classmethod + async def get_sync_uris(cls, url: str, api_mock: SpotifyMock) -> tuple[list[str], list[str]]: + requests = await api_mock.get_requests(url=f"{url}/tracks") uri_add = [] uri_clear = [] - for req in requests: - if not req.body: + for _, request, _ in requests: + payload = cls._get_payload_from_request(request) + if not payload: continue - body = req.json() - if "uris" in body: - uri_add += body["uris"] - elif req.body: - uri_clear += [t["uri"] for t in body["tracks"]] + if "uris" in payload: + uri_add += payload["uris"] + else: + uri_clear += [t["uri"] for t in payload["tracks"]] return uri_add, uri_clear diff --git a/tests/libraries/remote/spotify/object/test_track.py b/tests/libraries/remote/spotify/object/test_track.py index 2ff1d3b8..e739bcae 100644 --- a/tests/libraries/remote/spotify/object/test_track.py +++ b/tests/libraries/remote/spotify/object/test_track.py @@ -47,11 +47,11 @@ def response_valid(self, api_mock: SpotifyMock) -> dict[str, Any]: """ return deepcopy(next(track for track in api_mock.tracks)) - def test_input_validation(self, response_random: dict[str, Any], api_mock: SpotifyMock): + async def test_input_validation(self, response_random: dict[str, Any], api_mock: SpotifyMock): with pytest.raises(RemoteObjectTypeError): SpotifyTrack(api_mock.generate_artist(properties=False)) with pytest.raises(APIError): - SpotifyTrack(response_random).reload() + await SpotifyTrack(response_random).reload() def test_attributes(self, response_random: dict[str, Any]): track = SpotifyTrack(response_random) @@ -183,7 +183,7 @@ def test_refresh(self, response_valid: dict[str, Any]): track.refresh(skip_checks=True) assert len(track.artists) == 1 - def test_reload(self, response_valid: dict[str, Any], api: SpotifyAPI): + async def test_reload(self, response_valid: dict[str, Any], api: SpotifyAPI): response_valid["album"].pop("name", None) response_valid.pop("audio_features", None) @@ -193,7 +193,7 @@ def test_reload(self, response_valid: dict[str, Any], api: SpotifyAPI): assert not track.bpm track.api = api - track.reload(features=True) + await track.reload(features=True) assert track.album if track.response["audio_features"]["key"] > -1: assert track.key @@ -201,8 +201,8 @@ def test_reload(self, response_valid: dict[str, Any], api: SpotifyAPI): assert track.key is None assert track.bpm - def test_load(self, response_valid: dict[str, Any], api: SpotifyAPI): - track = SpotifyTrack.load(response_valid["href"], api=api) + async def test_load(self, response_valid: dict[str, Any], api: SpotifyAPI): + track = await SpotifyTrack.load(response_valid["href"], api=api) assert track.name == response_valid["name"] assert track.id == response_valid["id"] diff --git a/tests/libraries/remote/spotify/object/testers.py b/tests/libraries/remote/spotify/object/testers.py index 15dca298..35a7651c 100644 --- a/tests/libraries/remote/spotify/object/testers.py +++ b/tests/libraries/remote/spotify/object/testers.py @@ -1,7 +1,7 @@ from abc import ABCMeta, abstractmethod from collections.abc import Iterable from typing import Any -from urllib.parse import parse_qs +from urllib.parse import unquote import pytest @@ -33,7 +33,7 @@ def item_key(self, item_kind: RemoteObjectType) -> str: ## Assertions ########################################################################### @staticmethod - def assert_load_with_items_requests[T: SpotifyObject]( + async def assert_load_with_items_requests[T: SpotifyObject]( response: dict[str, Any], result: SpotifyCollectionLoader[T], items: list[T], @@ -43,19 +43,18 @@ def assert_load_with_items_requests[T: SpotifyObject]( """Run assertions on the requests from load method with given ``items``""" assert len(result.response[key][result.api.items_key]) == response[key]["total"] assert len(result.items) == response[key]["total"] - assert not api_mock.get_requests(result.url) # main collection URL was not called + assert not await api_mock.get_requests(url=result.url) # main collection URL was not called # ensure none of the input_ids were requested input_ids = {item.id for item in items} - for request in api_mock.get_requests(f"{result.url}/{key}"): - params = parse_qs(request.query) - if "ids" not in params: + for url, _, _ in await api_mock.get_requests(url=f"{result.url}/{key}"): + if "ids" not in url.query: continue - assert not input_ids.intersection(params["ids"][0].split(",")) + assert not input_ids.intersection(unquote(url.query["ids"]).split(",")) @staticmethod - def assert_load_with_items_extended[T: SpotifyObject]( + async def assert_load_with_items_extended[T: SpotifyObject]( response: dict[str, Any], result: SpotifyCollectionLoader[T], items: list[T], @@ -64,7 +63,7 @@ def assert_load_with_items_extended[T: SpotifyObject]( api_mock: SpotifyMock, ): """Run assertions on the requests for missing data from load method with given ``items``""" - requests_missing = api_mock.get_requests(f"{result.api.url}/{key}") + requests_missing = await api_mock.get_requests(url=f"{result.api.url}/{key}") limit = response[key]["limit"] input_ids = {item.id for item in items} response_item_ids = { @@ -78,7 +77,7 @@ def assert_load_with_items_extended[T: SpotifyObject]( ########################################################################### @staticmethod @abstractmethod - def get_load_without_items( + async def get_load_without_items( loader: SpotifyCollectionLoader, response_valid: dict[str, Any], api: SpotifyAPI, @@ -87,7 +86,7 @@ def get_load_without_items( """Yields the results from 'load' where no items are given as a pytest.fixture.""" raise NotImplementedError - def test_load_without_items( + async def test_load_without_items( self, collection: SpotifyCollectionLoader, response_valid: dict[str, Any], @@ -95,7 +94,7 @@ def test_load_without_items( api: SpotifyAPI, api_mock: SpotifyMock ): - result = self.get_load_without_items( + result = await self.get_load_without_items( loader=collection, response_valid=response_valid, api=api, api_mock=api_mock ) @@ -107,13 +106,13 @@ def test_load_without_items( if not isinstance(result, SpotifyArtist): expected -= 1 # -1 for not calling initial page - assert len(api_mock.get_requests(result.url)) == 1 - assert len(api_mock.get_requests(f"{result.url}/{item_key}")) == expected - assert not api_mock.get_requests(f"{api.url}/audio-features") - assert not api_mock.get_requests(f"{api.url}/audio-analysis") + assert len(await api_mock.get_requests(url=result.url)) == 1 + assert len(await api_mock.get_requests(url=f"{result.url}/{item_key}")) == expected + assert not await api_mock.get_requests(url=f"{api.url}/audio-features") + assert not await api_mock.get_requests(url=f"{api.url}/audio-analysis") # input items given, but no key to search on still loads - result = collection.load(response_valid, api=api, items=response_valid.pop(item_key), extend_tracks=True) + result = await collection.load(response_valid, api=api, items=response_valid.pop(item_key), extend_tracks=True) assert result.name == response_valid["name"] assert result.id == response_valid["id"] diff --git a/tests/libraries/remote/spotify/test_library.py b/tests/libraries/remote/spotify/test_library.py index e10cdbae..0c90409c 100644 --- a/tests/libraries/remote/spotify/test_library.py +++ b/tests/libraries/remote/spotify/test_library.py @@ -1,6 +1,6 @@ from copy import deepcopy from random import sample -from urllib.parse import parse_qs +from urllib.parse import unquote import pytest @@ -28,21 +28,21 @@ def library_unloaded(self, api: SpotifyAPI, api_mock: SpotifyMock) -> SpotifyLib return SpotifyLibrary(api=api, playlist_filter=include) @pytest.fixture(scope="class") - def _library(self, api: SpotifyAPI, _api_mock: SpotifyMock) -> SpotifyLibrary: + async def _library(self, api: SpotifyAPI, _api_mock: SpotifyMock) -> SpotifyLibrary: include = FilterDefinedList([pl["name"] for pl in sample(_api_mock.user_playlists, k=10)]) library = SpotifyLibrary(api=api, playlist_filter=include) - library.load() + await library.load() return library @pytest.fixture def library(self, _library: SpotifyLibrary) -> SpotifyLibrary: return deepcopy(_library) - def test_filter_playlists(self, api: SpotifyAPI, api_mock: SpotifyMock): + async def test_filter_playlists(self, api: SpotifyAPI, api_mock: SpotifyMock): # keep all when no include or exclude settings defined library = SpotifyLibrary(api=api) - responses = api.get_user_items(kind=RemoteObjectType.PLAYLIST) + responses = await api.get_user_items(kind=RemoteObjectType.PLAYLIST) filtered = library._filter_playlists(responses) assert len(filtered) == len(api_mock.user_playlists) == len(responses) @@ -57,28 +57,28 @@ def test_filter_playlists(self, api: SpotifyAPI, api_mock: SpotifyMock): ########################################################################### ## Load tests ########################################################################### - def test_load_tracks(self, library_unloaded: SpotifyLibrary, api_mock: SpotifyMock): - library_unloaded.load_tracks() + async def test_load_tracks(self, library_unloaded: SpotifyLibrary, api_mock: SpotifyMock): + await library_unloaded.load_tracks() assert len(library_unloaded.tracks) == len(api_mock.user_tracks) # does not add duplicates to the loaded list - library_unloaded.load_tracks() + await library_unloaded.load_tracks() assert len(library_unloaded.tracks) == len(api_mock.user_tracks) - def test_load_saved_albums(self, library_unloaded: SpotifyLibrary, api_mock: SpotifyMock): - library_unloaded.load_saved_albums() + async def test_load_saved_albums(self, library_unloaded: SpotifyLibrary, api_mock: SpotifyMock): + await library_unloaded.load_saved_albums() assert len(library_unloaded.albums) == len(api_mock.user_albums) # does not add duplicates to the loaded list - library_unloaded.load_saved_albums() + await library_unloaded.load_saved_albums() assert len(library_unloaded.albums) == len(api_mock.user_albums) - def test_load_saved_artists(self, library_unloaded: SpotifyLibrary, api_mock: SpotifyMock): - library_unloaded.load_saved_artists() + async def test_load_saved_artists(self, library_unloaded: SpotifyLibrary, api_mock: SpotifyMock): + await library_unloaded.load_saved_artists() assert len(library_unloaded.artists) == len(api_mock.user_artists) # does not add duplicates to the loaded list - library_unloaded.load_saved_artists() + await library_unloaded.load_saved_artists() assert len(library_unloaded.artists) == len(api_mock.user_artists) ########################################################################### @@ -86,7 +86,7 @@ def test_load_saved_artists(self, library_unloaded: SpotifyLibrary, api_mock: Sp ########################################################################### # noinspection PyMethodOverriding,PyTestUnpassedFixture @pytest.mark.slow - def test_enrich_tracks(self, library: SpotifyLibrary, api_mock: SpotifyMock, **kwargs): + async def test_enrich_tracks(self, library: SpotifyLibrary, api_mock: SpotifyMock, **kwargs): def validate_track_extras_not_enriched(t: SpotifyTrack) -> None: """Check track does not contain audio features or analysis fields""" assert "audio_features" not in t.response @@ -116,7 +116,7 @@ def validate_artists_not_enriched(t: SpotifyTrack) -> None: validate_artists_not_enriched(track) # enriches only albums - library.enrich_tracks(albums=True, artists=False) + await library.enrich_tracks(albums=True, artists=False) for track in library.tracks: assert track.album == track_album_map[track.uri] assert "external_ids" in track.response["album"] @@ -127,18 +127,18 @@ def validate_artists_not_enriched(t: SpotifyTrack) -> None: validate_track_extras_not_enriched(track) validate_artists_not_enriched(track) - assert len(api_mock.get_requests(url=track.response["album"]["href"] + "/tracks")) == 0 + assert len(await api_mock.get_requests(url=track.response["album"]["href"] + "/tracks")) == 0 # check requests - assert len(api_mock.get_requests(url=library.api.url + "/artists")) == 0 - req_albums = api_mock.get_requests(url=library.api.url + "/albums") - req_album_ids = {id_ for req in req_albums for id_ in parse_qs(req.query)["ids"][0].split(",")} + assert len(await api_mock.get_requests(url=library.api.url + "/artists")) == 0 + req_albums = await api_mock.get_requests(url=library.api.url + "/albums") + req_album_ids = {id_ for url, _, _ in req_albums for id_ in unquote(url.query["ids"]).split(",")} assert req_album_ids == album_ids - api_mock.reset_mock() # reset for new requests checks to work correctly + api_mock.reset() # reset for new requests checks to work correctly # enriches artists without replacing previous enrichment - library.enrich_tracks(albums=False, artists=True) + await library.enrich_tracks(albums=False, artists=True) for track in library.tracks: assert track.album == track_album_map[track.uri] assert "external_ids" in track.response["album"] @@ -146,7 +146,7 @@ def validate_artists_not_enriched(t: SpotifyTrack) -> None: assert "popularity" in track.response["album"] assert "tracks" not in track.response["album"] - assert len(api_mock.get_requests(url=track.response["album"]["href"] + "/tracks")) == 0 + assert len(await api_mock.get_requests(url=track.response["album"]["href"] + "/tracks")) == 0 assert [a.name for a in track.artists] == track_artists_map[track.uri] for artist in track.response["artists"]: @@ -157,38 +157,38 @@ def validate_artists_not_enriched(t: SpotifyTrack) -> None: validate_track_extras_not_enriched(track) # check requests - assert len(api_mock.get_requests(url=library.api.url + "/albums")) == 0 - req_artists = api_mock.get_requests(url=library.api.url + "/artists") - req_artist_ids = {id_ for req in req_artists for id_ in parse_qs(req.query)["ids"][0].split(",")} + assert len(await api_mock.get_requests(url=library.api.url + "/albums")) == 0 + req_artists = await api_mock.get_requests(url=library.api.url + "/artists") + req_artist_ids = {id_ for url, _, _ in req_artists for id_ in unquote(url.query["ids"]).split(",")} assert req_artist_ids == artist_ids # just check these fields were now added - library.enrich_tracks(features=True, analysis=True) + await library.enrich_tracks(features=True, analysis=True) for track in library.tracks: assert "audio_features" in track.response assert "audio_analysis" in track.response @pytest.mark.slow - def test_enrich_saved_albums(self, library: SpotifyLibrary, **kwargs): + async def test_enrich_saved_albums(self, library: SpotifyLibrary, **kwargs): # ensure at least some albums are not enriched already assert any(len(album.response["tracks"]["items"]) != album.track_total for album in library.albums) assert any(len(album.tracks) != album.track_total for album in library.albums) - library.enrich_saved_albums() + await library.enrich_saved_albums() for album in library.albums: assert len(album.response["tracks"]["items"]) == album.track_total assert len(album.tracks) == album.track_total # noinspection PyMethodOverriding @pytest.mark.slow - def test_enrich_saved_artists(self, library: SpotifyLibrary, api_mock: SpotifyMock, **kwargs): + async def test_enrich_saved_artists(self, library: SpotifyLibrary, api_mock: SpotifyMock, **kwargs): # ensure artists are not enriched already for artist in library.artists: assert "albums" not in artist.response assert len(artist.albums) == 0 # gets albums but does not extend them - library.enrich_saved_artists(tracks=False) + await library.enrich_saved_artists(tracks=False) assert any(len(artist.response["albums"]["items"]) > 0 for artist in library.artists) for artist in library.artists: assert "albums" in artist.response @@ -200,15 +200,15 @@ def test_enrich_saved_artists(self, library: SpotifyLibrary, api_mock: SpotifyMo assert len(album.tracks) == len(album.response["tracks"].get("items", [])) == 0 # only album URLs were called - req_urls = [req.url.split("?")[0] for req in api_mock.request_history] + req_urls = [str(url.with_query(None)) for url, _, _ in await api_mock.get_requests()] assert req_urls == [artist.url + "/albums" for artist in library.artists] - library.enrich_saved_artists(tracks=True) + await library.enrich_saved_artists(tracks=True) for artist in library.artists: for response in artist.response["albums"]["items"]: assert len(response["tracks"].get("items", [])) == response["total_tracks"] > 0 for album in artist.albums: assert len(album.tracks) == len(album.response["tracks"].get("items", [])) == album.track_total > 0 - req_urls = set(req.url.split("?")[0] for req in api_mock.request_history) + req_urls = set(str(url.with_query(None)) for url, _, _ in await api_mock.get_requests()) assert all(album.url + "/tracks" in req_urls for artist in library.artists for album in artist.albums)