diff --git a/test/test_oauth_handler.py b/test/test_oauth_handler.py index 50e031e..f8cf977 100644 --- a/test/test_oauth_handler.py +++ b/test/test_oauth_handler.py @@ -12,6 +12,7 @@ from iolite_client.oauth_handler import ( AsyncOAuthHandler, + AsyncOAuthWrapper, OAuthHandler, OAuthHandlerHelper, OAuthWrapper, @@ -124,8 +125,7 @@ def test_invalid_token_refresh(self): self.mock_oauth_handler.get_new_access_token.return_value = response - oauth_wrapper = OAuthWrapper(self.mock_oauth_handler, self.mock_oauth_storage) - oauth_wrapper.get_sid(token) + self.oauth_wrapper.get_sid(token) self.assertEqual(self.mock_oauth_handler.get_sid.call_count, 2) @@ -136,3 +136,66 @@ def _get_token(date_time: datetime.datetime) -> dict: "access_token": "access-token", "refresh_token": "refresh-token", } + + +class AsyncOAuthWrapperTest(unittest.TestCase): + def setUp(self) -> None: + self.mock_async_oauth_handler = Mock() + self.mock_async_oauth_storage = Mock() + self.async_oauth_wrapper = AsyncOAuthWrapper( + self.mock_async_oauth_handler, self.mock_async_oauth_storage + ) + + @pytest.mark.enable_socket + @pytest.mark.asyncio + @freeze_time("2021-01-01 00:00:00") + async def test_get_sid_valid_access_token(self): + token = self._get_token(datetime.datetime(2021, 1, 1, 0, 0, 1)) + await self.async_oauth_wrapper.get_sid(token) + self.mock_async_oauth_handler.get_sid.assert_called_once_with("access-token") + + @pytest.mark.enable_socket + @pytest.mark.asyncio + @freeze_time("2021-01-01 00:00:01") + async def test_get_sid_expired_access_token(self): + token = self._get_token(datetime.datetime(2021, 1, 1, 0, 0, 0)) + response = self._get_token(datetime.datetime(2021, 1, 10, 0, 0, 0)) + self.mock_async_oauth_handler.get_new_access_token.return_value = response + + await self.async_oauth_wrapper.get_sid(token) + self.mock_async_oauth_handler.get_new_access_token.assert_called_once_with( + "refresh-token" + ) + self.mock_async_oauth_storage.store_access_token.assert_called_once_with( + response + ) + self.mock_async_oauth_handler.get_sid.assert_called_once_with("access-token") + + @pytest.mark.enable_socket + @pytest.mark.asyncio + @freeze_time("2021-01-01 00:00:00") + async def test_invalid_token_refresh(self): + + token = self._get_token(datetime.datetime(2021, 1, 1, 0, 0, 1)) + self.mock_async_oauth_storage.fetch_access_token.return_value = token + + self.mock_async_oauth_handler.get_sid.side_effect = [ + HTTPError("Something went wrong"), + "sid", + ] + + response = self._get_token(datetime.datetime(2021, 1, 10, 0, 0, 0)) + + self.mock_async_oauth_handler.get_new_access_token.return_value = response + + await self.async_oauth_wrapper.get_sid(token) + + self.assertEqual(self.mock_async_oauth_handler.get_sid.call_count, 2) + + @staticmethod + def _get_token(date_time: datetime.datetime) -> dict: + return { + "expires_at": date_time.timestamp(), + "access_token": "access-token", + "refresh_token": "refresh-token", + }