Skip to content

Commit

Permalink
Do not shadow http client attribute, extract a protocol for client co…
Browse files Browse the repository at this point in the history
…nfig
  • Loading branch information
tribble committed Aug 13, 2024
1 parent 3d853a2 commit 045fe75
Show file tree
Hide file tree
Showing 5 changed files with 86 additions and 43 deletions.
23 changes: 8 additions & 15 deletions workos/_base_client.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from abc import abstractmethod
import os
from typing import Generic, Optional, Protocol, Type, TypeVar
from typing import Optional, Protocol

from workos.__about__ import __version__
from workos.fga import FGAModule
Expand All @@ -18,24 +18,22 @@
from workos.webhooks import WebhooksModule


HTTPClientType = TypeVar("HTTPClientType", bound=HTTPClient)


class ClientConfiguration(Protocol):
@property
def base_url(self) -> str: ...
@property
def client_id(self) -> str: ...
@property
def request_timeout(self) -> int: ...


class BaseClient(Generic[HTTPClientType], ClientConfiguration):
class BaseClient(ClientConfiguration):
"""Base client for accessing the WorkOS feature set."""

_api_key: str
_base_url: str
_client_id: str
_request_timeout: int
_http_client: HTTPClient

def __init__(
self,
Expand All @@ -44,7 +42,6 @@ def __init__(
client_id: Optional[str],
base_url: Optional[str] = None,
request_timeout: Optional[int] = None,
http_client_cls: Type[HTTPClientType],
) -> None:
api_key = api_key or os.getenv("WORKOS_API_KEY")
if api_key is None:
Expand Down Expand Up @@ -73,14 +70,6 @@ def __init__(
else int(os.getenv("WORKOS_REQUEST_TIMEOUT", DEFAULT_REQUEST_TIMEOUT))
)

self._http_client = http_client_cls(
api_key=self._api_key,
base_url=self._base_url,
client_id=self._client_id,
version=__version__,
timeout=self._request_timeout,
)

@property
@abstractmethod
def audit_logs(self) -> AuditLogsModule: ...
Expand Down Expand Up @@ -132,3 +121,7 @@ def base_url(self) -> str:
@property
def client_id(self) -> str:
return self._client_id

@property
def request_timeout(self) -> int:
return self._request_timeout
20 changes: 15 additions & 5 deletions workos/async_client.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from typing import Optional

from workos.__about__ import __version__
from workos._base_client import BaseClient
from workos.audit_logs import AuditLogsModule
from workos.directory_sync import AsyncDirectorySync
Expand All @@ -15,7 +15,7 @@
from workos.webhooks import WebhooksModule


class AsyncClient(BaseClient[AsyncHTTPClient]):
class AsyncClient(BaseClient):
"""Client for a convenient way to access the WorkOS feature set."""

_http_client: AsyncHTTPClient
Expand All @@ -33,13 +33,21 @@ def __init__(
client_id=client_id,
base_url=base_url,
request_timeout=request_timeout,
http_client_cls=AsyncHTTPClient,
)
self._http_client = AsyncHTTPClient(
api_key=self._api_key,
base_url=self._base_url,
client_id=self._client_id,
version=__version__,
timeout=self.request_timeout,
)

@property
def sso(self) -> AsyncSSO:
if not getattr(self, "_sso", None):
self._sso = AsyncSSO(self._http_client)
self._sso = AsyncSSO(
http_client=self._http_client, client_configuration=self
)
return self._sso

@property
Expand Down Expand Up @@ -93,5 +101,7 @@ def mfa(self) -> MFAModule:
@property
def user_management(self) -> AsyncUserManagement:
if not getattr(self, "_user_management", None):
self._user_management = AsyncUserManagement(self._http_client)
self._user_management = AsyncUserManagement(
http_client=self._http_client, client_configuration=self
)
return self._user_management
19 changes: 14 additions & 5 deletions workos/client.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from calendar import c
from typing import Optional

from workos.__about__ import __version__
from workos._base_client import BaseClient
from workos.audit_logs import AuditLogs
from workos.directory_sync import DirectorySync
Expand All @@ -15,7 +16,7 @@
from workos.utils.http_client import SyncHTTPClient


class SyncClient(BaseClient[SyncHTTPClient]):
class SyncClient(BaseClient):
"""Client for a convenient way to access the WorkOS feature set."""

_http_client: SyncHTTPClient
Expand All @@ -33,13 +34,19 @@ def __init__(
client_id=client_id,
base_url=base_url,
request_timeout=request_timeout,
http_client_cls=SyncHTTPClient,
)
self._http_client = SyncHTTPClient(
api_key=self._api_key,
base_url=self._base_url,
client_id=self._client_id,
version=__version__,
timeout=self.request_timeout,
)

@property
def sso(self) -> SSO:
if not getattr(self, "_sso", None):
self._sso = SSO(self._http_client)
self._sso = SSO(http_client=self._http_client, client_configuration=self)
return self._sso

@property
Expand Down Expand Up @@ -99,5 +106,7 @@ def mfa(self) -> Mfa:
@property
def user_management(self) -> UserManagement:
if not getattr(self, "_user_management", None):
self._user_management = UserManagement(self._http_client)
self._user_management = UserManagement(
http_client=self._http_client, client_configuration=self
)
return self._user_management
18 changes: 11 additions & 7 deletions workos/sso.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ class ConnectionsListFilters(ListArgs, total=False):


class SSOModule(Protocol):
client_configuration: ClientConfiguration
_client_configuration: ClientConfiguration

def get_authorization_url(
self,
Expand Down Expand Up @@ -71,7 +71,7 @@ def get_authorization_url(
str: URL to redirect a User to to begin the OAuth workflow with WorkOS
"""
params: QueryParameters = {
"client_id": self.client_configuration.client_id,
"client_id": self._client_configuration.client_id,
"redirect_uri": redirect_uri,
"response_type": RESPONSE_TYPE_CODE,
}
Expand All @@ -95,7 +95,7 @@ def get_authorization_url(
params["state"] = state

return RequestHelper.build_url_with_query_params(
base_url=self.client_configuration.base_url,
base_url=self._client_configuration.base_url,
path=AUTHORIZATION_PATH,
**params,
)
Expand Down Expand Up @@ -128,8 +128,10 @@ class SSO(SSOModule):

_http_client: SyncHTTPClient

def __init__(self, http_client: SyncHTTPClient):
self.client_configuration = http_client
def __init__(
self, http_client: SyncHTTPClient, client_configuration: ClientConfiguration
):
self._client_configuration = client_configuration
self._http_client = http_client

def get_profile(self, access_token: str) -> Profile:
Expand Down Expand Up @@ -257,8 +259,10 @@ class AsyncSSO(SSOModule):

_http_client: AsyncHTTPClient

def __init__(self, http_client: AsyncHTTPClient):
self.client_configuration = http_client
def __init__(
self, http_client: AsyncHTTPClient, client_configuration: ClientConfiguration
):
self._client_configuration = client_configuration
self._http_client = http_client

async def get_profile(self, access_token: str) -> Profile:
Expand Down
Loading

0 comments on commit 045fe75

Please sign in to comment.