diff --git a/msal_requests_auth/auth/base_auth_client.py b/msal_requests_auth/auth/base_auth_client.py index b1e13f9..ffc24ef 100644 --- a/msal_requests_auth/auth/base_auth_client.py +++ b/msal_requests_auth/auth/base_auth_client.py @@ -45,6 +45,20 @@ def __call__( """ Adds the token to the authorization header. """ + token = self.get_access_token() + input_request.headers[ + "Authorization" + ] = f"{token['token_type']} {token['access_token']}" + return input_request + + def get_access_token(self) -> Dict[str, str]: + """ + Retrieves the token dictionary from Azure AD. + + Returns + ------- + dict + """ token = self._get_access_token() if "access_token" not in token: error = token.get("error") @@ -52,15 +66,12 @@ def __call__( raise AuthenticationError( f"Unable to get token. Error: {error} (Details: {description})." ) - input_request.headers[ - "Authorization" - ] = f"{token['token_type']} {token['access_token']}" - return input_request + return token @abstractmethod def _get_access_token(self) -> Dict[str, str]: """ - Retrieves the token dictionary from Azure AD. + Abstract method to return the token dictionary from Azure AD. Returns ------- diff --git a/test/test_base_client.py b/test/test_base_client.py new file mode 100644 index 0000000..e69de29 diff --git a/test/test_client_credential.py b/test/test_client_credential.py index abfb349..abd36a2 100644 --- a/test/test_client_credential.py +++ b/test/test_client_credential.py @@ -6,6 +6,42 @@ from msal_requests_auth.exceptions import AuthenticationError +@patch("msal.ConfidentialClientApplication", autospec=True) +@patch( + "msal_requests_auth.auth.client_credential.ClientCredentialAuth._get_access_token" +) +def test_client_credential_auth__get_access_token__error(access_token_mock, cca_mock): + access_token_mock.return_value = { + "error": "BAD REQUEST", + "error_description": "Request to get token was bad.", + } + with pytest.raises( + AuthenticationError, + match=( + r"Unable to get token\. Error: BAD REQUEST " + r"\(Details: Request to get token was bad\.\)\." + ), + ): + ClientCredentialAuth(client=cca_mock, scopes=["TEST SCOPE"]).get_access_token() + + +@patch("msal.ConfidentialClientApplication", autospec=True) +@patch( + "msal_requests_auth.auth.client_credential.ClientCredentialAuth._get_access_token" +) +def test_client_credential_auth__get_access_token__valid(access_token_mock, cca_mock): + access_token_mock.return_value = { + "token_type": "Bearer", + "access_token": "TEST TOKEN", + } + assert ClientCredentialAuth( + client=cca_mock, scopes=["TEST SCOPE"] + ).get_access_token() == { + "token_type": "Bearer", + "access_token": "TEST TOKEN", + } + + @patch("msal.ConfidentialClientApplication", autospec=True) def test_client_credential_auth__no_cache(cca_mock): cca_mock.acquire_token_silent.return_value = None diff --git a/test/test_devide_code.py b/test/test_device_code.py similarity index 85% rename from test/test_devide_code.py rename to test/test_device_code.py index b81f7d1..60768b3 100644 --- a/test/test_devide_code.py +++ b/test/test_device_code.py @@ -66,7 +66,7 @@ def test_device_code_auth__headless(pca_mock, headless): @patch("msal.PublicClientApplication", autospec=True) @patch("msal_requests_auth.auth.device_code.webbrowser") @patch("msal_requests_auth.auth.device_code.pyperclip") -def test_device_code_auth__no_accounts__unable_to_get_token( +def test_device_code_auth__no_accounts__unable_to_get_token__call( pyperclip_patch, webbrowser_patch, pca_mock ): pca_mock.get_accounts.return_value = None @@ -102,6 +102,40 @@ def test_device_code_auth__no_accounts__unable_to_get_token( pyperclip_patch.copy.assert_called_with("TEST CODE") +@patch.dict(os.environ, {}, clear=True) +@patch("msal.PublicClientApplication", autospec=True) +@patch("msal_requests_auth.auth.device_code.DeviceCodeAuth._get_access_token") +def test_device_code_auth__get_access_token__error(access_token_mock, pca_mock): + access_token_mock.return_value = { + "error": "BAD REQUEST", + "error_description": "Request to get token was bad.", + } + with pytest.raises( + AuthenticationError, + match=( + r"Unable to get token\. Error: BAD REQUEST " + r"\(Details: Request to get token was bad\.\)\." + ), + ): + DeviceCodeAuth(client=pca_mock, scopes=["TEST SCOPE"]).get_access_token() + + +@patch.dict(os.environ, {}, clear=True) +@patch("msal.PublicClientApplication", autospec=True) +@patch("msal_requests_auth.auth.device_code.DeviceCodeAuth._get_access_token") +def test_device_code_auth__get_access_token__valid(access_token_mock, pca_mock): + access_token_mock.return_value = { + "token_type": "Bearer", + "access_token": "TEST TOKEN", + } + assert DeviceCodeAuth( + client=pca_mock, scopes=["TEST SCOPE"] + ).get_access_token() == { + "token_type": "Bearer", + "access_token": "TEST TOKEN", + } + + @patch.dict(os.environ, {}, clear=True) @patch("msal.PublicClientApplication", autospec=True) @patch("msal_requests_auth.auth.device_code.webbrowser")