diff --git a/iolite_client/oauth_handler.py b/iolite_client/oauth_handler.py index 2d43e66..60d410d 100644 --- a/iolite_client/oauth_handler.py +++ b/iolite_client/oauth_handler.py @@ -144,7 +144,23 @@ async def get_sid(self, access_token: str) -> str: return response_json.get("SID") -class OAuthStorage: +class AsyncOAuthStorageInterface: + async def store_access_token(self, payload: dict): + raise NotImplementedError + + async def fetch_access_token(self) -> Optional[dict]: + raise NotImplementedError + + +class OAuthStorageInterface: + def store_access_token(self, payload: dict): + raise NotImplementedError + + def fetch_access_token(self) -> Optional[dict]: + raise NotImplementedError + + +class LocalOAuthStorage(OAuthStorageInterface): def __init__(self, path: str): self.path = path @@ -178,7 +194,9 @@ def __get_path(self, payload_type: str): class OAuthWrapper: - def __init__(self, oauth_handler: OAuthHandler, oauth_storage: OAuthStorage): + def __init__( + self, oauth_handler: OAuthHandler, oauth_storage: OAuthStorageInterface + ): self.oauth_handler = oauth_handler self.oauth_storage = oauth_storage @@ -218,3 +236,48 @@ def _refresh_access_token(self, access_token): ) self.oauth_storage.store_access_token(refreshed_token) return refreshed_token["access_token"] + + +class AsyncOAuthWrapper: + def __init__( + self, + oauth_handler: AsyncOAuthHandler, + oauth_storage: AsyncOAuthStorageInterface, + ): + self.oauth_handler = oauth_handler + self.oauth_storage = oauth_storage + + async def get_sid(self, code: str, name: str) -> str: + access_token = await self.oauth_storage.fetch_access_token() + + if access_token is None: + logger.debug("No token, requesting") + access_token = await self.oauth_handler.get_access_token(code, name) + await self.oauth_storage.store_access_token(access_token) + + if access_token["expires_at"] < time.time(): + logger.debug("Token expired, refreshing") + token = await self._refresh_token(access_token) + else: + token = access_token["access_token"] + + logger.debug("Fetched access token") + + try: + return await self.oauth_handler.get_sid(token) + except BaseException as e: + logger.debug(f"Invalid token, attempt refresh: {e}") + token = await self._refresh_token(access_token) + return await self.oauth_handler.get_sid(token) + + async def _refresh_token(self, access_token: dict) -> str: + """Refresh token.""" + refreshed_token = await self.oauth_handler.get_new_access_token( + access_token["refresh_token"] + ) + expires_at = time.time() + refreshed_token["expires_in"] + refreshed_token.update({"expires_at": expires_at}) + del refreshed_token["expires_in"] + await self.oauth_storage.store_access_token(refreshed_token) + + return refreshed_token["access_token"] diff --git a/scripts/example.py b/scripts/example.py index 17b11de..8882240 100644 --- a/scripts/example.py +++ b/scripts/example.py @@ -5,7 +5,7 @@ from iolite_client.client import Client from iolite_client.entity import RadiatorValve -from iolite_client.oauth_handler import OAuthHandler, OAuthStorage, OAuthWrapper +from iolite_client.oauth_handler import LocalOAuthStorage, OAuthHandler, OAuthWrapper env = Env() env.read_env() @@ -20,7 +20,7 @@ logger = logging.getLogger(__name__) # Get SID -oauth_storage = OAuthStorage(".") +oauth_storage = LocalOAuthStorage(".") oauth_handler = OAuthHandler(USERNAME, PASSWORD) oauth_wrapper = OAuthWrapper(oauth_handler, oauth_storage) sid = oauth_wrapper.get_sid(CODE, NAME)