From ffcbd556acebe04ef629d6aee35d9c20c20e9c3f Mon Sep 17 00:00:00 2001 From: mattgd Date: Tue, 13 Aug 2024 18:14:26 -0400 Subject: [PATCH 1/3] Add async support for organizations. --- tests/test_organizations.py | 155 ++++++++++++++++++++++++++++++++-- tests/test_user_management.py | 1 - workos/async_client.py | 10 +-- workos/organizations.py | 122 ++++++++++++++++++++++++-- 4 files changed, 267 insertions(+), 21 deletions(-) diff --git a/tests/test_organizations.py b/tests/test_organizations.py index 71d4f565..a3087bb2 100644 --- a/tests/test_organizations.py +++ b/tests/test_organizations.py @@ -1,16 +1,11 @@ import datetime import pytest from tests.utils.list_resource import list_data_to_dicts, list_response_of -from workos.organizations import Organizations +from workos.organizations import AsyncOrganizations, Organizations from tests.utils.fixtures.mock_organization import MockOrganization -class TestOrganizations(object): - @pytest.fixture(autouse=True) - def setup(self, sync_http_client_for_test): - self.http_client = sync_http_client_for_test - self.organizations = Organizations(http_client=self.http_client) - +class OrganizationFixtures: @pytest.fixture def mock_organization(self): return MockOrganization("org_01EHT88Z8J8795GZNQ4ZP1J81T").dict() @@ -63,6 +58,13 @@ def mock_organizations_multiple_data_pages(self): ] return list_response_of(data=organizations_list) + +class TestOrganizations(OrganizationFixtures): + @pytest.fixture(autouse=True) + def setup(self, sync_http_client_for_test): + self.http_client = sync_http_client_for_test + self.organizations = Organizations(http_client=self.http_client) + def test_list_organizations( self, mock_organizations, mock_http_client_with_response ): @@ -200,3 +202,142 @@ def test_list_organizations_auto_pagination_for_multiple_pages( list_function=self.organizations.list_organizations, expected_all_page_data=mock_organizations_multiple_data_pages["data"], ) + + +@pytest.mark.asyncio +class TestAsyncOrganizations(OrganizationFixtures): + @pytest.fixture(autouse=True) + def setup(self, async_http_client_for_test): + self.http_client = async_http_client_for_test + self.organizations = AsyncOrganizations(http_client=self.http_client) + + async def test_list_organizations( + self, mock_organizations, mock_http_client_with_response + ): + mock_http_client_with_response(self.http_client, mock_organizations, 200) + + organizations_response = await self.organizations.list_organizations() + + def to_dict(x): + return x.dict() + + assert ( + list(map(to_dict, organizations_response.data)) + == mock_organizations["data"] + ) + + async def test_get_organization( + self, mock_organization, mock_http_client_with_response + ): + mock_http_client_with_response(self.http_client, mock_organization, 200) + + organization = await self.organizations.get_organization( + organization_id="organization_id" + ) + + assert organization.dict() == mock_organization + + async def test_get_organization_by_lookup_key( + self, mock_organization, mock_http_client_with_response + ): + mock_http_client_with_response(self.http_client, mock_organization, 200) + + organization = await self.organizations.get_organization_by_lookup_key( + lookup_key="test" + ) + + assert organization.dict() == mock_organization + + async def test_create_organization_with_domain_data( + self, mock_organization, mock_http_client_with_response + ): + mock_http_client_with_response(self.http_client, mock_organization, 201) + + payload = { + "domain_data": [{"domain": "example.com", "state": "verified"}], + "name": "Test Organization", + } + organization = await self.organizations.create_organization(**payload) + + assert organization.id == "org_01EHT88Z8J8795GZNQ4ZP1J81T" + assert organization.name == "Foo Corporation" + + async def test_sends_idempotency_key( + self, mock_organization, capture_and_mock_http_client_request + ): + idempotency_key = "test_123456789" + + payload = { + "domain_data": [{"domain": "example.com", "state": "verified"}], + "name": "Foo Corporation", + } + + request_kwargs = capture_and_mock_http_client_request( + self.http_client, mock_organization, 200 + ) + + response = await self.organizations.create_organization( + **payload, idempotency_key=idempotency_key + ) + + assert request_kwargs["headers"]["idempotency-key"] == idempotency_key + assert response.name == "Foo Corporation" + + async def test_update_organization_with_domain_data( + self, mock_organization_updated, mock_http_client_with_response + ): + mock_http_client_with_response(self.http_client, mock_organization_updated, 201) + + updated_organization = await self.organizations.update_organization( + organization_id="org_01EHT88Z8J8795GZNQ4ZP1J81T", + name="Example Organization", + domain_data=[{"domain": "example.io", "state": "verified"}], + ) + + assert updated_organization.id == "org_01EHT88Z8J8795GZNQ4ZP1J81T" + assert updated_organization.name == "Example Organization" + assert updated_organization.domains[0].dict() == { + "domain": "example.io", + "object": "organization_domain", + "id": "org_domain_01EHT88Z8WZEFWYPM6EC9BX2R8", + "state": "verified", + "organization_id": "org_01EHT88Z8J8795GZNQ4ZP1J81T", + "verification_strategy": "dns", + "verification_token": "token", + } + + async def test_delete_organization(self, setup, mock_http_client_with_response): + mock_http_client_with_response( + self.http_client, + 202, + headers={"content-type": "text/plain; charset=utf-8"}, + ) + + response = await self.organizations.delete_organization( + organization_id="connection_id" + ) + + assert response is None + + async def test_list_organizations_auto_pagination_for_multiple_pages( + self, + mock_organizations_multiple_data_pages, + mock_pagination_request_for_http_client, + ): + mock_pagination_request_for_http_client( + http_client=self.http_client, + data_list=mock_organizations_multiple_data_pages["data"], + status_code=200, + ) + + all_organizations = [] + + async for organization in await self.organizations.list_organizations(): + all_organizations.append(organization) + + assert len(list(all_organizations)) == len( + mock_organizations_multiple_data_pages["data"] + ) + assert ( + list_data_to_dicts(all_organizations) + ) == mock_organizations_multiple_data_pages["data"] diff --git a/tests/test_user_management.py b/tests/test_user_management.py index 43cf81f0..ca79adc8 100644 --- a/tests/test_user_management.py +++ b/tests/test_user_management.py @@ -1,5 +1,4 @@ import json -from os import sync from six.moves.urllib.parse import parse_qsl, urlparse import pytest diff --git a/workos/async_client.py b/workos/async_client.py index 5908b348..c65af0fd 100644 --- a/workos/async_client.py +++ b/workos/async_client.py @@ -6,7 +6,7 @@ from workos.events import AsyncEvents from workos.fga import FGAModule from workos.mfa import MFAModule -from workos.organizations import OrganizationsModule +from workos.organizations import AsyncOrganizations from workos.passwordless import PasswordlessModule from workos.portal import PortalModule from workos.sso import AsyncSSO @@ -73,10 +73,10 @@ def fga(self) -> FGAModule: raise NotImplementedError("FGA APIs are not yet supported in the async client.") @property - def organizations(self) -> OrganizationsModule: - raise NotImplementedError( - "Organizations APIs are not yet supported in the async client." - ) + def organizations(self) -> AsyncOrganizations: + if not getattr(self, "_organizations", None): + self._organizations = AsyncOrganizations(self._http_client) + return self._organizations @property def passwordless(self) -> PasswordlessModule: diff --git a/workos/organizations.py b/workos/organizations.py index 27744fb7..eebcf891 100644 --- a/workos/organizations.py +++ b/workos/organizations.py @@ -2,7 +2,8 @@ from workos.types.organizations.domain_data_input import DomainDataInput from workos.types.organizations.list_filters import OrganizationListFilters -from workos.utils.http_client import SyncHTTPClient +from workos.typing.sync_or_async import SyncOrAsync +from workos.utils.http_client import AsyncHTTPClient, SyncHTTPClient from workos.utils.pagination_order import PaginationOrder from workos.utils.request_helper import ( DEFAULT_LIST_RESPONSE_LIMIT, @@ -33,7 +34,7 @@ def list_organizations( before: Optional[str] = None, after: Optional[str] = None, order: PaginationOrder = "desc", - ) -> OrganizationsListResource: + ) -> SyncOrAsync[OrganizationsListResource]: """Retrieve a list of organizations that have connections configured within your WorkOS dashboard. Kwargs: @@ -48,7 +49,7 @@ def list_organizations( """ ... - def get_organization(self, organization_id: str) -> Organization: + def get_organization(self, organization_id: str) -> SyncOrAsync[Organization]: """Gets details for a single Organization Args: @@ -58,7 +59,9 @@ def get_organization(self, organization_id: str) -> Organization: """ ... - def get_organization_by_lookup_key(self, lookup_key: str) -> Organization: + def get_organization_by_lookup_key( + self, lookup_key: str + ) -> SyncOrAsync[Organization]: """Gets details for a single Organization by lookup key Args: @@ -75,7 +78,9 @@ def create_organization( name: str, domain_data: Optional[Sequence[DomainDataInput]] = None, idempotency_key: Optional[str] = None, - ) -> Organization: ... + ) -> SyncOrAsync[Organization]: + """Create an organization""" + ... def update_organization( self, @@ -83,7 +88,7 @@ def update_organization( organization_id: str, name: Optional[str] = None, domain_data: Optional[Sequence[DomainDataInput]] = None, - ) -> Organization: + ) -> SyncOrAsync[Organization]: """Update an organization Kwargs: @@ -97,7 +102,7 @@ def update_organization( """ ... - def delete_organization(self, organization_id: str) -> None: + def delete_organization(self, organization_id: str) -> SyncOrAsync[None]: """Deletes a single Organization Args: @@ -167,7 +172,6 @@ def create_organization( domain_data: Optional[Sequence[DomainDataInput]] = None, idempotency_key: Optional[str] = None, ) -> Organization: - """Create an organization""" headers = {} if idempotency_key: headers["idempotency-key"] = idempotency_key @@ -210,3 +214,105 @@ def delete_organization(self, organization_id: str) -> None: f"organizations/{organization_id}", method=REQUEST_METHOD_DELETE, ) + + +class AsyncOrganizations(OrganizationsModule): + + _http_client: AsyncHTTPClient + + def __init__(self, http_client: AsyncHTTPClient): + self._http_client = http_client + + async def list_organizations( + self, + *, + domains: Optional[Sequence[str]] = None, + limit: int = DEFAULT_LIST_RESPONSE_LIMIT, + before: Optional[str] = None, + after: Optional[str] = None, + order: PaginationOrder = "desc", + ) -> OrganizationsListResource: + list_params: OrganizationListFilters = { + "limit": limit, + "before": before, + "after": after, + "order": order, + "domains": domains, + } + + response = await self._http_client.request( + ORGANIZATIONS_PATH, + method=REQUEST_METHOD_GET, + params=list_params, + ) + + return WorkOSListResource[Organization, OrganizationListFilters, ListMetadata]( + list_method=self.list_organizations, + list_args=list_params, + **ListPage[Organization](**response).model_dump(), + ) + + async def get_organization(self, organization_id: str) -> Organization: + response = await self._http_client.request( + f"organizations/{organization_id}", method=REQUEST_METHOD_GET + ) + + return Organization.model_validate(response) + + async def get_organization_by_lookup_key(self, lookup_key: str) -> Organization: + response = await self._http_client.request( + "organizations/by_lookup_key/{lookup_key}".format(lookup_key=lookup_key), + method=REQUEST_METHOD_GET, + ) + + return Organization.model_validate(response) + + async def create_organization( + self, + *, + name: str, + domain_data: Optional[Sequence[DomainDataInput]] = None, + idempotency_key: Optional[str] = None, + ) -> Organization: + headers = {} + if idempotency_key: + headers["idempotency-key"] = idempotency_key + + json = { + "name": name, + "domain_data": domain_data, + "idempotency_key": idempotency_key, + } + + response = await self._http_client.request( + ORGANIZATIONS_PATH, + method=REQUEST_METHOD_POST, + json=json, + headers=headers, + ) + + return Organization.model_validate(response) + + async def update_organization( + self, + *, + organization_id: str, + name: Optional[str] = None, + domain_data: Optional[Sequence[DomainDataInput]] = None, + ) -> Organization: + json = { + "name": name, + "domain_data": domain_data, + } + + response = await self._http_client.request( + f"organizations/{organization_id}", method=REQUEST_METHOD_PUT, json=json + ) + + return Organization.model_validate(response) + + async def delete_organization(self, organization_id: str) -> None: + await self._http_client.request( + f"organizations/{organization_id}", + method=REQUEST_METHOD_DELETE, + ) From 523169400edceaa3531642727acb5a44f038301a Mon Sep 17 00:00:00 2001 From: mattgd Date: Wed, 14 Aug 2024 10:58:39 -0400 Subject: [PATCH 2/3] Override limit when using auto paginator. --- workos/types/list_resource.py | 10 +++++++--- workos/types/webhooks/webhook.py | 4 +--- workos/user_management.py | 2 +- workos/utils/_base_http_client.py | 2 +- workos/webhooks.py | 1 - 5 files changed, 10 insertions(+), 9 deletions(-) diff --git a/workos/types/list_resource.py b/workos/types/list_resource.py index 6b289b82..d217795e 100644 --- a/workos/types/list_resource.py +++ b/workos/types/list_resource.py @@ -29,6 +29,7 @@ from workos.types.sso import ConnectionWithDomains from workos.types.user_management import Invitation, OrganizationMembership, User from workos.types.workos_model import WorkOSModel +from workos.utils.request_helper import DEFAULT_LIST_RESPONSE_LIMIT ListableResource = TypeVar( # add all possible generics of List Resource @@ -99,13 +100,13 @@ class WorkOSListResource( list_args: ListAndFilterParams = Field(exclude=True) def _parse_params( - self, + self, limit_override: Optional[int] = None ) -> Tuple[Dict[str, Union[int, str, None]], Mapping[str, Any]]: fixed_pagination_params = cast( # Type hints consider this a mismatch because it assume the dictionary is dict[str, int] Dict[str, Union[int, str, None]], { - "limit": self.list_args["limit"], + "limit": limit_override or self.list_args["limit"], }, ) if "order" in self.list_args: @@ -133,7 +134,10 @@ def __iter__(self) -> Iterator[ListableResource]: # type: ignore ListableResource, ListAndFilterParams, ListMetadataType ] after = self.list_metadata.after - fixed_pagination_params, filter_params = self._parse_params() + fixed_pagination_params, filter_params = self._parse_params( + # Singe we're auto-paginating, ignore the original limit and use the default + limit_override=DEFAULT_LIST_RESPONSE_LIMIT + ) index: int = 0 while True: diff --git a/workos/types/webhooks/webhook.py b/workos/types/webhooks/webhook.py index 7facb05d..b5710be7 100644 --- a/workos/types/webhooks/webhook.py +++ b/workos/types/webhooks/webhook.py @@ -1,11 +1,9 @@ -from typing import Generic, Literal, Union +from typing import Literal, Union from pydantic import Field from typing_extensions import Annotated from workos.types.directory_sync import DirectoryGroup -from workos.types.events import EventPayload from workos.types.user_management import OrganizationMembership, User from workos.types.webhooks.webhook_model import WebhookModel -from workos.types.workos_model import WorkOSModel from workos.types.directory_sync.directory_user import DirectoryUser from workos.types.events.authentication_payload import ( AuthenticationEmailVerificationSucceededPayload, diff --git a/workos/user_management.py b/workos/user_management.py index 1f781daf..6d389649 100644 --- a/workos/user_management.py +++ b/workos/user_management.py @@ -47,7 +47,7 @@ UserManagementProviderType, ) from workos.typing.sync_or_async import SyncOrAsync -from workos.utils.http_client import AsyncHTTPClient, HTTPClient, SyncHTTPClient +from workos.utils.http_client import AsyncHTTPClient, SyncHTTPClient from workos.utils.pagination_order import PaginationOrder from workos.utils.request_helper import ( DEFAULT_LIST_RESPONSE_LIMIT, diff --git a/workos/utils/_base_http_client.py b/workos/utils/_base_http_client.py index ff9c127f..bfd8eea1 100644 --- a/workos/utils/_base_http_client.py +++ b/workos/utils/_base_http_client.py @@ -214,7 +214,7 @@ def auth_headers(self) -> Mapping[str, str]: def auth_header_from_token(self, token: str) -> Mapping[str, str]: return { - "Authorization": f"Bearer {token }", + "Authorization": f"Bearer {token}", } @property diff --git a/workos/webhooks.py b/workos/webhooks.py index 473a1ed9..948210db 100644 --- a/workos/webhooks.py +++ b/workos/webhooks.py @@ -1,7 +1,6 @@ import hashlib import hmac import time -import hashlib from typing import Optional, Protocol from workos.types.webhooks.webhook import Webhook from workos.types.webhooks.webhook_payload import WebhookPayload From 15e306a65d5d8bc7f196c75b452c143f3b911475 Mon Sep 17 00:00:00 2001 From: mattgd Date: Wed, 14 Aug 2024 12:33:48 -0400 Subject: [PATCH 3/3] Update doc. --- workos/organizations.py | 12 ++++++++++-- 1 file changed, 10 insertions(+), 2 deletions(-) diff --git a/workos/organizations.py b/workos/organizations.py index eebcf891..61ec9218 100644 --- a/workos/organizations.py +++ b/workos/organizations.py @@ -79,7 +79,16 @@ def create_organization( domain_data: Optional[Sequence[DomainDataInput]] = None, idempotency_key: Optional[str] = None, ) -> SyncOrAsync[Organization]: - """Create an organization""" + """Create an organization + + Kwargs: + name (str): A descriptive name for the organization. (Optional) + domain_data (Sequence[DomainDataInput]): List of domains that belong to the organization. (Optional) + idempotency_key (str): Key to guarantee idempotency across requests. (Optional) + + Returns: + Organization: Updated Organization response from WorkOS. + """ ... def update_organization( @@ -94,7 +103,6 @@ def update_organization( Kwargs: organization (str): Organization's unique identifier. name (str): A descriptive name for the organization. (Optional) - domains (list): [Deprecated] Use domain_data instead. List of domains that belong to the organization. (Optional) domain_data (Sequence[DomainDataInput]): List of domains that belong to the organization. (Optional) Returns: