diff --git a/tests/conftest.py b/tests/conftest.py index a02a6e6c..e48f35c9 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,3 +1,4 @@ +from typing import Mapping, Union from unittest.mock import AsyncMock, MagicMock import httpx @@ -6,7 +7,7 @@ from tests.utils.list_resource import list_response_of import workos -from workos.utils.http_client import SyncHTTPClient +from workos.utils.http_client import AsyncHTTPClient, SyncHTTPClient class MockResponse(object): @@ -124,20 +125,69 @@ def mock(*args, **kwargs): @pytest.fixture -def mock_sync_http_client_with_response(): - def inner(http_client: SyncHTTPClient, response_dict: dict, status_code: int): - http_client._client.request = MagicMock( - return_value=httpx.Response(status_code, json=response_dict), +def mock_http_client_with_response(monkeypatch): + def inner( + http_client: Union[SyncHTTPClient, AsyncHTTPClient], + response_dict: dict, + status_code: int = 200, + headers: Mapping[str, str] = None, + ): + mock_class = ( + AsyncMock if isinstance(http_client, AsyncHTTPClient) else MagicMock ) + mock = mock_class( + return_value=httpx.Response( + status_code=status_code, headers=headers, json=response_dict + ), + ) + monkeypatch.setattr(http_client._client, "request", mock) return inner @pytest.fixture -def mock_async_http_client_with_response(): - def inner(http_client: SyncHTTPClient, response_dict: dict, status_code: int): - http_client._client.request = AsyncMock( - return_value=httpx.Response(status_code, json=response_dict), +def mock_pagination_request_for_http_client(monkeypatch): + # Mocking pagination correctly requires us to index into a list of data + # and correctly set the before and after metadata in the response. + def inner( + http_client: Union[SyncHTTPClient, AsyncHTTPClient], + data_list: list, + status_code: int = 200, + headers: Mapping[str, str] = None, + ): + # For convenient index lookup, store the list of object IDs. + data_ids = list(map(lambda x: x["id"], data_list)) + + def mock_function(*args, **kwargs): + params = kwargs.get("params") or {} + request_after = params.get("after", None) + limit = params.get("limit", 10) + + if request_after is None: + # First page + start = 0 + else: + # A subsequent page, return the first item _after_ the index we locate + start = data_ids.index(request_after) + 1 + data = data_list[start : start + limit] + if len(data) < limit or len(data) == 0: + # No more data, set after to None + after = None + else: + # Set after to the last item in this page of results + after = data[-1]["id"] + + return httpx.Response( + status_code=status_code, + headers=headers, + json=list_response_of(data=data, before=request_after, after=after), + ) + + mock_class = ( + AsyncMock if isinstance(http_client, AsyncHTTPClient) else MagicMock ) + mock = mock_class(side_effect=mock_function) + + monkeypatch.setattr(http_client._client, "request", mock) return inner diff --git a/tests/test_async_events.py b/tests/test_async_events.py deleted file mode 100644 index 19f99d98..00000000 --- a/tests/test_async_events.py +++ /dev/null @@ -1,83 +0,0 @@ -import pytest - -from tests.utils.fixtures.mock_event import MockEvent -from workos.events import AsyncEvents -from workos.utils.http_client import AsyncHTTPClient - - -class TestAsyncEvents(object): - @pytest.fixture(autouse=True) - def setup( - self, - set_api_key, - set_client_id, - ): - self.http_client = AsyncHTTPClient( - base_url="https://api.workos.test", version="test" - ) - self.events = AsyncEvents(http_client=self.http_client) - - @pytest.fixture - def mock_events(self): - events = [MockEvent(id=str(i)).to_dict() for i in range(10)] - - return { - "data": events, - "metadata": { - "params": { - "events": ["dsync.user.created"], - "limit": 10, - "organization_id": None, - "after": None, - "range_start": None, - "range_end": None, - }, - "method": AsyncEvents.list_events, - }, - } - - @pytest.mark.asyncio - async def test_list_events(self, mock_events, mock_async_http_client_with_response): - mock_async_http_client_with_response( - http_client=self.http_client, - status_code=200, - response_dict={"data": mock_events["data"]}, - ) - - events = await self.events.list_events(events=["dsync.user.created"]) - - assert events == mock_events - - @pytest.mark.asyncio - async def test_list_events_returns_metadata( - self, mock_events, mock_async_http_client_with_response - ): - mock_async_http_client_with_response( - http_client=self.http_client, - status_code=200, - response_dict={"data": mock_events["data"]}, - ) - - events = await self.events.list_events( - events=["dsync.user.created"], - ) - - assert events["metadata"]["params"]["events"] == ["dsync.user.created"] - - @pytest.mark.asyncio - async def test_list_events_with_organization_id_returns_metadata( - self, mock_events, mock_async_http_client_with_response - ): - mock_async_http_client_with_response( - http_client=self.http_client, - status_code=200, - response_dict={"data": mock_events["data"]}, - ) - - events = await self.events.list_events( - events=["dsync.user.created"], - organization_id="org_1234", - ) - - assert events["metadata"]["params"]["organization_id"] == "org_1234" - assert events["metadata"]["params"]["events"] == ["dsync.user.created"] diff --git a/tests/test_client.py b/tests/test_client.py index c9b5c64b..d7eabea9 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -1,6 +1,6 @@ import pytest -from workos import client +from workos import async_client, client from workos.exceptions import ConfigurationException @@ -9,6 +9,7 @@ class TestClient(object): def setup(self): client._audit_logs = None client._directory_sync = None + client._events = None client._organizations = None client._passwordless = None client._portal = None @@ -24,6 +25,9 @@ def test_initialize_audit_logs(self, set_api_key): def test_initialize_directory_sync(self, set_api_key): assert bool(client.directory_sync) + def test_initialize_events(self, set_api_key): + assert bool(client.events) + def test_initialize_organizations(self, set_api_key): assert bool(client.organizations) @@ -76,6 +80,14 @@ def test_initialize_directory_sync_missing_api_key(self): assert "api_key" in message + def test_initialize_events_missing_api_key(self): + with pytest.raises(ConfigurationException) as ex: + client.events + + message = str(ex) + + assert "api_key" in message + def test_initialize_organizations_missing_api_key(self): with pytest.raises(ConfigurationException) as ex: client.organizations @@ -124,3 +136,38 @@ def test_initialize_user_management_missing_api_key_and_client_id(self): assert "api_key" in message assert "client_id" in message + + +class TestAsyncClient(object): + @pytest.fixture(autouse=True) + def setup(self): + async_client._audit_logs = None + async_client._directory_sync = None + async_client._events = None + async_client._organizations = None + async_client._passwordless = None + async_client._portal = None + async_client._sso = None + async_client._user_management = None + + def test_initialize_directory_sync(self, set_api_key): + assert bool(async_client.directory_sync) + + def test_initialize_directory_sync_missing_api_key(self): + with pytest.raises(ConfigurationException) as ex: + async_client.directory_sync + + message = str(ex) + + assert "api_key" in message + + def test_initialize_events(self, set_api_key): + assert bool(async_client.events) + + def test_initialize_events_missing_api_key(self): + with pytest.raises(ConfigurationException) as ex: + async_client.events + + message = str(ex) + + assert "api_key" in message diff --git a/tests/test_directory_sync.py b/tests/test_directory_sync.py index 3afd73e0..5ed9e645 100644 --- a/tests/test_directory_sync.py +++ b/tests/test_directory_sync.py @@ -1,17 +1,14 @@ import pytest from tests.utils.list_resource import list_data_to_dicts, list_response_of -from workos.directory_sync import DirectorySync +from workos.directory_sync import AsyncDirectorySync, DirectorySync +from workos.utils.http_client import AsyncHTTPClient, SyncHTTPClient from tests.utils.fixtures.mock_directory import MockDirectory from tests.utils.fixtures.mock_directory_user import MockDirectoryUser from tests.utils.fixtures.mock_directory_group import MockDirectoryGroup -class TestDirectorySync(object): - @pytest.fixture(autouse=True) - def setup(self, set_api_key, set_client_id): - self.directory_sync = DirectorySync() - +class DirectorySyncFixtures: @pytest.fixture def mock_users(self): user_list = [MockDirectoryUser(id=str(i)).to_dict() for i in range(100)] @@ -101,43 +98,72 @@ def mock_directory_groups_multiple_data_pages(self): def mock_directory(self): return MockDirectory("directory_id").to_dict() - def test_list_users_with_directory(self, mock_users, mock_request_method): - mock_request_method("get", mock_users, 200) + +class TestDirectorySync(DirectorySyncFixtures): + @pytest.fixture(autouse=True) + def setup(self, set_api_key, set_client_id): + self.http_client = SyncHTTPClient( + base_url="https://api.workos.test", version="test" + ) + self.directory_sync = DirectorySync(http_client=self.http_client) + + def test_list_users_with_directory( + self, mock_users, mock_http_client_with_response + ): + mock_http_client_with_response( + http_client=self.http_client, status_code=200, response_dict=mock_users + ) users = self.directory_sync.list_users(directory="directory_id") assert list_data_to_dicts(users.data) == mock_users["data"] - def test_list_users_with_group(self, mock_users, mock_request_method): - mock_request_method("get", mock_users, 200) + def test_list_users_with_group(self, mock_users, mock_http_client_with_response): + mock_http_client_with_response( + http_client=self.http_client, status_code=200, response_dict=mock_users + ) users = self.directory_sync.list_users(group="directory_grp_id") assert list_data_to_dicts(users.data) == mock_users["data"] - def test_list_groups_with_directory(self, mock_groups, mock_request_method): - mock_request_method("get", mock_groups, 200) + def test_list_groups_with_directory( + self, mock_groups, mock_http_client_with_response + ): + mock_http_client_with_response( + http_client=self.http_client, status_code=200, response_dict=mock_groups + ) groups = self.directory_sync.list_groups(directory="directory_id") assert list_data_to_dicts(groups.data) == mock_groups["data"] - def test_list_groups_with_user(self, mock_groups, mock_request_method): - mock_request_method("get", mock_groups, 200) + def test_list_groups_with_user(self, mock_groups, mock_http_client_with_response): + mock_http_client_with_response( + http_client=self.http_client, status_code=200, response_dict=mock_groups + ) groups = self.directory_sync.list_groups(user="directory_usr_id") assert list_data_to_dicts(groups.data) == mock_groups["data"] - def test_get_user(self, mock_user, mock_request_method): - mock_request_method("get", mock_user, 200) + def test_get_user(self, mock_user, mock_http_client_with_response): + mock_http_client_with_response( + http_client=self.http_client, + status_code=200, + response_dict=mock_user, + ) user = self.directory_sync.get_user(user="directory_usr_id") assert user.dict() == mock_user - def test_get_group(self, mock_group, mock_request_method): - mock_request_method("get", mock_group, 200) + def test_get_group(self, mock_group, mock_http_client_with_response): + mock_http_client_with_response( + http_client=self.http_client, + status_code=200, + response_dict=mock_group, + ) group = self.directory_sync.get_group( group="directory_group_01FHGRYAQ6ERZXXXXXX1E01QFE" @@ -145,25 +171,33 @@ def test_get_group(self, mock_group, mock_request_method): assert group.dict() == mock_group - def test_list_directories(self, mock_directories, mock_request_method): - mock_request_method("get", mock_directories, 200) + def test_list_directories(self, mock_directories, mock_http_client_with_response): + mock_http_client_with_response( + http_client=self.http_client, + status_code=200, + response_dict=mock_directories, + ) directories = self.directory_sync.list_directories() assert list_data_to_dicts(directories.data) == mock_directories["data"] - def test_get_directory(self, mock_directory, mock_request_method): - mock_request_method("get", mock_directory, 200) + def test_get_directory(self, mock_directory, mock_http_client_with_response): + mock_http_client_with_response( + http_client=self.http_client, + status_code=200, + response_dict=mock_directory, + ) directory = self.directory_sync.get_directory(directory="directory_id") assert directory.dict() == mock_directory - def test_delete_directory(self, mock_directories, mock_raw_request_method): - mock_raw_request_method( - "delete", - "Accepted", - 202, + def test_delete_directory(self, mock_http_client_with_response): + mock_http_client_with_response( + http_client=self.http_client, + status_code=202, + response_dict=None, headers={"content-type": "text/plain; charset=utf-8"}, ) @@ -172,9 +206,13 @@ def test_delete_directory(self, mock_directories, mock_raw_request_method): assert response is None def test_primary_email( - self, mock_user, mock_user_primary_email, mock_request_method + self, mock_user, mock_user_primary_email, mock_http_client_with_response ): - mock_request_method("get", mock_user, 200) + mock_http_client_with_response( + http_client=self.http_client, + status_code=200, + response_dict=mock_user, + ) mock_user_instance = self.directory_sync.get_user( "directory_user_01E1JG7J09H96KYP8HM9B0G5SJ" ) @@ -182,8 +220,14 @@ def test_primary_email( assert primary_email assert primary_email.dict() == mock_user_primary_email - def test_primary_email_none(self, mock_user_no_email, mock_request_method): - mock_request_method("get", mock_user_no_email, 200) + def test_primary_email_none( + self, mock_user_no_email, mock_http_client_with_response + ): + mock_http_client_with_response( + http_client=self.http_client, + status_code=200, + response_dict=mock_user_no_email, + ) mock_user_instance = self.directory_sync.get_user( "directory_user_01E1JG7J09H96KYP8HM9B0G5SJ" ) @@ -195,9 +239,13 @@ def test_primary_email_none(self, mock_user_no_email, mock_request_method): def test_list_directories_auto_pagination( self, mock_directories_multiple_data_pages, - mock_pagination_request, + mock_pagination_request_for_http_client, ): - mock_pagination_request("get", mock_directories_multiple_data_pages, 200) + mock_pagination_request_for_http_client( + http_client=self.http_client, + data_list=mock_directories_multiple_data_pages, + status_code=200, + ) directories = self.directory_sync.list_directories() all_directories = [] @@ -213,9 +261,13 @@ def test_list_directories_auto_pagination( def test_directory_users_auto_pagination( self, mock_directory_users_multiple_data_pages, - mock_pagination_request, + mock_pagination_request_for_http_client, ): - mock_pagination_request("get", mock_directory_users_multiple_data_pages, 200) + mock_pagination_request_for_http_client( + http_client=self.http_client, + data_list=mock_directory_users_multiple_data_pages, + status_code=200, + ) users = self.directory_sync.list_users() all_users = [] @@ -231,9 +283,13 @@ def test_directory_users_auto_pagination( def test_directory_user_groups_auto_pagination( self, mock_directory_groups_multiple_data_pages, - mock_pagination_request, + mock_pagination_request_for_http_client, ): - mock_pagination_request("get", mock_directory_groups_multiple_data_pages, 200) + mock_pagination_request_for_http_client( + http_client=self.http_client, + data_list=mock_directory_groups_multiple_data_pages, + status_code=200, + ) groups = self.directory_sync.list_groups() all_groups = [] @@ -249,10 +305,14 @@ def test_directory_user_groups_auto_pagination( def test_auto_pagination_honors_limit( self, mock_directories_multiple_data_pages, - mock_pagination_request, + mock_pagination_request_for_http_client, ): # TODO: This does not actually test anything about the limit. - mock_pagination_request("get", mock_directories_multiple_data_pages, 200) + mock_pagination_request_for_http_client( + http_client=self.http_client, + data_list=mock_directories_multiple_data_pages, + status_code=200, + ) directories = self.directory_sync.list_directories() all_directories = [] @@ -264,3 +324,238 @@ def test_auto_pagination_honors_limit( assert ( list_data_to_dicts(all_directories) ) == mock_directories_multiple_data_pages + + +@pytest.mark.asyncio +class TestAsyncDirectorySync(DirectorySyncFixtures): + @pytest.fixture(autouse=True) + def setup(self, set_api_key, set_client_id): + self.http_client = AsyncHTTPClient( + base_url="https://api.workos.test", + version="test", + ) + self.directory_sync = AsyncDirectorySync(http_client=self.http_client) + + async def test_list_users_with_directory( + self, mock_users, mock_http_client_with_response + ): + mock_http_client_with_response( + http_client=self.http_client, status_code=200, response_dict=mock_users + ) + + users = await self.directory_sync.list_users(directory="directory_id") + + assert list_data_to_dicts(users.data) == mock_users["data"] + + async def test_list_users_with_group( + self, mock_users, mock_http_client_with_response + ): + mock_http_client_with_response( + http_client=self.http_client, status_code=200, response_dict=mock_users + ) + + users = await self.directory_sync.list_users(group="directory_grp_id") + + assert list_data_to_dicts(users.data) == mock_users["data"] + + async def test_list_groups_with_directory( + self, mock_groups, mock_http_client_with_response + ): + mock_http_client_with_response( + http_client=self.http_client, status_code=200, response_dict=mock_groups + ) + + groups = await self.directory_sync.list_groups(directory="directory_id") + + assert list_data_to_dicts(groups.data) == mock_groups["data"] + + async def test_list_groups_with_user( + self, mock_groups, mock_http_client_with_response + ): + mock_http_client_with_response( + http_client=self.http_client, status_code=200, response_dict=mock_groups + ) + + groups = await self.directory_sync.list_groups(user="directory_usr_id") + + assert list_data_to_dicts(groups.data) == mock_groups["data"] + + async def test_get_user(self, mock_user, mock_http_client_with_response): + mock_http_client_with_response( + http_client=self.http_client, + status_code=200, + response_dict=mock_user, + ) + + user = await self.directory_sync.get_user(user="directory_usr_id") + + assert user.dict() == mock_user + + async def test_get_group(self, mock_group, mock_http_client_with_response): + mock_http_client_with_response( + http_client=self.http_client, + status_code=200, + response_dict=mock_group, + ) + + group = await self.directory_sync.get_group( + group="directory_group_01FHGRYAQ6ERZXXXXXX1E01QFE" + ) + + assert group.dict() == mock_group + + async def test_list_directories( + self, mock_directories, mock_http_client_with_response + ): + mock_http_client_with_response( + http_client=self.http_client, + status_code=200, + response_dict=mock_directories, + ) + + directories = await self.directory_sync.list_directories() + + assert list_data_to_dicts(directories.data) == mock_directories["data"] + + async def test_get_directory(self, mock_directory, mock_http_client_with_response): + mock_http_client_with_response( + http_client=self.http_client, + status_code=200, + response_dict=mock_directory, + ) + + directory = await self.directory_sync.get_directory(directory="directory_id") + + assert directory.dict() == mock_directory + + async def test_delete_directory(self, mock_http_client_with_response): + mock_http_client_with_response( + http_client=self.http_client, + status_code=202, + response_dict=None, + headers={"content-type": "text/plain; charset=utf-8"}, + ) + + response = await self.directory_sync.delete_directory(directory="directory_id") + + assert response is None + + async def test_primary_email( + self, mock_user, mock_user_primary_email, mock_http_client_with_response + ): + mock_http_client_with_response( + http_client=self.http_client, + status_code=200, + response_dict=mock_user, + ) + mock_user_instance = await self.directory_sync.get_user( + "directory_user_01E1JG7J09H96KYP8HM9B0G5SJ" + ) + primary_email = mock_user_instance.primary_email() + assert primary_email + assert primary_email.dict() == mock_user_primary_email + + async def test_primary_email_none( + self, mock_user_no_email, mock_http_client_with_response + ): + mock_http_client_with_response( + http_client=self.http_client, + status_code=200, + response_dict=mock_user_no_email, + ) + mock_user_instance = await self.directory_sync.get_user( + "directory_user_01E1JG7J09H96KYP8HM9B0G5SJ" + ) + + me = mock_user_instance.primary_email() + + assert me == None + + async def test_list_directories_auto_pagination( + self, + mock_directories_multiple_data_pages, + mock_pagination_request_for_http_client, + ): + mock_pagination_request_for_http_client( + http_client=self.http_client, + data_list=mock_directories_multiple_data_pages, + status_code=200, + ) + + directories = await self.directory_sync.list_directories() + all_directories = [] + + async for directory in directories.auto_paging_iter(): + all_directories.append(directory) + + assert len(list(all_directories)) == len(mock_directories_multiple_data_pages) + assert ( + list_data_to_dicts(all_directories) + ) == mock_directories_multiple_data_pages + + async def test_directory_users_auto_pagination( + self, + mock_directory_users_multiple_data_pages, + mock_pagination_request_for_http_client, + ): + mock_pagination_request_for_http_client( + http_client=self.http_client, + data_list=mock_directory_users_multiple_data_pages, + status_code=200, + ) + + users = await self.directory_sync.list_users() + all_users = [] + + async for user in users.auto_paging_iter(): + all_users.append(user) + + assert len(list(all_users)) == len(mock_directory_users_multiple_data_pages) + assert ( + list_data_to_dicts(all_users) + ) == mock_directory_users_multiple_data_pages + + async def test_directory_user_groups_auto_pagination( + self, + mock_directory_groups_multiple_data_pages, + mock_pagination_request_for_http_client, + ): + mock_pagination_request_for_http_client( + http_client=self.http_client, + data_list=mock_directory_groups_multiple_data_pages, + status_code=200, + ) + + groups = await self.directory_sync.list_groups() + all_groups = [] + + async for group in groups.auto_paging_iter(): + all_groups.append(group) + + assert len(list(all_groups)) == len(mock_directory_groups_multiple_data_pages) + assert ( + list_data_to_dicts(all_groups) + ) == mock_directory_groups_multiple_data_pages + + async def test_auto_pagination_honors_limit( + self, + mock_directories_multiple_data_pages, + mock_pagination_request_for_http_client, + ): + # TODO: This does not actually test anything about the limit. + mock_pagination_request_for_http_client( + http_client=self.http_client, + data_list=mock_directories_multiple_data_pages, + status_code=200, + ) + + directories = await self.directory_sync.list_directories() + all_directories = [] + + async for directory in directories.auto_paging_iter(): + all_directories.append(directory) + + assert len(list(all_directories)) == len(mock_directories_multiple_data_pages) + assert ( + list_data_to_dicts(all_directories) + ) == mock_directories_multiple_data_pages diff --git a/tests/test_events.py b/tests/test_events.py index 56ab3b57..ebaca488 100644 --- a/tests/test_events.py +++ b/tests/test_events.py @@ -1,8 +1,8 @@ import pytest from tests.utils.fixtures.mock_event import MockEvent -from workos.events import Events -from workos.utils.http_client import SyncHTTPClient +from workos.events import AsyncEvents, Events +from workos.utils.http_client import AsyncHTTPClient, SyncHTTPClient class TestEvents(object): @@ -36,8 +36,8 @@ def mock_events(self): }, } - def test_list_events(self, mock_events, mock_sync_http_client_with_response): - mock_sync_http_client_with_response( + def test_list_events(self, mock_events, mock_http_client_with_response): + mock_http_client_with_response( http_client=self.http_client, status_code=200, response_dict={"data": mock_events["data"]}, @@ -48,9 +48,9 @@ def test_list_events(self, mock_events, mock_sync_http_client_with_response): assert events == mock_events def test_list_events_returns_metadata( - self, mock_events, mock_sync_http_client_with_response + self, mock_events, mock_http_client_with_response ): - mock_sync_http_client_with_response( + mock_http_client_with_response( http_client=self.http_client, status_code=200, response_dict={"data": mock_events["data"]}, @@ -63,9 +63,9 @@ def test_list_events_returns_metadata( assert events["metadata"]["params"]["events"] == ["dsync.user.created"] def test_list_events_with_organization_id_returns_metadata( - self, mock_events, mock_sync_http_client_with_response + self, mock_events, mock_http_client_with_response ): - mock_sync_http_client_with_response( + mock_http_client_with_response( http_client=self.http_client, status_code=200, response_dict={"data": mock_events["data"]}, @@ -78,3 +78,79 @@ def test_list_events_with_organization_id_returns_metadata( assert events["metadata"]["params"]["organization_id"] == "org_1234" assert events["metadata"]["params"]["events"] == ["dsync.user.created"] + + +@pytest.mark.asyncio +class TestAsyncEvents(object): + @pytest.fixture(autouse=True) + def setup( + self, + set_api_key, + set_client_id, + ): + self.http_client = AsyncHTTPClient( + base_url="https://api.workos.test", version="test" + ) + self.events = AsyncEvents(http_client=self.http_client) + + @pytest.fixture + def mock_events(self): + events = [MockEvent(id=str(i)).to_dict() for i in range(10)] + + return { + "data": events, + "metadata": { + "params": { + "events": ["dsync.user.created"], + "limit": 10, + "organization_id": None, + "after": None, + "range_start": None, + "range_end": None, + }, + "method": AsyncEvents.list_events, + }, + } + + async def test_list_events(self, mock_events, mock_http_client_with_response): + mock_http_client_with_response( + http_client=self.http_client, + status_code=200, + response_dict={"data": mock_events["data"]}, + ) + + events = await self.events.list_events(events=["dsync.user.created"]) + + assert events == mock_events + + async def test_list_events_returns_metadata( + self, mock_events, mock_http_client_with_response + ): + mock_http_client_with_response( + http_client=self.http_client, + status_code=200, + response_dict={"data": mock_events["data"]}, + ) + + events = await self.events.list_events( + events=["dsync.user.created"], + ) + + assert events["metadata"]["params"]["events"] == ["dsync.user.created"] + + async def test_list_events_with_organization_id_returns_metadata( + self, mock_events, mock_http_client_with_response + ): + mock_http_client_with_response( + http_client=self.http_client, + status_code=200, + response_dict={"data": mock_events["data"]}, + ) + + events = await self.events.list_events( + events=["dsync.user.created"], + organization_id="org_1234", + ) + + assert events["metadata"]["params"]["organization_id"] == "org_1234" + assert events["metadata"]["params"]["events"] == ["dsync.user.created"] diff --git a/workos/async_client.py b/workos/async_client.py index 513ad711..c238c387 100644 --- a/workos/async_client.py +++ b/workos/async_client.py @@ -1,4 +1,5 @@ from workos._base_client import BaseClient +from workos.directory_sync import AsyncDirectorySync from workos.events import AsyncEvents from workos.utils.http_client import AsyncHTTPClient @@ -25,9 +26,9 @@ def audit_logs(self): @property def directory_sync(self): - raise NotImplementedError( - "Directory Sync APIs are not yet supported in the async client." - ) + if not getattr(self, "_directory_sync", None): + self._directory_sync = AsyncDirectorySync(self._http_client) + return self._directory_sync @property def events(self): diff --git a/workos/client.py b/workos/client.py index a2ae7748..82803ed7 100644 --- a/workos/client.py +++ b/workos/client.py @@ -37,7 +37,7 @@ def audit_logs(self): @property def directory_sync(self): if not getattr(self, "_directory_sync", None): - self._directory_sync = DirectorySync() + self._directory_sync = DirectorySync(self._http_client) return self._directory_sync @property diff --git a/workos/directory_sync.py b/workos/directory_sync.py index 5fb72ca8..15c497b7 100644 --- a/workos/directory_sync.py +++ b/workos/directory_sync.py @@ -1,11 +1,9 @@ from typing import Optional, Protocol import workos +from workos.typing.sync_or_async import SyncOrAsync +from workos.utils.http_client import AsyncHTTPClient, SyncHTTPClient from workos.utils.pagination_order import PaginationOrder -from workos.utils.request import ( - RequestHelper, - REQUEST_METHOD_DELETE, - REQUEST_METHOD_GET, -) +from workos.utils.request import REQUEST_METHOD_DELETE, REQUEST_METHOD_GET from workos.utils.validation import DIRECTORY_SYNC_MODULE, validate_settings from workos.resources.directory_sync import ( @@ -13,7 +11,13 @@ Directory, DirectoryUser, ) -from workos.resources.list import ListArgs, ListPage, WorkOsListResource +from workos.resources.list import ( + ListArgs, + ListPage, + AsyncWorkOsListResource, + SyncOrAsyncListResource, + WorkOsListResource, +) RESPONSE_LIMIT = 10 @@ -47,7 +51,7 @@ def list_users( before: Optional[str] = None, after: Optional[str] = None, order: PaginationOrder = "desc", - ) -> WorkOsListResource[DirectoryUser, DirectoryUserListFilters]: + ) -> SyncOrAsyncListResource: ... def list_groups( @@ -58,16 +62,16 @@ def list_groups( before: Optional[str] = None, after: Optional[str] = None, order: PaginationOrder = "desc", - ) -> WorkOsListResource[DirectoryGroup, DirectoryGroupListFilters]: + ) -> SyncOrAsyncListResource: ... - def get_user(self, user: str) -> DirectoryUser: + def get_user(self, user: str) -> SyncOrAsync[DirectoryUser]: ... - def get_group(self, group: str) -> DirectoryGroup: + def get_group(self, group: str) -> SyncOrAsync[DirectoryGroup]: ... - def get_directory(self, directory: str) -> Directory: + def get_directory(self, directory: str) -> SyncOrAsync[Directory]: ... def list_directories( @@ -79,31 +83,27 @@ def list_directories( after: Optional[str] = None, organization: Optional[str] = None, order: PaginationOrder = "desc", - ) -> WorkOsListResource[Directory, DirectoryListFilters]: + ) -> SyncOrAsyncListResource: ... - def delete_directory(self, directory: str) -> None: + def delete_directory(self, directory: str) -> SyncOrAsync[None]: ... class DirectorySync(DirectorySyncModule): """Offers methods through the WorkOS Directory Sync service.""" - @validate_settings(DIRECTORY_SYNC_MODULE) - def __init__(self): - pass + _http_client: SyncHTTPClient - @property - def request_helper(self): - if not getattr(self, "_request_helper", None): - self._request_helper = RequestHelper() - return self._request_helper + @validate_settings(DIRECTORY_SYNC_MODULE) + def __init__(self, http_client: SyncHTTPClient): + self._http_client = http_client def list_users( self, directory: Optional[str] = None, group: Optional[str] = None, - limit: int = RESPONSE_LIMIT, + limit: Optional[int] = RESPONSE_LIMIT, before: Optional[str] = None, after: Optional[str] = None, order: PaginationOrder = "desc", @@ -125,7 +125,7 @@ def list_users( """ list_params: DirectoryUserListFilters = { - "limit": limit, + "limit": limit if limit is not None else RESPONSE_LIMIT, "before": before, "after": after, "order": order, @@ -136,7 +136,7 @@ def list_users( if directory is not None: list_params["directory"] = directory - response = self.request_helper.request( + response = self._http_client.request( "directory_users", method=REQUEST_METHOD_GET, params=list_params, @@ -185,7 +185,7 @@ def list_groups( if directory is not None: list_params["directory"] = directory - response = self.request_helper.request( + response = self._http_client.request( "directory_groups", method=REQUEST_METHOD_GET, params=list_params, @@ -198,7 +198,7 @@ def list_groups( **ListPage[DirectoryGroup](**response).model_dump(), ) - def get_user(self, user: str): + def get_user(self, user: str) -> DirectoryUser: """Gets details for a single provisioned Directory User. Args: @@ -207,7 +207,7 @@ def get_user(self, user: str): Returns: dict: Directory User response from WorkOS. """ - response = self.request_helper.request( + response = self._http_client.request( "directory_users/{user}".format(user=user), method=REQUEST_METHOD_GET, token=workos.api_key, @@ -215,7 +215,7 @@ def get_user(self, user: str): return DirectoryUser.model_validate(response) - def get_group(self, group: str): + def get_group(self, group: str) -> DirectoryGroup: """Gets details for a single provisioned Directory Group. Args: @@ -224,14 +224,14 @@ def get_group(self, group: str): Returns: dict: Directory Group response from WorkOS. """ - response = self.request_helper.request( + response = self._http_client.request( "directory_groups/{group}".format(group=group), method=REQUEST_METHOD_GET, token=workos.api_key, ) return DirectoryGroup.model_validate(response) - def get_directory(self, directory: str): + def get_directory(self, directory: str) -> Directory: """Gets details for a single Directory Args: @@ -242,7 +242,7 @@ def get_directory(self, directory: str): """ - response = self.request_helper.request( + response = self._http_client.request( "directories/{directory}".format(directory=directory), method=REQUEST_METHOD_GET, token=workos.api_key, @@ -285,7 +285,7 @@ def list_directories( "search": search, } - response = self.request_helper.request( + response = self._http_client.request( "directories", method=REQUEST_METHOD_GET, params=list_params, @@ -297,7 +297,229 @@ def list_directories( **ListPage[Directory](**response).model_dump(), ) - def delete_directory(self, directory: str): + def delete_directory(self, directory: str) -> None: + """Delete one existing Directory. + + Args: + directory (str): The ID of the directory to be deleted. (Required) + + Returns: + None + """ + self._http_client.request( + "directories/{directory}".format(directory=directory), + method=REQUEST_METHOD_DELETE, + token=workos.api_key, + ) + + +class AsyncDirectorySync(DirectorySyncModule): + """Offers methods through the WorkOS Directory Sync service.""" + + _http_client: AsyncHTTPClient + + @validate_settings(DIRECTORY_SYNC_MODULE) + def __init__(self, http_client: AsyncHTTPClient): + self._http_client = http_client + + async def list_users( + self, + directory: Optional[str] = None, + group: Optional[str] = None, + limit: int = RESPONSE_LIMIT, + before: Optional[str] = None, + after: Optional[str] = None, + order: PaginationOrder = "desc", + ) -> AsyncWorkOsListResource[DirectoryUser, DirectoryUserListFilters]: + """Gets a list of provisioned Users for a Directory. + + Note, either 'directory' or 'group' must be provided. + + Args: + directory (str): Directory unique identifier. + group (str): Directory Group unique identifier. + limit (int): Maximum number of records to return. + before (str): Pagination cursor to receive records before a provided Directory ID. + after (str): Pagination cursor to receive records after a provided Directory ID. + order (Order): Sort records in either ascending or descending order by created_at timestamp. + + Returns: + dict: Directory Users response from WorkOS. + """ + + list_params = { + "limit": limit, + "before": before, + "after": after, + "order": order, + } + + if group is not None: + list_params["group"] = group + if directory is not None: + list_params["directory"] = directory + + response = await self._http_client.request( + "directory_users", + method=REQUEST_METHOD_GET, + params=list_params, + token=workos.api_key, + ) + + return AsyncWorkOsListResource( + list_method=self.list_users, + list_args=list_params, + **ListPage[DirectoryUser](**response).model_dump(), + ) + + async def list_groups( + self, + directory: Optional[str] = None, + user: Optional[str] = None, + limit: int = RESPONSE_LIMIT, + before: Optional[str] = None, + after: Optional[str] = None, + order: PaginationOrder = "desc", + ) -> AsyncWorkOsListResource[DirectoryGroup, DirectoryGroupListFilters]: + """Gets a list of provisioned Groups for a Directory . + + Note, either 'directory' or 'user' must be provided. + + Args: + directory (str): Directory unique identifier. + user (str): Directory User unique identifier. + limit (int): Maximum number of records to return. + before (str): Pagination cursor to receive records before a provided Directory ID. + after (str): Pagination cursor to receive records after a provided Directory ID. + order (Order): Sort records in either ascending or descending order by created_at timestamp. + + Returns: + dict: Directory Groups response from WorkOS. + """ + list_params = { + "limit": limit, + "before": before, + "after": after, + "order": order, + } + if user is not None: + list_params["user"] = user + if directory is not None: + list_params["directory"] = directory + + response = await self._http_client.request( + "directory_groups", + method=REQUEST_METHOD_GET, + params=list_params, + token=workos.api_key, + ) + + return AsyncWorkOsListResource( + list_method=self.list_groups, + list_args=list_params, + **ListPage[DirectoryGroup](**response).model_dump(), + ) + + async def get_user(self, user: str) -> DirectoryUser: + """Gets details for a single provisioned Directory User. + + Args: + user (str): Directory User unique identifier. + + Returns: + dict: Directory User response from WorkOS. + """ + response = await self._http_client.request( + "directory_users/{user}".format(user=user), + method=REQUEST_METHOD_GET, + token=workos.api_key, + ) + + return DirectoryUser.model_validate(response) + + async def get_group(self, group: str) -> DirectoryGroup: + """Gets details for a single provisioned Directory Group. + + Args: + group (str): Directory Group unique identifier. + + Returns: + dict: Directory Group response from WorkOS. + """ + response = await self._http_client.request( + "directory_groups/{group}".format(group=group), + method=REQUEST_METHOD_GET, + token=workos.api_key, + ) + return DirectoryGroup.model_validate(response) + + async def get_directory(self, directory: str) -> Directory: + """Gets details for a single Directory + + Args: + directory (str): Directory unique identifier. + + Returns: + dict: Directory response from WorkOS + + """ + + response = await self._http_client.request( + "directories/{directory}".format(directory=directory), + method=REQUEST_METHOD_GET, + token=workos.api_key, + ) + + return Directory.model_validate(response) + + async def list_directories( + self, + domain: Optional[str] = None, + search: Optional[str] = None, + limit: int = RESPONSE_LIMIT, + before: Optional[str] = None, + after: Optional[str] = None, + organization: Optional[str] = None, + order: PaginationOrder = "desc", + ) -> AsyncWorkOsListResource[Directory, DirectoryListFilters]: + """Gets details for existing Directories. + + Args: + domain (str): Domain of a Directory. (Optional) + organization: ID of an Organization (Optional) + search (str): Searchable text for a Directory. (Optional) + limit (int): Maximum number of records to return. (Optional) + before (str): Pagination cursor to receive records before a provided Directory ID. (Optional) + after (str): Pagination cursor to receive records after a provided Directory ID. (Optional) + order (Order): Sort records in either ascending or descending order by created_at timestamp. + + Returns: + dict: Directories response from WorkOS. + """ + + list_params = { + "domain": domain, + "organization": organization, + "search": search, + "limit": limit, + "before": before, + "after": after, + "order": order, + } + + response = await self._http_client.request( + "directories", + method=REQUEST_METHOD_GET, + params=list_params, + token=workos.api_key, + ) + return AsyncWorkOsListResource( + list_method=self.list_directories, + list_args=list_params, + **ListPage[Directory](**response).model_dump(), + ) + + async def delete_directory(self, directory: str) -> None: """Delete one existing Directory. Args: @@ -306,7 +528,7 @@ def delete_directory(self, directory: str): Returns: None """ - self.request_helper.request( + await self._http_client.request( "directories/{directory}".format(directory=directory), method=REQUEST_METHOD_DELETE, token=workos.api_key, diff --git a/workos/events.py b/workos/events.py index d612b25d..fc047a86 100644 --- a/workos/events.py +++ b/workos/events.py @@ -1,9 +1,8 @@ -from typing import Awaitable, List, Optional, Protocol, Union +from typing import List, Optional, Protocol import workos -from workos.utils.request import ( - REQUEST_METHOD_GET, -) +from workos.typing.sync_or_async import SyncOrAsync +from workos.utils.request import REQUEST_METHOD_GET from workos.utils.http_client import AsyncHTTPClient, SyncHTTPClient from workos.utils.validation import EVENTS_MODULE, validate_settings from workos.resources.list import WorkOSListResource @@ -21,7 +20,7 @@ def list_events( after: Optional[str] = None, range_start: Optional[str] = None, range_end: Optional[str] = None, - ) -> Union[dict, Awaitable[dict]]: + ) -> SyncOrAsync[dict]: ... diff --git a/workos/resources/list.py b/workos/resources/list.py index 99ceda18..145ba4ae 100644 --- a/workos/resources/list.py +++ b/workos/resources/list.py @@ -1,4 +1,7 @@ +import abc from typing import ( + AsyncIterator, + Awaitable, List, Literal, TypeVar, @@ -6,6 +9,7 @@ Callable, Iterator, Optional, + Union, ) from typing_extensions import TypedDict from workos.resources.base import WorkOSBaseResource @@ -137,7 +141,7 @@ class ListArgs(TypedDict): ListAndFilterParams = TypeVar("ListAndFilterParams", bound=ListArgs) -class WorkOsListResource( +class BaseWorkOsListResource( WorkOSModel, Generic[ListableResource, ListAndFilterParams], ): @@ -148,11 +152,7 @@ class WorkOsListResource( list_method: Callable = Field(exclude=True) list_args: ListAndFilterParams = Field(exclude=True) - def auto_paging_iter(self) -> Iterator[ListableResource]: - next_page: WorkOsListResource[ListableResource, ListAndFilterParams] - - after = self.list_metadata.after - + def _parse_params(self): fixed_pagination_params = { "order": self.list_args["order"], "limit": self.list_args["limit"], @@ -163,7 +163,26 @@ def auto_paging_iter(self) -> Iterator[ListableResource]: for k, v in self.list_args.items() if k not in {"order", "limit", "before", "after"} } + + return fixed_pagination_params, filter_params + + @abc.abstractmethod + def auto_paging_iter( + self, + ) -> Union[AsyncIterator[ListableResource], Iterator[ListableResource]]: + ... + + +class WorkOsListResource( + BaseWorkOsListResource, + Generic[ListableResource, ListAndFilterParams], +): + def auto_paging_iter(self) -> Iterator[ListableResource]: + next_page: WorkOsListResource[ListableResource, ListAndFilterParams] + after = self.list_metadata.after + fixed_pagination_params, filter_params = self._parse_params() index: int = 0 + while True: if index >= len(self.data): if after is not None: @@ -178,3 +197,35 @@ def auto_paging_iter(self) -> Iterator[ListableResource]: return yield self.data[index] index += 1 + + +class AsyncWorkOsListResource( + BaseWorkOsListResource, + Generic[ListableResource, ListAndFilterParams], +): + async def auto_paging_iter(self) -> AsyncIterator[ListableResource]: + next_page: WorkOsListResource[ListableResource, ListAndFilterParams] + after = self.list_metadata.after + fixed_pagination_params, filter_params = self._parse_params() + index: int = 0 + + while True: + if index >= len(self.data): + if after is not None: + next_page = await self.list_method( + after=after, **fixed_pagination_params, **filter_params + ) + self.data = next_page.data + after = next_page.list_metadata.after + index = 0 + continue + else: + return + yield self.data[index] + index += 1 + + +SyncOrAsyncListResource = Union[ + Awaitable[AsyncWorkOsListResource], + WorkOsListResource, +] diff --git a/workos/typing/sync_or_async.py b/workos/typing/sync_or_async.py new file mode 100644 index 00000000..d336c76e --- /dev/null +++ b/workos/typing/sync_or_async.py @@ -0,0 +1,5 @@ +from typing import Awaitable, TypeVar, Union + + +T = TypeVar("T") +SyncOrAsync = Union[T, Awaitable[T]] diff --git a/workos/utils/_base_http_client.py b/workos/utils/_base_http_client.py index aab0f7d6..0baeca9d 100644 --- a/workos/utils/_base_http_client.py +++ b/workos/utils/_base_http_client.py @@ -3,6 +3,7 @@ cast, Dict, Generic, + Mapping, Optional, TypeVar, TypedDict, @@ -32,8 +33,8 @@ class PreparedRequest(TypedDict): method: str url: str headers: httpx.Headers - params: NotRequired[Union[Dict, None]] - json: NotRequired[Union[Dict, None]] + params: NotRequired[Union[Mapping, None]] + json: NotRequired[Union[Mapping, None]] timeout: int @@ -104,7 +105,7 @@ def _prepare_request( self, path: str, method: Optional[str] = REQUEST_METHOD_GET, - params: Optional[dict] = None, + params: Optional[Mapping] = None, headers: Optional[dict] = None, token: Optional[str] = None, ) -> PreparedRequest: diff --git a/workos/utils/http_client.py b/workos/utils/http_client.py index d89f324f..838dd49b 100644 --- a/workos/utils/http_client.py +++ b/workos/utils/http_client.py @@ -1,8 +1,9 @@ import asyncio -from typing import Awaitable, Optional +from typing import Dict, Mapping, Optional, Union, TypedDict import httpx +from workos.resources.list import ListArgs from workos.utils._base_http_client import BaseHTTPClient from workos.utils.request import REQUEST_METHOD_GET @@ -68,7 +69,7 @@ def request( self, path: str, method: Optional[str] = REQUEST_METHOD_GET, - params: Optional[dict] = None, + params: Optional[Mapping] = None, headers: Optional[dict] = None, token: Optional[str] = None, ) -> dict: