Skip to content

Commit

Permalink
Refactor token logic to internal methods to ensure stable API
Browse files Browse the repository at this point in the history
  • Loading branch information
NeonDaniel committed Jan 24, 2024
1 parent 0ecfadf commit 683bcf2
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 11 deletions.
12 changes: 6 additions & 6 deletions neon_utils/hana_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,14 +59,14 @@ def _init_client(backend_address: str):
with open(_client_config_path) as f:
_client_config = json.load(f)
else:
get_token(backend_address)
_get_token(backend_address)

if not _headers:
_headers = {"Authorization": f"Bearer {_client_config['access_token']}"}


def get_token(backend_address: str, username: str = "guest",
password: str = "password"):
def _get_token(backend_address: str, username: str = "guest",
password: str = "password"):
"""
Get new auth tokens from the specified server. This will cache the returned
token, overwriting any previous data at the cache path.
Expand All @@ -87,7 +87,7 @@ def get_token(backend_address: str, username: str = "guest",
json.dump(_client_config, f, indent=2)


def refresh_token(backend_address: str):
def _refresh_token(backend_address: str):
"""
Get new tokens from the specified server using an existing refresh token
(if it exists). This will update the cached tokens and associated metadata.
Expand Down Expand Up @@ -127,10 +127,10 @@ def request_backend(endpoint: str, request_data: dict,
_init_client(backend_address)
if time() >= _client_config.get("expiration", 0):
try:
refresh_token(backend_address)
_refresh_token(backend_address)
except ServerException as e:
LOG.error(e)
get_token(backend_address)
_get_token(backend_address)
resp = requests.post(f"{backend_address}/{endpoint.lstrip('/')}",
json=request_data, headers=_headers)
if resp.ok:
Expand Down
10 changes: 5 additions & 5 deletions tests/hana_util_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,10 +73,10 @@ def test_request_backend(self):
# TODO: Test invalid route, invalid request data

def test_00_get_token(self):
from neon_utils.hana_utils import get_token
from neon_utils.hana_utils import _get_token

# Test valid request
get_token(self.test_server)
_get_token(self.test_server)
from neon_utils.hana_utils import _client_config
self.assertTrue(isfile(self.test_path))
with open(self.test_path) as f:
Expand All @@ -93,13 +93,13 @@ def _write_token(*_, **__):
json.dump(valid_config, c)
neon_utils.hana_utils._client_config = valid_config

from neon_utils.hana_utils import refresh_token
from neon_utils.hana_utils import _refresh_token
get_token.side_effect = _write_token

self.assertFalse(isfile(self.test_path))

# Test valid request (auth + refresh)
refresh_token(self.test_server)
_refresh_token(self.test_server)
get_token.assert_called_once()
from neon_utils.hana_utils import _client_config
self.assertTrue(isfile(self.test_path))
Expand All @@ -108,7 +108,7 @@ def _write_token(*_, **__):
self.assertEqual(credentials_on_disk, _client_config)

# Test refresh of existing token (no auth)
refresh_token(self.test_server)
_refresh_token(self.test_server)
get_token.assert_called_once()
with open(self.test_path) as f:
new_credentials = json.load(f)
Expand Down

0 comments on commit 683bcf2

Please sign in to comment.