From dc703521faa89cc1ee0a6444325895e945ca6ec2 Mon Sep 17 00:00:00 2001 From: Peter Arzhintar Date: Mon, 12 Aug 2024 15:25:13 -0700 Subject: [PATCH 01/11] Remove token and add ignore comment to shadowed import --- workos/utils/http_client.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/workos/utils/http_client.py b/workos/utils/http_client.py index 6da22677..e5fe42a7 100644 --- a/workos/utils/http_client.py +++ b/workos/utils/http_client.py @@ -1,7 +1,7 @@ import asyncio from types import TracebackType from typing import Optional, Type, Union -from typing_extensions import Self +from typing_extensions import Self # type: ignore shadowed import, Self was added to typing in 3.11 import httpx @@ -83,7 +83,6 @@ def request( params: ParamsType = None, json: JsonType = None, headers: HeadersType = None, - token: Optional[str] = None, ) -> ResponseJson: """Executes a request against the WorkOS API. @@ -94,7 +93,6 @@ def request( method (str): One of the supported methods as defined by the REQUEST_METHOD_X constants params (ParamsType): Query params to be added to the request json (JsonType): Body payload to be added to the request - token (str): Bearer token Returns: ResponseJson: Response from WorkOS @@ -185,7 +183,6 @@ async def request( method (str): One of the supported methods as defined by the REQUEST_METHOD_X constants params (ParamsType): Query params to be added to the request json (JsonType): Body payload to be added to the request - token (str): Bearer token Returns: ResponseJson: Response from WorkOS From 4bdd2bb159c6bb120d0fe9a7e9c0fd23184205a1 Mon Sep 17 00:00:00 2001 From: Peter Arzhintar Date: Mon, 12 Aug 2024 15:28:10 -0700 Subject: [PATCH 02/11] More token fixes --- workos/sso.py | 4 +--- workos/utils/http_client.py | 4 +++- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/workos/sso.py b/workos/sso.py index 50f70e26..fb8c36ae 100644 --- a/workos/sso.py +++ b/workos/sso.py @@ -138,9 +138,7 @@ def get_profile(self, access_token: str) -> Profile: Returns: Profile """ - response = self._http_client.request( - PROFILE_PATH, method=REQUEST_METHOD_GET, token=access_token - ) + response = self._http_client.request(PROFILE_PATH, method=REQUEST_METHOD_GET) return Profile.model_validate(response) diff --git a/workos/utils/http_client.py b/workos/utils/http_client.py index e5fe42a7..9a23bcc0 100644 --- a/workos/utils/http_client.py +++ b/workos/utils/http_client.py @@ -1,7 +1,9 @@ import asyncio from types import TracebackType from typing import Optional, Type, Union -from typing_extensions import Self # type: ignore shadowed import, Self was added to typing in 3.11 + +# Self was added to typing in Python 3.11 +from typing_extensions import Self import httpx From 425809f7f083678d630eead6b88d19151b66f35a Mon Sep 17 00:00:00 2001 From: Peter Arzhintar Date: Mon, 12 Aug 2024 16:00:53 -0700 Subject: [PATCH 03/11] Minor typing fixes and extract protocol for client config --- workos/_base_client.py | 19 +++++++++++++++++-- workos/sso.py | 22 ++++++++++++++++------ workos/utils/http_client.py | 8 +++++++- 3 files changed, 40 insertions(+), 9 deletions(-) diff --git a/workos/_base_client.py b/workos/_base_client.py index b68f9b93..0bb20f3b 100644 --- a/workos/_base_client.py +++ b/workos/_base_client.py @@ -1,6 +1,6 @@ from abc import abstractmethod import os -from typing import Generic, Optional, Type, TypeVar +from typing import Generic, Optional, Protocol, Type, TypeVar from workos.__about__ import __version__ from workos.fga import FGAModule @@ -21,7 +21,14 @@ HTTPClientType = TypeVar("HTTPClientType", bound=HTTPClient) -class BaseClient(Generic[HTTPClientType]): +class ClientConfiguration(Protocol): + @property + def base_url(self) -> str: ... + @property + def client_id(self) -> str: ... + + +class BaseClient(Generic[HTTPClientType], ClientConfiguration): """Base client for accessing the WorkOS feature set.""" _api_key: str @@ -117,3 +124,11 @@ def user_management(self) -> UserManagementModule: ... @property @abstractmethod def webhooks(self) -> WebhooksModule: ... + + @property + def base_url(self) -> str: + return self._base_url + + @property + def client_id(self) -> str: + return self._client_id diff --git a/workos/sso.py b/workos/sso.py index fb8c36ae..c4ba7f78 100644 --- a/workos/sso.py +++ b/workos/sso.py @@ -1,4 +1,5 @@ from typing import Optional, Protocol +from workos._base_client import ClientConfiguration from workos.types.sso.connection import ConnectionType from workos.types.sso.sso_provider_type import SsoProviderType from workos.typing.sync_or_async import SyncOrAsync @@ -40,7 +41,7 @@ class ConnectionsListFilters(ListArgs, total=False): class SSOModule(Protocol): - _http_client: HTTPClient + client_configuration: ClientConfiguration def get_authorization_url( self, @@ -70,7 +71,7 @@ def get_authorization_url( str: URL to redirect a User to to begin the OAuth workflow with WorkOS """ params: QueryParameters = { - "client_id": self._http_client.client_id, + "client_id": self.client_configuration.client_id, "redirect_uri": redirect_uri, "response_type": RESPONSE_TYPE_CODE, } @@ -94,10 +95,12 @@ def get_authorization_url( params["state"] = state return RequestHelper.build_url_with_query_params( - base_url=self._http_client.base_url, path=AUTHORIZATION_PATH, **params + base_url=self.client_configuration.base_url, + path=AUTHORIZATION_PATH, + **params, ) - def get_profile(self, accessToken: str) -> SyncOrAsync[Profile]: ... + def get_profile(self, access_token: str) -> SyncOrAsync[Profile]: ... def get_profile_and_token(self, code: str) -> SyncOrAsync[ProfileAndToken]: ... @@ -117,7 +120,7 @@ def list_connections( order: PaginationOrder = "desc", ) -> SyncOrAsync[ConnectionsListResource]: ... - def delete_connection(self, connection: str) -> SyncOrAsync[None]: ... + def delete_connection(self, connection_id: str) -> SyncOrAsync[None]: ... class SSO(SSOModule): @@ -126,6 +129,7 @@ class SSO(SSOModule): _http_client: SyncHTTPClient def __init__(self, http_client: SyncHTTPClient): + self.client_configuration = http_client self._http_client = http_client def get_profile(self, access_token: str) -> Profile: @@ -138,7 +142,12 @@ def get_profile(self, access_token: str) -> Profile: Returns: Profile """ - response = self._http_client.request(PROFILE_PATH, method=REQUEST_METHOD_GET) + response = self._http_client.request( + PROFILE_PATH, + method=REQUEST_METHOD_GET, + headers={**self._http_client.auth_header_from_token(access_token)}, + exclude_default_auth_headers=True, + ) return Profile.model_validate(response) @@ -249,6 +258,7 @@ class AsyncSSO(SSOModule): _http_client: AsyncHTTPClient def __init__(self, http_client: AsyncHTTPClient): + self.client_configuration = http_client self._http_client = http_client async def get_profile(self, access_token: str) -> Profile: diff --git a/workos/utils/http_client.py b/workos/utils/http_client.py index 9a23bcc0..bc8dd426 100644 --- a/workos/utils/http_client.py +++ b/workos/utils/http_client.py @@ -85,6 +85,7 @@ def request( params: ParamsType = None, json: JsonType = None, headers: HeadersType = None, + exclude_default_auth_headers: bool = False, ) -> ResponseJson: """Executes a request against the WorkOS API. @@ -100,7 +101,12 @@ def request( ResponseJson: Response from WorkOS """ prepared_request_parameters = self._prepare_request( - path=path, method=method, params=params, json=json, headers=headers + path=path, + method=method, + params=params, + json=json, + headers=headers, + exclude_default_auth_headers=exclude_default_auth_headers, ) response = self._client.request(**prepared_request_parameters) return self._handle_response(response) From 3d853a24c2c2f895e9bd275e878a0290bd72be1e Mon Sep 17 00:00:00 2001 From: Peter Arzhintar Date: Mon, 12 Aug 2024 16:28:04 -0700 Subject: [PATCH 04/11] directory sync params match protocol --- workos/directory_sync.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/workos/directory_sync.py b/workos/directory_sync.py index b681dceb..2f4245ee 100644 --- a/workos/directory_sync.py +++ b/workos/directory_sync.py @@ -60,11 +60,11 @@ def list_groups( order: PaginationOrder = "desc", ) -> SyncOrAsync[DirectoryGroupsListResource]: ... - def get_user(self, user: str) -> SyncOrAsync[DirectoryUserWithGroups]: ... + def get_user(self, user_id: str) -> SyncOrAsync[DirectoryUserWithGroups]: ... - def get_group(self, group: str) -> SyncOrAsync[DirectoryGroup]: ... + def get_group(self, group_id: str) -> SyncOrAsync[DirectoryGroup]: ... - def get_directory(self, directory: str) -> SyncOrAsync[Directory]: ... + def get_directory(self, directory_id: str) -> SyncOrAsync[Directory]: ... def list_directories( self, @@ -77,7 +77,7 @@ def list_directories( order: PaginationOrder = "desc", ) -> SyncOrAsync[DirectoriesListResource]: ... - def delete_directory(self, directory: str) -> SyncOrAsync[None]: ... + def delete_directory(self, directory_id: str) -> SyncOrAsync[None]: ... class DirectorySync(DirectorySyncModule): From 045fe756a359521d8ba36c094f4db0846583eaa7 Mon Sep 17 00:00:00 2001 From: Peter Arzhintar Date: Mon, 12 Aug 2024 16:28:46 -0700 Subject: [PATCH 05/11] Do not shadow http client attribute, extract a protocol for client config --- workos/_base_client.py | 23 +++++++----------- workos/async_client.py | 20 ++++++++++++---- workos/client.py | 19 +++++++++++---- workos/sso.py | 18 ++++++++------ workos/user_management.py | 49 ++++++++++++++++++++++++++++++--------- 5 files changed, 86 insertions(+), 43 deletions(-) diff --git a/workos/_base_client.py b/workos/_base_client.py index 0bb20f3b..da3e04e2 100644 --- a/workos/_base_client.py +++ b/workos/_base_client.py @@ -1,6 +1,6 @@ from abc import abstractmethod import os -from typing import Generic, Optional, Protocol, Type, TypeVar +from typing import Optional, Protocol from workos.__about__ import __version__ from workos.fga import FGAModule @@ -18,24 +18,22 @@ from workos.webhooks import WebhooksModule -HTTPClientType = TypeVar("HTTPClientType", bound=HTTPClient) - - class ClientConfiguration(Protocol): @property def base_url(self) -> str: ... @property def client_id(self) -> str: ... + @property + def request_timeout(self) -> int: ... -class BaseClient(Generic[HTTPClientType], ClientConfiguration): +class BaseClient(ClientConfiguration): """Base client for accessing the WorkOS feature set.""" _api_key: str _base_url: str _client_id: str _request_timeout: int - _http_client: HTTPClient def __init__( self, @@ -44,7 +42,6 @@ def __init__( client_id: Optional[str], base_url: Optional[str] = None, request_timeout: Optional[int] = None, - http_client_cls: Type[HTTPClientType], ) -> None: api_key = api_key or os.getenv("WORKOS_API_KEY") if api_key is None: @@ -73,14 +70,6 @@ def __init__( else int(os.getenv("WORKOS_REQUEST_TIMEOUT", DEFAULT_REQUEST_TIMEOUT)) ) - self._http_client = http_client_cls( - api_key=self._api_key, - base_url=self._base_url, - client_id=self._client_id, - version=__version__, - timeout=self._request_timeout, - ) - @property @abstractmethod def audit_logs(self) -> AuditLogsModule: ... @@ -132,3 +121,7 @@ def base_url(self) -> str: @property def client_id(self) -> str: return self._client_id + + @property + def request_timeout(self) -> int: + return self._request_timeout diff --git a/workos/async_client.py b/workos/async_client.py index 4be138a3..93f756d9 100644 --- a/workos/async_client.py +++ b/workos/async_client.py @@ -1,5 +1,5 @@ from typing import Optional - +from workos.__about__ import __version__ from workos._base_client import BaseClient from workos.audit_logs import AuditLogsModule from workos.directory_sync import AsyncDirectorySync @@ -15,7 +15,7 @@ from workos.webhooks import WebhooksModule -class AsyncClient(BaseClient[AsyncHTTPClient]): +class AsyncClient(BaseClient): """Client for a convenient way to access the WorkOS feature set.""" _http_client: AsyncHTTPClient @@ -33,13 +33,21 @@ def __init__( client_id=client_id, base_url=base_url, request_timeout=request_timeout, - http_client_cls=AsyncHTTPClient, + ) + self._http_client = AsyncHTTPClient( + api_key=self._api_key, + base_url=self._base_url, + client_id=self._client_id, + version=__version__, + timeout=self.request_timeout, ) @property def sso(self) -> AsyncSSO: if not getattr(self, "_sso", None): - self._sso = AsyncSSO(self._http_client) + self._sso = AsyncSSO( + http_client=self._http_client, client_configuration=self + ) return self._sso @property @@ -93,5 +101,7 @@ def mfa(self) -> MFAModule: @property def user_management(self) -> AsyncUserManagement: if not getattr(self, "_user_management", None): - self._user_management = AsyncUserManagement(self._http_client) + self._user_management = AsyncUserManagement( + http_client=self._http_client, client_configuration=self + ) return self._user_management diff --git a/workos/client.py b/workos/client.py index 16f00c0f..e07f6db6 100644 --- a/workos/client.py +++ b/workos/client.py @@ -1,5 +1,6 @@ +from calendar import c from typing import Optional - +from workos.__about__ import __version__ from workos._base_client import BaseClient from workos.audit_logs import AuditLogs from workos.directory_sync import DirectorySync @@ -15,7 +16,7 @@ from workos.utils.http_client import SyncHTTPClient -class SyncClient(BaseClient[SyncHTTPClient]): +class SyncClient(BaseClient): """Client for a convenient way to access the WorkOS feature set.""" _http_client: SyncHTTPClient @@ -33,13 +34,19 @@ def __init__( client_id=client_id, base_url=base_url, request_timeout=request_timeout, - http_client_cls=SyncHTTPClient, + ) + self._http_client = SyncHTTPClient( + api_key=self._api_key, + base_url=self._base_url, + client_id=self._client_id, + version=__version__, + timeout=self.request_timeout, ) @property def sso(self) -> SSO: if not getattr(self, "_sso", None): - self._sso = SSO(self._http_client) + self._sso = SSO(http_client=self._http_client, client_configuration=self) return self._sso @property @@ -99,5 +106,7 @@ def mfa(self) -> Mfa: @property def user_management(self) -> UserManagement: if not getattr(self, "_user_management", None): - self._user_management = UserManagement(self._http_client) + self._user_management = UserManagement( + http_client=self._http_client, client_configuration=self + ) return self._user_management diff --git a/workos/sso.py b/workos/sso.py index c4ba7f78..096432a1 100644 --- a/workos/sso.py +++ b/workos/sso.py @@ -41,7 +41,7 @@ class ConnectionsListFilters(ListArgs, total=False): class SSOModule(Protocol): - client_configuration: ClientConfiguration + _client_configuration: ClientConfiguration def get_authorization_url( self, @@ -71,7 +71,7 @@ def get_authorization_url( str: URL to redirect a User to to begin the OAuth workflow with WorkOS """ params: QueryParameters = { - "client_id": self.client_configuration.client_id, + "client_id": self._client_configuration.client_id, "redirect_uri": redirect_uri, "response_type": RESPONSE_TYPE_CODE, } @@ -95,7 +95,7 @@ def get_authorization_url( params["state"] = state return RequestHelper.build_url_with_query_params( - base_url=self.client_configuration.base_url, + base_url=self._client_configuration.base_url, path=AUTHORIZATION_PATH, **params, ) @@ -128,8 +128,10 @@ class SSO(SSOModule): _http_client: SyncHTTPClient - def __init__(self, http_client: SyncHTTPClient): - self.client_configuration = http_client + def __init__( + self, http_client: SyncHTTPClient, client_configuration: ClientConfiguration + ): + self._client_configuration = client_configuration self._http_client = http_client def get_profile(self, access_token: str) -> Profile: @@ -257,8 +259,10 @@ class AsyncSSO(SSOModule): _http_client: AsyncHTTPClient - def __init__(self, http_client: AsyncHTTPClient): - self.client_configuration = http_client + def __init__( + self, http_client: AsyncHTTPClient, client_configuration: ClientConfiguration + ): + self._client_configuration = client_configuration self._http_client = http_client async def get_profile(self, access_token: str) -> Profile: diff --git a/workos/user_management.py b/workos/user_management.py index ffe2cfd6..d939f65a 100644 --- a/workos/user_management.py +++ b/workos/user_management.py @@ -1,4 +1,7 @@ from typing import Optional, Protocol, Set + +from httpx._client import Client +from workos._base_client import ClientConfiguration from workos.types.list_resource import ( ListArgs, ListMetadata, @@ -100,7 +103,7 @@ class UserManagementModule(Protocol): - _http_client: HTTPClient + _client_configuration: ClientConfiguration def get_user(self, user_id: str) -> SyncOrAsync[User]: ... @@ -215,7 +218,7 @@ def get_authorization_url( str: URL to redirect a User to to begin the OAuth workflow with WorkOS """ params: QueryParameters = { - "client_id": self._http_client.client_id, + "client_id": self._client_configuration.client_id, "redirect_uri": redirect_uri, "response_type": RESPONSE_TYPE_CODE, } @@ -242,7 +245,9 @@ def get_authorization_url( params["code_challenge_method"] = "S256" return RequestHelper.build_url_with_query_params( - base_url=self._http_client.base_url, path=USER_AUTHORIZATION_PATH, **params + base_url=self._client_configuration.base_url, + path=USER_AUTHORIZATION_PATH, + **params, ) def _authenticate_with( @@ -321,7 +326,7 @@ def get_jwks_url(self) -> str: (str): The public JWKS URL. """ - return f"{self._http_client.base_url}sso/jwks/{self._http_client.client_id}" + return f"{self._client_configuration.base_url}sso/jwks/{self._client_configuration.client_id}" def get_logout_url(self, session_id: str) -> str: """Get the URL for ending the session and redirecting the user @@ -333,7 +338,7 @@ def get_logout_url(self, session_id: str) -> str: (str): URL to redirect the user to to end the session. """ - return f"{self._http_client.base_url}user_management/sessions/logout?session_id={session_id}" + return f"{self._client_configuration.base_url}user_management/sessions/logout?session_id={session_id}" def get_password_reset( self, password_reset_id: str @@ -412,7 +417,10 @@ class UserManagement(UserManagementModule): _http_client: SyncHTTPClient - def __init__(self, http_client: SyncHTTPClient): + def __init__( + self, http_client: SyncHTTPClient, client_configuration: ClientConfiguration + ): + self._client_configuration = client_configuration self._http_client = http_client def get_user(self, user_id: str) -> User: @@ -1344,7 +1352,10 @@ class AsyncUserManagement(UserManagementModule): _http_client: AsyncHTTPClient - def __init__(self, http_client: AsyncHTTPClient): + def __init__( + self, http_client: AsyncHTTPClient, client_configuration: ClientConfiguration + ): + self._client_configuration = client_configuration self._http_client = http_client async def get_user(self, user_id: str) -> User: @@ -1363,6 +1374,7 @@ async def get_user(self, user_id: str) -> User: async def list_users( self, + *, email: Optional[str] = None, organization_id: Optional[str] = None, limit: int = DEFAULT_LIST_RESPONSE_LIMIT, @@ -1405,6 +1417,7 @@ async def list_users( async def create_user( self, + *, email: str, password: Optional[str] = None, password_hash: Optional[str] = None, @@ -1445,6 +1458,7 @@ async def create_user( async def update_user( self, + *, user_id: str, first_name: Optional[str] = None, last_name: Optional[str] = None, @@ -1493,7 +1507,7 @@ async def delete_user(self, user_id: str) -> None: ) async def create_organization_membership( - self, user_id: str, organization_id: str, role_slug: Optional[str] = None + self, *, user_id: str, organization_id: str, role_slug: Optional[str] = None ) -> OrganizationMembership: """Create a new OrganizationMembership for the given Organization and User. @@ -1520,7 +1534,7 @@ async def create_organization_membership( return OrganizationMembership.model_validate(response) async def update_organization_membership( - self, organization_membership_id: str, role_slug: Optional[str] = None + self, *, organization_membership_id: str, role_slug: Optional[str] = None ) -> OrganizationMembership: """Updates an OrganizationMembership for the given id. @@ -1565,6 +1579,7 @@ async def get_organization_membership( async def list_organization_memberships( self, + *, user_id: Optional[str] = None, organization_id: Optional[str] = None, statuses: Optional[Set[OrganizationMembershipStatus]] = None, @@ -1674,6 +1689,7 @@ async def _authenticate_with( async def authenticate_with_password( self, + *, email: str, password: str, ip_address: Optional[str] = None, @@ -1703,6 +1719,7 @@ async def authenticate_with_password( async def authenticate_with_code( self, + *, code: str, code_verifier: Optional[str] = None, ip_address: Optional[str] = None, @@ -1735,6 +1752,7 @@ async def authenticate_with_code( async def authenticate_with_magic_auth( self, + *, code: str, email: str, link_authorization_code: Optional[str] = None, @@ -1767,6 +1785,7 @@ async def authenticate_with_magic_auth( async def authenticate_with_email_verification( self, + *, code: str, pending_authentication_token: str, ip_address: Optional[str] = None, @@ -1796,6 +1815,7 @@ async def authenticate_with_email_verification( async def authenticate_with_totp( self, + *, code: str, authentication_challenge_id: str, pending_authentication_token: str, @@ -1828,6 +1848,7 @@ async def authenticate_with_totp( async def authenticate_with_organization_selection( self, + *, organization_id: str, pending_authentication_token: str, ip_address: Optional[str] = None, @@ -1857,6 +1878,7 @@ async def authenticate_with_organization_selection( async def authenticate_with_refresh_token( self, + *, refresh_token: str, organization_id: Optional[str] = None, ip_address: Optional[str] = None, @@ -1929,7 +1951,7 @@ async def create_password_reset(self, email: str) -> PasswordReset: return PasswordReset.model_validate(response) - async def reset_password(self, token: str, new_password: str) -> User: + async def reset_password(self, *, token: str, new_password: str) -> User: """Resets user password using token that was sent to the user. Kwargs: @@ -1987,7 +2009,7 @@ async def send_verification_email(self, user_id: str) -> User: return User.model_validate(response["user"]) - async def verify_email(self, user_id: str, code: str) -> User: + async def verify_email(self, *, user_id: str, code: str) -> User: """Verifies user email using one-time code that was sent to the user. Kwargs: @@ -2028,6 +2050,7 @@ async def get_magic_auth(self, magic_auth_id: str) -> MagicAuth: async def create_magic_auth( self, + *, email: str, invitation_token: Optional[str] = None, ) -> MagicAuth: @@ -2054,6 +2077,7 @@ async def create_magic_auth( async def enroll_auth_factor( self, + *, user_id: str, type: AuthenticationFactorType, totp_issuer: Optional[str] = None, @@ -2089,6 +2113,7 @@ async def enroll_auth_factor( async def list_auth_factors( self, + *, user_id: str, limit: int = DEFAULT_LIST_RESPONSE_LIMIT, before: Optional[str] = None, @@ -2167,6 +2192,7 @@ async def find_invitation_by_token(self, invitation_token: str) -> Invitation: async def list_invitations( self, + *, email: Optional[str] = None, organization_id: Optional[str] = None, limit: int = DEFAULT_LIST_RESPONSE_LIMIT, @@ -2209,6 +2235,7 @@ async def list_invitations( async def send_invitation( self, + *, email: str, organization_id: Optional[str] = None, expires_in_days: Optional[int] = None, From 12e77936a64d0419b4f767a8b6f020c93d931db3 Mon Sep 17 00:00:00 2001 From: Peter Arzhintar Date: Mon, 12 Aug 2024 16:34:25 -0700 Subject: [PATCH 06/11] Remove unused imports --- workos/client.py | 1 - workos/user_management.py | 1 - 2 files changed, 2 deletions(-) diff --git a/workos/client.py b/workos/client.py index e07f6db6..1781c673 100644 --- a/workos/client.py +++ b/workos/client.py @@ -1,4 +1,3 @@ -from calendar import c from typing import Optional from workos.__about__ import __version__ from workos._base_client import BaseClient diff --git a/workos/user_management.py b/workos/user_management.py index d939f65a..5f54dec4 100644 --- a/workos/user_management.py +++ b/workos/user_management.py @@ -1,6 +1,5 @@ from typing import Optional, Protocol, Set -from httpx._client import Client from workos._base_client import ClientConfiguration from workos.types.list_resource import ( ListArgs, From 6d5e159cc5c41623364f342e524145f25718ac20 Mon Sep 17 00:00:00 2001 From: Peter Arzhintar Date: Tue, 13 Aug 2024 09:51:03 -0700 Subject: [PATCH 07/11] Move enforcement of trailing slash in base URL --- workos/_base_client.py | 15 ++++++++++++++- workos/async_client.py | 2 +- workos/client.py | 2 +- workos/utils/_base_http_client.py | 14 +------------- 4 files changed, 17 insertions(+), 16 deletions(-) diff --git a/workos/_base_client.py b/workos/_base_client.py index da3e04e2..d38102f8 100644 --- a/workos/_base_client.py +++ b/workos/_base_client.py @@ -59,11 +59,12 @@ def __init__( self._client_id = client_id - self._base_url = ( + self.base_url = ( base_url if base_url else os.getenv("WORKOS_BASE_URL", "https://api.workos.com/") ) + self._request_timeout = ( request_timeout if request_timeout @@ -114,10 +115,22 @@ def user_management(self) -> UserManagementModule: ... @abstractmethod def webhooks(self) -> WebhooksModule: ... + def _enforce_trailing_slash(self, url: str) -> str: + return url if url.endswith("/") else url + "/" + @property def base_url(self) -> str: return self._base_url + @base_url.setter + def base_url(self, url: str) -> None: + """Creates an accessible template for constructing the URL for an API request. + + Args: + base_api_url (str): Base URL for api requests + """ + self._base_url = "{}{{}}".format(self._enforce_trailing_slash(url)) + @property def client_id(self) -> str: return self._client_id diff --git a/workos/async_client.py b/workos/async_client.py index 93f756d9..5908b348 100644 --- a/workos/async_client.py +++ b/workos/async_client.py @@ -36,7 +36,7 @@ def __init__( ) self._http_client = AsyncHTTPClient( api_key=self._api_key, - base_url=self._base_url, + base_url=self.base_url, client_id=self._client_id, version=__version__, timeout=self.request_timeout, diff --git a/workos/client.py b/workos/client.py index 1781c673..e51e167f 100644 --- a/workos/client.py +++ b/workos/client.py @@ -36,7 +36,7 @@ def __init__( ) self._http_client = SyncHTTPClient( api_key=self._api_key, - base_url=self._base_url, + base_url=self.base_url, client_id=self._client_id, version=__version__, timeout=self.request_timeout, diff --git a/workos/utils/_base_http_client.py b/workos/utils/_base_http_client.py index 06453c66..f75ea66f 100644 --- a/workos/utils/_base_http_client.py +++ b/workos/utils/_base_http_client.py @@ -55,14 +55,11 @@ def __init__( timeout: Optional[int] = DEFAULT_REQUEST_TIMEOUT, ) -> None: self._api_key = api_key - self.base_url = base_url + self._base_url = base_url self._client_id = client_id self._version = version self._timeout = DEFAULT_REQUEST_TIMEOUT if timeout is None else timeout - def _enforce_trailing_slash(self, url: str) -> str: - return url if url.endswith("/") else url + "/" - def _generate_api_url(self, path: str) -> str: return self._base_url.format(path) @@ -206,15 +203,6 @@ def api_key(self) -> str: def base_url(self) -> str: return self._base_url - @base_url.setter - def base_url(self, url: str) -> None: - """Creates an accessible template for constructing the URL for an API request. - - Args: - base_api_url (str): Base URL for api requests - """ - self._base_url = "{}{{}}".format(self._enforce_trailing_slash(url)) - @property def client_id(self) -> str: return self._client_id From 1074f57f14e16678061bd301f09658b640f439b7 Mon Sep 17 00:00:00 2001 From: Peter Arzhintar Date: Tue, 13 Aug 2024 10:43:17 -0700 Subject: [PATCH 08/11] Extract client config protocol and fixup tests --- tests/conftest.py | 5 +++-- tests/test_async_http_client.py | 2 +- tests/test_sso.py | 22 ++++++++++++++++--- tests/test_sync_http_client.py | 6 ++---- tests/test_user_management.py | 28 ++++++++++++++++++++---- tests/utils/client_configuration.py | 33 +++++++++++++++++++++++++++++ workos/_base_client.py | 32 +++++++--------------------- workos/_client_configuration.py | 10 +++++++++ workos/sso.py | 2 +- workos/user_management.py | 3 +-- workos/utils/_base_http_client.py | 3 ++- 11 files changed, 104 insertions(+), 42 deletions(-) create mode 100644 tests/utils/client_configuration.py create mode 100644 workos/_client_configuration.py diff --git a/tests/conftest.py b/tests/conftest.py index 5591f638..1c22d0a1 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -4,6 +4,7 @@ import httpx import pytest +from tests.utils.client_configuration import ClientConfiguration from tests.utils.list_resource import list_data_to_dicts, list_response_of from workos.types.list_resource import WorkOsListResource from workos.utils.http_client import AsyncHTTPClient, HTTPClient, SyncHTTPClient @@ -13,7 +14,7 @@ def sync_http_client_for_test(): return SyncHTTPClient( api_key="sk_test", - base_url="https://api.workos.test", + base_url="https://api.workos.test/", client_id="client_b27needthisforssotemxo", version="test", ) @@ -23,7 +24,7 @@ def sync_http_client_for_test(): def async_http_client_for_test(): return AsyncHTTPClient( api_key="sk_test", - base_url="https://api.workos.test", + base_url="https://api.workos.test/", client_id="client_b27needthisforssotemxo", version="test", ) diff --git a/tests/test_async_http_client.py b/tests/test_async_http_client.py index 78cf9e6c..6d2c01dd 100644 --- a/tests/test_async_http_client.py +++ b/tests/test_async_http_client.py @@ -20,7 +20,7 @@ def handler(request: httpx.Request) -> httpx.Response: self.http_client = AsyncHTTPClient( api_key="sk_test", - base_url="https://api.workos.test", + base_url="https://api.workos.test/", client_id="client_b27needthisforssotemxo", version="test", transport=httpx.MockTransport(handler), diff --git a/tests/test_sso.py b/tests/test_sso.py index 1182cbe1..187426f4 100644 --- a/tests/test_sso.py +++ b/tests/test_sso.py @@ -1,6 +1,7 @@ import json from six.moves.urllib.parse import parse_qsl, urlparse import pytest +from tests.utils.client_configuration import client_configuration_for_http_client from tests.utils.fixtures.mock_profile import MockProfile from tests.utils.list_resource import list_data_to_dicts, list_response_of from tests.utils.fixtures.mock_connection import MockConnection @@ -51,7 +52,12 @@ class TestSSOBase(SSOFixtures): @pytest.fixture(autouse=True) def setup(self, sync_http_client_for_test): self.http_client = sync_http_client_for_test - self.sso = SSO(http_client=self.http_client) + self.sso = SSO( + http_client=self.http_client, + client_configuration=client_configuration_for_http_client( + sync_http_client_for_test + ), + ) self.provider = "GoogleOAuth" self.customer_domain = "workos.com" self.login_hint = "foo@workos.com" @@ -214,7 +220,12 @@ class TestSSO(SSOFixtures): @pytest.fixture(autouse=True) def setup(self, sync_http_client_for_test): self.http_client = sync_http_client_for_test - self.sso = SSO(http_client=self.http_client) + self.sso = SSO( + http_client=self.http_client, + client_configuration=client_configuration_for_http_client( + sync_http_client_for_test + ), + ) self.provider = "GoogleOAuth" self.customer_domain = "workos.com" self.login_hint = "foo@workos.com" @@ -333,7 +344,12 @@ class TestAsyncSSO(SSOFixtures): @pytest.fixture(autouse=True) def setup(self, async_http_client_for_test): self.http_client = async_http_client_for_test - self.sso = AsyncSSO(http_client=self.http_client) + self.sso = AsyncSSO( + http_client=self.http_client, + client_configuration=client_configuration_for_http_client( + async_http_client_for_test + ), + ) self.provider = "GoogleOAuth" self.customer_domain = "workos.com" self.login_hint = "foo@workos.com" diff --git a/tests/test_sync_http_client.py b/tests/test_sync_http_client.py index 1a09cfeb..bb58f144 100644 --- a/tests/test_sync_http_client.py +++ b/tests/test_sync_http_client.py @@ -32,7 +32,7 @@ def handler(request: httpx.Request) -> httpx.Response: self.http_client = SyncHTTPClient( api_key="sk_test", - base_url="https://api.workos.test", + base_url="https://api.workos.test/", client_id="client_b27needthisforssotemxo", version="test", transport=httpx.MockTransport(handler), @@ -63,7 +63,6 @@ def test_request_without_body( "events", method=method, params={"test_param": "test_value"}, - token="test", ) self.http_client._client.request.assert_called_with( @@ -101,7 +100,7 @@ def test_request_with_body( ) response = self.http_client.request( - "events", method=method, json={"test_param": "test_value"}, token="test" + "events", method=method, json={"test_param": "test_value"} ) self.http_client._client.request.assert_called_with( @@ -144,7 +143,6 @@ def test_request_with_body_and_query_parameters( method=method, params={"test_param": "test_param_value"}, json={"test_json": "test_json_value"}, - token="test", ) self.http_client._client.request.assert_called_with( diff --git a/tests/test_user_management.py b/tests/test_user_management.py index de7a82a0..2817605c 100644 --- a/tests/test_user_management.py +++ b/tests/test_user_management.py @@ -1,4 +1,5 @@ import json +from os import sync from six.moves.urllib.parse import parse_qsl, urlparse import pytest @@ -11,6 +12,10 @@ from tests.utils.fixtures.mock_password_reset import MockPasswordReset from tests.utils.fixtures.mock_user import MockUser from tests.utils.list_resource import list_data_to_dicts, list_response_of +from tests.utils.client_configuration import ( + ClientConfiguration, + client_configuration_for_http_client, +) from workos.user_management import AsyncUserManagement, UserManagement from workos.utils.request_helper import RESPONSE_TYPE_CODE @@ -144,7 +149,12 @@ class TestUserManagementBase(UserManagementFixtures): @pytest.fixture(autouse=True) def setup(self, sync_http_client_for_test): self.http_client = sync_http_client_for_test - self.user_management = UserManagement(http_client=self.http_client) + self.user_management = UserManagement( + http_client=self.http_client, + client_configuration=client_configuration_for_http_client( + sync_http_client_for_test + ), + ) def test_authorization_url_throws_value_error_with_missing_connection_organization_and_provider( self, @@ -311,7 +321,12 @@ class TestUserManagement(UserManagementFixtures): @pytest.fixture(autouse=True) def setup(self, sync_http_client_for_test): self.http_client = sync_http_client_for_test - self.user_management = UserManagement(http_client=self.http_client) + self.user_management = UserManagement( + http_client=self.http_client, + client_configuration=client_configuration_for_http_client( + sync_http_client_for_test + ), + ) def test_get_user(self, mock_user, capture_and_mock_http_client_request): request_kwargs = capture_and_mock_http_client_request( @@ -946,7 +961,12 @@ class TestAsyncUserManagement(UserManagementFixtures): @pytest.fixture(autouse=True) def setup(self, async_http_client_for_test): self.http_client = async_http_client_for_test - self.user_management = AsyncUserManagement(http_client=self.http_client) + self.user_management = AsyncUserManagement( + http_client=self.http_client, + client_configuration=client_configuration_for_http_client( + async_http_client_for_test + ), + ) async def test_get_user(self, mock_user, capture_and_mock_http_client_request): request_kwargs = capture_and_mock_http_client_request( @@ -1005,7 +1025,7 @@ async def test_update_user(self, mock_user, capture_and_mock_http_client_request "password": "password", } user = await self.user_management.update_user( - "user_01H7ZGXFP5C6BBQY6Z7277ZCT0", **params + user_id="user_01H7ZGXFP5C6BBQY6Z7277ZCT0", **params ) assert request_kwargs["url"].endswith("users/user_01H7ZGXFP5C6BBQY6Z7277ZCT0") diff --git a/tests/utils/client_configuration.py b/tests/utils/client_configuration.py new file mode 100644 index 00000000..78cfac19 --- /dev/null +++ b/tests/utils/client_configuration.py @@ -0,0 +1,33 @@ +from workos._client_configuration import ( + ClientConfiguration as ClientConfigurationProtocol, +) +from workos.utils._base_http_client import BaseHTTPClient + + +class ClientConfiguration(ClientConfigurationProtocol): + def __init__(self, base_url: str, client_id: str, request_timeout: int): + self._base_url = base_url + self._client_id = client_id + self._request_timeout = request_timeout + + @property + def base_url(self) -> str: + return self._base_url + + @property + def client_id(self) -> str: + return self._client_id + + @property + def request_timeout(self) -> int: + return self._request_timeout + + +def client_configuration_for_http_client( + http_client: BaseHTTPClient, +) -> ClientConfiguration: + return ClientConfiguration( + base_url=http_client.base_url, + client_id=http_client.client_id, + request_timeout=http_client.timeout, + ) diff --git a/workos/_base_client.py b/workos/_base_client.py index d38102f8..41f31a66 100644 --- a/workos/_base_client.py +++ b/workos/_base_client.py @@ -1,8 +1,8 @@ from abc import abstractmethod import os -from typing import Optional, Protocol - +from typing import Optional from workos.__about__ import __version__ +from workos._client_configuration import ClientConfiguration from workos.fga import FGAModule from workos.utils._base_http_client import DEFAULT_REQUEST_TIMEOUT from workos.utils.http_client import HTTPClient @@ -18,15 +18,6 @@ from workos.webhooks import WebhooksModule -class ClientConfiguration(Protocol): - @property - def base_url(self) -> str: ... - @property - def client_id(self) -> str: ... - @property - def request_timeout(self) -> int: ... - - class BaseClient(ClientConfiguration): """Base client for accessing the WorkOS feature set.""" @@ -59,10 +50,12 @@ def __init__( self._client_id = client_id - self.base_url = ( - base_url - if base_url - else os.getenv("WORKOS_BASE_URL", "https://api.workos.com/") + self._base_url = self._enforce_trailing_slash( + url=( + base_url + if base_url + else os.getenv("WORKOS_BASE_URL", "https://api.workos.com/") + ) ) self._request_timeout = ( @@ -122,15 +115,6 @@ def _enforce_trailing_slash(self, url: str) -> str: def base_url(self) -> str: return self._base_url - @base_url.setter - def base_url(self, url: str) -> None: - """Creates an accessible template for constructing the URL for an API request. - - Args: - base_api_url (str): Base URL for api requests - """ - self._base_url = "{}{{}}".format(self._enforce_trailing_slash(url)) - @property def client_id(self) -> str: return self._client_id diff --git a/workos/_client_configuration.py b/workos/_client_configuration.py new file mode 100644 index 00000000..c682f83e --- /dev/null +++ b/workos/_client_configuration.py @@ -0,0 +1,10 @@ +from typing import Protocol + + +class ClientConfiguration(Protocol): + @property + def base_url(self) -> str: ... + @property + def client_id(self) -> str: ... + @property + def request_timeout(self) -> int: ... diff --git a/workos/sso.py b/workos/sso.py index 096432a1..5757fec0 100644 --- a/workos/sso.py +++ b/workos/sso.py @@ -1,5 +1,5 @@ from typing import Optional, Protocol -from workos._base_client import ClientConfiguration +from workos._client_configuration import ClientConfiguration from workos.types.sso.connection import ConnectionType from workos.types.sso.sso_provider_type import SsoProviderType from workos.typing.sync_or_async import SyncOrAsync diff --git a/workos/user_management.py b/workos/user_management.py index 5f54dec4..2593ec7b 100644 --- a/workos/user_management.py +++ b/workos/user_management.py @@ -1,6 +1,5 @@ from typing import Optional, Protocol, Set - -from workos._base_client import ClientConfiguration +from workos._client_configuration import ClientConfiguration from workos.types.list_resource import ( ListArgs, ListMetadata, diff --git a/workos/utils/_base_http_client.py b/workos/utils/_base_http_client.py index f75ea66f..ff9c127f 100644 --- a/workos/utils/_base_http_client.py +++ b/workos/utils/_base_http_client.py @@ -55,7 +55,8 @@ def __init__( timeout: Optional[int] = DEFAULT_REQUEST_TIMEOUT, ) -> None: self._api_key = api_key - self._base_url = base_url + # Template for constructing the URL for an API request + self._base_url = "{}{{}}".format(base_url) self._client_id = client_id self._version = version self._timeout = DEFAULT_REQUEST_TIMEOUT if timeout is None else timeout From 18df781c071a333d2f80c488d05a9d1e511353fc Mon Sep 17 00:00:00 2001 From: Peter Arzhintar Date: Tue, 13 Aug 2024 10:47:35 -0700 Subject: [PATCH 09/11] Remove unused import --- tests/test_user_management.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/test_user_management.py b/tests/test_user_management.py index 2817605c..43cf81f0 100644 --- a/tests/test_user_management.py +++ b/tests/test_user_management.py @@ -13,7 +13,6 @@ from tests.utils.fixtures.mock_user import MockUser from tests.utils.list_resource import list_data_to_dicts, list_response_of from tests.utils.client_configuration import ( - ClientConfiguration, client_configuration_for_http_client, ) from workos.user_management import AsyncUserManagement, UserManagement From 6933d4586ceae7e8e8f1d734a6523ce81f026e41 Mon Sep 17 00:00:00 2001 From: Peter Arzhintar Date: Tue, 13 Aug 2024 10:52:37 -0700 Subject: [PATCH 10/11] Test trailing slash --- tests/test_client.py | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/tests/test_client.py b/tests/test_client.py index bd813380..c3fef540 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -1,3 +1,4 @@ +from http import client import os import pytest from workos import AsyncWorkOSClient, WorkOSClient @@ -64,6 +65,14 @@ def test_initialize_portal(self, default_client): def test_initialize_user_management(self, default_client): assert bool(default_client.user_management) + def test_enforce_trailing_slash_for_base_url(self, default_client): + client = WorkOSClient( + api_key="sk_test", + client_id="client_b27needthisforssotemxo", + base_url="https://api.workos.com", + ) + assert client.base_url == "https://api.workos.com/" + class TestAsyncClient: @pytest.fixture From 13d145c2faad67d5528d07c8d89e04e5cad88ea6 Mon Sep 17 00:00:00 2001 From: Peter Arzhintar Date: Tue, 13 Aug 2024 10:54:51 -0700 Subject: [PATCH 11/11] Remove unused fixture --- tests/test_client.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/tests/test_client.py b/tests/test_client.py index c3fef540..0e1e868e 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -65,7 +65,9 @@ def test_initialize_portal(self, default_client): def test_initialize_user_management(self, default_client): assert bool(default_client.user_management) - def test_enforce_trailing_slash_for_base_url(self, default_client): + def test_enforce_trailing_slash_for_base_url( + self, + ): client = WorkOSClient( api_key="sk_test", client_id="client_b27needthisforssotemxo",