Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ENH: Add get_access_token method to clients #36

Merged
merged 1 commit into from
Mar 1, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 16 additions & 5 deletions msal_requests_auth/auth/base_auth_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,22 +45,33 @@ def __call__(
"""
Adds the token to the authorization header.
"""
token = self.get_access_token()
input_request.headers[
"Authorization"
] = f"{token['token_type']} {token['access_token']}"
return input_request

def get_access_token(self) -> Dict[str, str]:
"""
Retrieves the token dictionary from Azure AD.

Returns
-------
dict
"""
token = self._get_access_token()
if "access_token" not in token:
error = token.get("error")
description = token.get("error_description")
raise AuthenticationError(
f"Unable to get token. Error: {error} (Details: {description})."
)
input_request.headers[
"Authorization"
] = f"{token['token_type']} {token['access_token']}"
return input_request
return token

@abstractmethod
def _get_access_token(self) -> Dict[str, str]:
"""
Retrieves the token dictionary from Azure AD.
Abstract method to return the token dictionary from Azure AD.

Returns
-------
Expand Down
Empty file added test/test_base_client.py
Empty file.
36 changes: 36 additions & 0 deletions test/test_client_credential.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,42 @@
from msal_requests_auth.exceptions import AuthenticationError


@patch("msal.ConfidentialClientApplication", autospec=True)
@patch(
"msal_requests_auth.auth.client_credential.ClientCredentialAuth._get_access_token"
)
def test_client_credential_auth__get_access_token__error(access_token_mock, cca_mock):
access_token_mock.return_value = {
"error": "BAD REQUEST",
"error_description": "Request to get token was bad.",
}
with pytest.raises(
AuthenticationError,
match=(
r"Unable to get token\. Error: BAD REQUEST "
r"\(Details: Request to get token was bad\.\)\."
),
):
ClientCredentialAuth(client=cca_mock, scopes=["TEST SCOPE"]).get_access_token()


@patch("msal.ConfidentialClientApplication", autospec=True)
@patch(
"msal_requests_auth.auth.client_credential.ClientCredentialAuth._get_access_token"
)
def test_client_credential_auth__get_access_token__valid(access_token_mock, cca_mock):
access_token_mock.return_value = {
"token_type": "Bearer",
"access_token": "TEST TOKEN",
}
assert ClientCredentialAuth(
client=cca_mock, scopes=["TEST SCOPE"]
).get_access_token() == {
"token_type": "Bearer",
"access_token": "TEST TOKEN",
}


@patch("msal.ConfidentialClientApplication", autospec=True)
def test_client_credential_auth__no_cache(cca_mock):
cca_mock.acquire_token_silent.return_value = None
Expand Down
36 changes: 35 additions & 1 deletion test/test_devide_code.py → test/test_device_code.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ def test_device_code_auth__headless(pca_mock, headless):
@patch("msal.PublicClientApplication", autospec=True)
@patch("msal_requests_auth.auth.device_code.webbrowser")
@patch("msal_requests_auth.auth.device_code.pyperclip")
def test_device_code_auth__no_accounts__unable_to_get_token(
def test_device_code_auth__no_accounts__unable_to_get_token__call(
pyperclip_patch, webbrowser_patch, pca_mock
):
pca_mock.get_accounts.return_value = None
Expand Down Expand Up @@ -102,6 +102,40 @@ def test_device_code_auth__no_accounts__unable_to_get_token(
pyperclip_patch.copy.assert_called_with("TEST CODE")


@patch.dict(os.environ, {}, clear=True)
@patch("msal.PublicClientApplication", autospec=True)
@patch("msal_requests_auth.auth.device_code.DeviceCodeAuth._get_access_token")
def test_device_code_auth__get_access_token__error(access_token_mock, pca_mock):
access_token_mock.return_value = {
"error": "BAD REQUEST",
"error_description": "Request to get token was bad.",
}
with pytest.raises(
AuthenticationError,
match=(
r"Unable to get token\. Error: BAD REQUEST "
r"\(Details: Request to get token was bad\.\)\."
),
):
DeviceCodeAuth(client=pca_mock, scopes=["TEST SCOPE"]).get_access_token()


@patch.dict(os.environ, {}, clear=True)
@patch("msal.PublicClientApplication", autospec=True)
@patch("msal_requests_auth.auth.device_code.DeviceCodeAuth._get_access_token")
def test_device_code_auth__get_access_token__valid(access_token_mock, pca_mock):
access_token_mock.return_value = {
"token_type": "Bearer",
"access_token": "TEST TOKEN",
}
assert DeviceCodeAuth(
client=pca_mock, scopes=["TEST SCOPE"]
).get_access_token() == {
"token_type": "Bearer",
"access_token": "TEST TOKEN",
}


@patch.dict(os.environ, {}, clear=True)
@patch("msal.PublicClientApplication", autospec=True)
@patch("msal_requests_auth.auth.device_code.webbrowser")
Expand Down
Loading