Skip to content

Commit

Permalink
Extract logic to helper function
Browse files Browse the repository at this point in the history
  • Loading branch information
dkirov-dd committed Jan 22, 2025
1 parent d69443d commit 3614393
Show file tree
Hide file tree
Showing 2 changed files with 35 additions and 11 deletions.
27 changes: 23 additions & 4 deletions datadog_checks_base/datadog_checks/base/utils/http.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,8 @@
# https://tools.ietf.org/html/rfc2988
DEFAULT_TIMEOUT = 10

DEFAULT_EXPIRATION = 300

# 16 KiB seems optimal, and is also the standard chunk size of the Bittorrent protocol:
# https://www.bittorrent.org/beps/bep_0003.html
DEFAULT_CHUNK_SIZE = 16
Expand Down Expand Up @@ -859,11 +861,13 @@ def read(self, **request):
self._expiration = get_timestamp()
try:
# According to https://www.rfc-editor.org/rfc/rfc6749#section-5.1, the `expires_in` field is optional
token_expiration = response.get('expires_in', 0)
self._expiration += token_expiration
self._expiration += _parse_expires_in(response.get('expires_in'))
except TypeError:
LOGGER.warning('OAuth2 included an `expires_in` value of unexpected type %s.', type(token_expiration))

LOGGER.warning(
'The `expires_in` field of the OAuth2 response is not a number, defaulting to %s',
DEFAULT_EXPIRATION,
)
self._expiration += DEFAULT_EXPIRATION
return self._token


Expand Down Expand Up @@ -997,6 +1001,21 @@ def quote_uds_url(url):
return urlunparse(parsed)


def _parse_expires_in(token_expiration):
if isinstance(token_expiration, int) or isinstance(token_expiration, float):
return token_expiration
if isinstance(token_expiration, str):
try:
token_expiration = int(token_expiration)
except ValueError:
LOGGER.warning('Could not convert %s to an integer', token_expiration)
else:
LOGGER.warning('Unexpected type for `expires_in`: %s.', type(token_expiration))
token_expiration = None

return token_expiration


# For documentation generation
# TODO: use an enum and remove STANDARD_FIELDS when mkdocstrings supports it
class StandardFields(object):
Expand Down
19 changes: 12 additions & 7 deletions datadog_checks_base/tests/base/utils/http/test_authtoken.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import pytest

from datadog_checks.base import ConfigurationError
from datadog_checks.base.utils.http import RequestsWrapper
from datadog_checks.base.utils.http import DEFAULT_EXPIRATION, RequestsWrapper
from datadog_checks.base.utils.time import get_timestamp
from datadog_checks.dev import TempDir
from datadog_checks.dev.fs import read_file, write_file
Expand Down Expand Up @@ -470,14 +470,18 @@ def fetch_token(self, *args, **kwargs):
http.get('https://www.google.com')

@pytest.mark.parametrize(
'token_response',
'token_response, expected_expiration',
[
pytest.param({'access_token': 'foo', 'expires_in': 9000}, id='With expires_in'),
pytest.param({'access_token': 'foo'}, id='Without expires_in'),
pytest.param({'access_token': 'foo', 'expires_in': 'two minutes'}, id='With string expires_in'),
pytest.param({'access_token': 'foo', 'expires_in': 9000}, 9000, id='With expires_in'),
pytest.param({'access_token': 'foo'}, DEFAULT_EXPIRATION, id='Without expires_in'),
pytest.param(
{'access_token': 'foo', 'expires_in': 'two minutes'}, DEFAULT_EXPIRATION, id='With string expires_in'
),
pytest.param({'access_token': 'foo', 'expires_in': '3600'}, 3600, id='With numeric string expires_in'),
pytest.param({'access_token': 'foo', 'expires_in': [1, 2, 3]}, 300, id='With list expires_in'),
],
)
def test_success(self, token_response):
def test_success(self, token_response, expected_expiration):
instance = {
'auth_token': {
'reader': {'type': 'oauth', 'url': 'foo', 'client_id': 'bar', 'client_secret': 'baz'},
Expand All @@ -499,7 +503,7 @@ def fetch_token(self, *args, **kwargs):

with mock.patch('requests.get') as get, mock.patch('oauthlib.oauth2.BackendApplicationClient'), mock.patch(
'requests_oauthlib.OAuth2Session', side_effect=MockOAuth2Session
):
), mock.patch('datadog_checks.base.utils.http.get_timestamp', return_value=0):
http.get('https://www.google.com')

get.assert_called_with(
Expand All @@ -514,6 +518,7 @@ def fetch_token(self, *args, **kwargs):
)

assert http.options['headers'] == expected_headers
assert http.auth_token_handler.reader._expiration == expected_expiration

def test_success_with_auth_params(self):
instance = {
Expand Down

0 comments on commit 3614393

Please sign in to comment.