diff --git a/ovos_utils/oauth.py b/ovos_utils/oauth.py index 300cd98..5b4ff5b 100644 --- a/ovos_utils/oauth.py +++ b/ovos_utils/oauth.py @@ -1,6 +1,12 @@ +import time + +import requests from json_database import JsonStorageXDG +from oauthlib.oauth2 import WebApplicationClient from ovos_config.locations import get_xdg_cache_save_path +from ovos_utils.log import LOG + class OAuthTokenDatabase(JsonStorageXDG): """ This helper class creates ovos-config-assistant/ovos-backend-manager compatible json databases @@ -68,3 +74,77 @@ def delete_application(self, oauth_service): def total_apps(self): return len(self) + + +def refresh_oauth_token(token_id): + """ + Refresh Oauth token for token_idential token_id. + + Argument: + token_id: development credentials identifier + + Returns: + json string containing token and additional information + """ + # Load all needed data for refresh + with OAuthApplicationDatabase() as db: + app_data = db.get(token_id) + with OAuthTokenDatabase() as db: + token_data = db.get(token_id) + + if (app_data is None or + token_data is None or 'refresh_token' not in token_data): + LOG.warning("Token data doesn't contain a refresh token and " + "cannot be refreshed.") + return + + refresh_token = token_data["refresh_token"] + + # Fall back to token endpoint if no specific refresh endpoint + # has been set + token_endpoint = app_data["token_endpoint"] + + client_id = app_data["client_id"] + client_secret = app_data["client_secret"] + + # Perform refresh + client = WebApplicationClient(client_id, refresh_token=refresh_token) + uri, headers, body = client.prepare_refresh_token_request(token_endpoint) + refresh_result = requests.post(uri, headers=headers, data=body, + auth=(client_id, client_secret)) + + if refresh_result.ok: + new_token_data = refresh_result.json() + # Make sure 'expires_at' entry exists in token + if 'expires_at' not in new_token_data: + new_token_data['expires_at'] = time.time() + token_data['expires_in'] + # Store token + with OAuthTokenDatabase() as db: + token_data.update(new_token_data) + db.update_token(token_id, token_data) + + return token_data + + +def get_oauth_token(token_id, auto_refresh=True): + """ + Get Oauth token for token_id + + Argument: + token_id: development credentials identifier + auto_refresh: refresh expired tokens automatically + + Returns: + json string containing token and additional information + """ + if auto_refresh: + expired = False + with OAuthTokenDatabase() as db: + token_data = db.get(token_id) + if "expires_at" not in token_data: + expired = True + elif token_data["expires_at"] >= time.time(): + expired = True + if expired: + return refresh_oauth_token(token_id) + return OAuthTokenDatabase().get_token(token_id) diff --git a/requirements/requirements.txt b/requirements/requirements.txt index e8bb4de..7cfab88 100644 --- a/requirements/requirements.txt +++ b/requirements/requirements.txt @@ -9,4 +9,5 @@ rich-click~=1.7 rich~=13.7 orjson langcodes -timezonefinder \ No newline at end of file +timezonefinder +oauthlib~=3.2 \ No newline at end of file