Skip to content

Commit

Permalink
Improve oauth layer (#133)
Browse files Browse the repository at this point in the history
  • Loading branch information
inverse authored Feb 19, 2022
1 parent 7bd7732 commit 8b4a9d7
Show file tree
Hide file tree
Showing 2 changed files with 67 additions and 4 deletions.
67 changes: 65 additions & 2 deletions iolite_client/oauth_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,23 @@ async def get_sid(self, access_token: str) -> str:
return response_json.get("SID")


class OAuthStorage:
class AsyncOAuthStorageInterface:
async def store_access_token(self, payload: dict):
raise NotImplementedError

async def fetch_access_token(self) -> Optional[dict]:
raise NotImplementedError


class OAuthStorageInterface:
def store_access_token(self, payload: dict):
raise NotImplementedError

def fetch_access_token(self) -> Optional[dict]:
raise NotImplementedError


class LocalOAuthStorage(OAuthStorageInterface):
def __init__(self, path: str):
self.path = path

Expand Down Expand Up @@ -178,7 +194,9 @@ def __get_path(self, payload_type: str):


class OAuthWrapper:
def __init__(self, oauth_handler: OAuthHandler, oauth_storage: OAuthStorage):
def __init__(
self, oauth_handler: OAuthHandler, oauth_storage: OAuthStorageInterface
):
self.oauth_handler = oauth_handler
self.oauth_storage = oauth_storage

Expand Down Expand Up @@ -218,3 +236,48 @@ def _refresh_access_token(self, access_token):
)
self.oauth_storage.store_access_token(refreshed_token)
return refreshed_token["access_token"]


class AsyncOAuthWrapper:
def __init__(
self,
oauth_handler: AsyncOAuthHandler,
oauth_storage: AsyncOAuthStorageInterface,
):
self.oauth_handler = oauth_handler
self.oauth_storage = oauth_storage

async def get_sid(self, code: str, name: str) -> str:
access_token = await self.oauth_storage.fetch_access_token()

if access_token is None:
logger.debug("No token, requesting")
access_token = await self.oauth_handler.get_access_token(code, name)
await self.oauth_storage.store_access_token(access_token)

if access_token["expires_at"] < time.time():
logger.debug("Token expired, refreshing")
token = await self._refresh_token(access_token)
else:
token = access_token["access_token"]

logger.debug("Fetched access token")

try:
return await self.oauth_handler.get_sid(token)
except BaseException as e:
logger.debug(f"Invalid token, attempt refresh: {e}")
token = await self._refresh_token(access_token)
return await self.oauth_handler.get_sid(token)

async def _refresh_token(self, access_token: dict) -> str:
"""Refresh token."""
refreshed_token = await self.oauth_handler.get_new_access_token(
access_token["refresh_token"]
)
expires_at = time.time() + refreshed_token["expires_in"]
refreshed_token.update({"expires_at": expires_at})
del refreshed_token["expires_in"]
await self.oauth_storage.store_access_token(refreshed_token)

return refreshed_token["access_token"]
4 changes: 2 additions & 2 deletions scripts/example.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

from iolite_client.client import Client
from iolite_client.entity import RadiatorValve
from iolite_client.oauth_handler import OAuthHandler, OAuthStorage, OAuthWrapper
from iolite_client.oauth_handler import LocalOAuthStorage, OAuthHandler, OAuthWrapper

env = Env()
env.read_env()
Expand All @@ -20,7 +20,7 @@
logger = logging.getLogger(__name__)

# Get SID
oauth_storage = OAuthStorage(".")
oauth_storage = LocalOAuthStorage(".")
oauth_handler = OAuthHandler(USERNAME, PASSWORD)
oauth_wrapper = OAuthWrapper(oauth_handler, oauth_storage)
sid = oauth_wrapper.get_sid(CODE, NAME)
Expand Down

0 comments on commit 8b4a9d7

Please sign in to comment.