Skip to content

Commit

Permalink
add back __call__ methods on various awaitable objects
Browse files Browse the repository at this point in the history
  • Loading branch information
geo-martino committed May 24, 2024
1 parent 1ea89be commit 7de6055
Show file tree
Hide file tree
Showing 9 changed files with 24 additions and 15 deletions.
4 changes: 2 additions & 2 deletions docs/_howto/scripts/sync/p1.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,8 @@ async def match_albums_to_remote(albums: Collection[MusifyCollection], factory:

searcher = RemoteItemSearcher(matcher=matcher, object_factory=factory)
async with searcher:
await searcher.search(albums)
await searcher(albums)

checker = RemoteItemChecker(matcher=matcher, object_factory=factory)
async with checker:
await checker.check(albums)
await checker(albums)
4 changes: 0 additions & 4 deletions docs/release-history.rst
Original file line number Diff line number Diff line change
Expand Up @@ -88,10 +88,6 @@ Removed
* Dependency on ``requests-cache`` package in favour of custom cache implementation.
* ``use_cache`` parameter from all :py:class:`.RemoteAPI` related methods.
Cache settings now handled by :py:class:`.ResponseCache`
* Ability to call the following classes directly due to their need to now by asynchronous:
* :py:class:`.APIAuthoriser`
* :py:class:`.RemoteItemChecker`
* :py:class:`.RemoteItemSearcher`
* ThreadPoolExecutor use on :py:class:`.RemoteItemSearcher`. Now uses asynchronous logic instead.

Documentation
Expand Down
5 changes: 4 additions & 1 deletion musify/api/authorise.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import logging
import os
import socket
from collections.abc import Callable, Mapping, Sequence, MutableMapping
from collections.abc import Callable, Mapping, Sequence, MutableMapping, Awaitable
from datetime import datetime
from typing import Any
from urllib.parse import unquote
Expand Down Expand Up @@ -204,6 +204,9 @@ def save_token(self) -> None:
with open(self.token_file_path, "w") as file:
json.dump(self.token, file, indent=2)

def __call__(self, force_load: bool = False, force_new: bool = False) -> Awaitable[dict[str, str]]:
return self.authorise(force_load=force_load, force_new=force_new)

async def authorise(self, force_load: bool = False, force_new: bool = False) -> dict[str, str]:
"""
Main method for authorisation which tests/refreshes/reauthorises as needed.
Expand Down
2 changes: 1 addition & 1 deletion musify/api/request.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,7 @@ async def authorise(self, force_load: bool = False, force_new: bool = False) ->

headers = {}
if self.authoriser is not None:
headers = await self.authoriser.authorise(force_load=force_load, force_new=force_new)
headers = await self.authoriser(force_load=force_load, force_new=force_new)
self.session.headers.update(headers)

return headers
Expand Down
7 changes: 6 additions & 1 deletion musify/libraries/remote/core/processors/check.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
"""
import logging
from collections import Counter
from collections.abc import Sequence, Collection, Iterator
from collections.abc import Sequence, Collection, Iterator, Awaitable
from concurrent.futures import ThreadPoolExecutor
from dataclasses import dataclass, field
from typing import Any, Self
Expand Down Expand Up @@ -173,6 +173,11 @@ async def _delete_playlists(self) -> None:
self._playlist_name_urls.clear()
self._playlist_name_collection.clear()

def __call__[T: MusifyItemSettable](
self, collections: Collection[MusifyCollection[T]]
) -> Awaitable[ItemCheckResult[T] | None]:
return self.check(collections)

async def check[T: MusifyItemSettable](
self, collections: Collection[MusifyCollection[T]]
) -> ItemCheckResult[T] | None:
Expand Down
7 changes: 6 additions & 1 deletion musify/libraries/remote/core/processors/search.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
and assigns the ID of the matched object back to the item.
"""
import logging
from collections.abc import Mapping, Sequence, Iterable, Collection
from collections.abc import Mapping, Sequence, Iterable, Collection, Awaitable
from dataclasses import dataclass, field
from typing import Any, Self

Expand Down Expand Up @@ -193,6 +193,11 @@ def _determine_remote_object_type(obj: MusifyObject) -> RemoteObjectType:
return obj.kind
raise MusifyAttributeError(f"Given object does not specify a RemoteObjectType: {obj.__class__.__name__}")

def __call__[T: MusifyItemSettable](
self, collections: Collection[MusifyCollection[T]]
) -> Awaitable[dict[str, ItemSearchResult[T]]]:
return self.search(collections)

async def search[T: MusifyItemSettable](
self, collections: Collection[MusifyCollection[T]]
) -> dict[str, ItemSearchResult[T]]:
Expand Down
6 changes: 3 additions & 3 deletions tests/api/test_authorise.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,7 +219,7 @@ async def test_auth_new_token(self, token: dict[str, Any], token_file_path: str,
response = {"access_token": "valid token", "expires_in": 3000, "refresh_token": "new_refresh"}
requests_mock.post(authoriser.auth_args["url"], payload=response)

await authoriser.authorise()
await authoriser()
expected_header = {"Authorization": "Bearer valid token"}
assert authoriser.headers == expected_header
assert authoriser.token["refresh_token"] == "new_refresh"
Expand All @@ -235,7 +235,7 @@ async def test_auth_load_and_token_valid(self, token_file_path: str, requests_mo
requests_mock.get(authoriser.test_args["url"], payload={"test": "valid"})

# loads token, token is valid, no refresh needed
await authoriser.authorise()
await authoriser()
expected_header = {"Authorization": f"Bearer {authoriser.token["access_token"]}"}
assert authoriser.headers == expected_header

Expand Down Expand Up @@ -273,7 +273,7 @@ async def test_auth_new_token_and_no_refresh(

requests_mock.post(authoriser.auth_args["url"], payload={"1": {"2": {"code": "token"}}})

await authoriser.authorise()
await authoriser()
expected_header = {"Authorization": "Bearer token"}
assert authoriser.headers == expected_header

Expand Down
2 changes: 1 addition & 1 deletion tests/libraries/remote/core/processors/check.py
Original file line number Diff line number Diff line change
Expand Up @@ -576,7 +576,7 @@ async def add_collection(self, collection: BasicCollection):
values = ["", *["ua" for _ in batch_1], "s", *["ua" for _ in batch_2]]
patch_input(values, mocker=mocker)

result = await checker.check(collections)
result = await checker(collections)
mocker.stopall()

assert count == len(batch_1) + len(batch_2) # only 2 batches executed
Expand Down
2 changes: 1 addition & 1 deletion tests/libraries/remote/core/processors/search.py
Original file line number Diff line number Diff line change
Expand Up @@ -277,7 +277,7 @@ async def test_search(
search_collection = BasicCollection(name="test", items=search_items + unmatchable_items)
skip_album = len([item for item in search_album if item.has_uri is not None])

results = await searcher.search([search_collection, search_album])
results = await searcher([search_collection, search_album])

result = results[search_collection.name]
assert len(result.matched) + len(result.unmatched) + len(result.skipped) == len(search_collection)
Expand Down

0 comments on commit 7de6055

Please sign in to comment.