Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Small typing fixes #332

Merged
merged 11 commits into from
Aug 13, 2024
49 changes: 35 additions & 14 deletions workos/_base_client.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from abc import abstractmethod
import os
from typing import Generic, Optional, Type, TypeVar
from typing import Optional, Protocol

from workos.__about__ import __version__
from workos.fga import FGAModule
Expand All @@ -18,17 +18,22 @@
from workos.webhooks import WebhooksModule


HTTPClientType = TypeVar("HTTPClientType", bound=HTTPClient)
class ClientConfiguration(Protocol):
@property
def base_url(self) -> str: ...
@property
def client_id(self) -> str: ...
@property
def request_timeout(self) -> int: ...


class BaseClient(Generic[HTTPClientType]):
class BaseClient(ClientConfiguration):
"""Base client for accessing the WorkOS feature set."""

_api_key: str
_base_url: str
_client_id: str
_request_timeout: int
_http_client: HTTPClient

def __init__(
self,
Expand All @@ -37,7 +42,6 @@ def __init__(
client_id: Optional[str],
base_url: Optional[str] = None,
request_timeout: Optional[int] = None,
http_client_cls: Type[HTTPClientType],
) -> None:
api_key = api_key or os.getenv("WORKOS_API_KEY")
if api_key is None:
Expand All @@ -55,25 +59,18 @@ def __init__(

self._client_id = client_id

self._base_url = (
self.base_url = (
base_url
if base_url
else os.getenv("WORKOS_BASE_URL", "https://api.workos.com/")
)

self._request_timeout = (
request_timeout
if request_timeout
else int(os.getenv("WORKOS_REQUEST_TIMEOUT", DEFAULT_REQUEST_TIMEOUT))
)

self._http_client = http_client_cls(
api_key=self._api_key,
base_url=self._base_url,
client_id=self._client_id,
version=__version__,
timeout=self._request_timeout,
)

@property
@abstractmethod
def audit_logs(self) -> AuditLogsModule: ...
Expand Down Expand Up @@ -117,3 +114,27 @@ def user_management(self) -> UserManagementModule: ...
@property
@abstractmethod
def webhooks(self) -> WebhooksModule: ...

def _enforce_trailing_slash(self, url: str) -> str:
return url if url.endswith("/") else url + "/"

@property
def base_url(self) -> str:
return self._base_url
Comment on lines +114 to +116
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In base HTTP client we have a setter that enforces a trailing slash. Would be good to make sure these values are the same and perhaps enforce the trailing slash here instead:

@base_url.setter
def base_url(self, url: str) -> None:
"""Creates an accessible template for constructing the URL for an API request.
Args:
base_api_url (str): Base URL for api requests
"""
self._base_url = "{}{{}}".format(self._enforce_trailing_slash(url))


@base_url.setter
def base_url(self, url: str) -> None:
"""Creates an accessible template for constructing the URL for an API request.

Args:
base_api_url (str): Base URL for api requests
"""
self._base_url = "{}{{}}".format(self._enforce_trailing_slash(url))

@property
def client_id(self) -> str:
return self._client_id

@property
def request_timeout(self) -> int:
return self._request_timeout
20 changes: 15 additions & 5 deletions workos/async_client.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from typing import Optional

from workos.__about__ import __version__
from workos._base_client import BaseClient
from workos.audit_logs import AuditLogsModule
from workos.directory_sync import AsyncDirectorySync
Expand All @@ -15,7 +15,7 @@
from workos.webhooks import WebhooksModule


class AsyncClient(BaseClient[AsyncHTTPClient]):
class AsyncClient(BaseClient):
"""Client for a convenient way to access the WorkOS feature set."""

_http_client: AsyncHTTPClient
Expand All @@ -33,13 +33,21 @@ def __init__(
client_id=client_id,
base_url=base_url,
request_timeout=request_timeout,
http_client_cls=AsyncHTTPClient,
)
self._http_client = AsyncHTTPClient(
api_key=self._api_key,
base_url=self.base_url,
client_id=self._client_id,
version=__version__,
timeout=self.request_timeout,
)

@property
def sso(self) -> AsyncSSO:
if not getattr(self, "_sso", None):
self._sso = AsyncSSO(self._http_client)
self._sso = AsyncSSO(
http_client=self._http_client, client_configuration=self
)
return self._sso

@property
Expand Down Expand Up @@ -93,5 +101,7 @@ def mfa(self) -> MFAModule:
@property
def user_management(self) -> AsyncUserManagement:
if not getattr(self, "_user_management", None):
self._user_management = AsyncUserManagement(self._http_client)
self._user_management = AsyncUserManagement(
http_client=self._http_client, client_configuration=self
)
return self._user_management
18 changes: 13 additions & 5 deletions workos/client.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from typing import Optional

from workos.__about__ import __version__
from workos._base_client import BaseClient
from workos.audit_logs import AuditLogs
from workos.directory_sync import DirectorySync
Expand All @@ -15,7 +15,7 @@
from workos.utils.http_client import SyncHTTPClient


class SyncClient(BaseClient[SyncHTTPClient]):
class SyncClient(BaseClient):
"""Client for a convenient way to access the WorkOS feature set."""

_http_client: SyncHTTPClient
Expand All @@ -33,13 +33,19 @@ def __init__(
client_id=client_id,
base_url=base_url,
request_timeout=request_timeout,
http_client_cls=SyncHTTPClient,
)
self._http_client = SyncHTTPClient(
api_key=self._api_key,
base_url=self.base_url,
client_id=self._client_id,
version=__version__,
timeout=self.request_timeout,
)

@property
def sso(self) -> SSO:
if not getattr(self, "_sso", None):
self._sso = SSO(self._http_client)
self._sso = SSO(http_client=self._http_client, client_configuration=self)
return self._sso

@property
Expand Down Expand Up @@ -99,5 +105,7 @@ def mfa(self) -> Mfa:
@property
def user_management(self) -> UserManagement:
if not getattr(self, "_user_management", None):
self._user_management = UserManagement(self._http_client)
self._user_management = UserManagement(
http_client=self._http_client, client_configuration=self
)
return self._user_management
8 changes: 4 additions & 4 deletions workos/directory_sync.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,11 +60,11 @@ def list_groups(
order: PaginationOrder = "desc",
) -> SyncOrAsync[DirectoryGroupsListResource]: ...

def get_user(self, user: str) -> SyncOrAsync[DirectoryUserWithGroups]: ...
def get_user(self, user_id: str) -> SyncOrAsync[DirectoryUserWithGroups]: ...

def get_group(self, group: str) -> SyncOrAsync[DirectoryGroup]: ...
def get_group(self, group_id: str) -> SyncOrAsync[DirectoryGroup]: ...

def get_directory(self, directory: str) -> SyncOrAsync[Directory]: ...
def get_directory(self, directory_id: str) -> SyncOrAsync[Directory]: ...

def list_directories(
self,
Expand All @@ -77,7 +77,7 @@ def list_directories(
order: PaginationOrder = "desc",
) -> SyncOrAsync[DirectoriesListResource]: ...

def delete_directory(self, directory: str) -> SyncOrAsync[None]: ...
def delete_directory(self, directory_id: str) -> SyncOrAsync[None]: ...


class DirectorySync(DirectorySyncModule):
Expand Down
28 changes: 20 additions & 8 deletions workos/sso.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from typing import Optional, Protocol
from workos._base_client import ClientConfiguration
from workos.types.sso.connection import ConnectionType
from workos.types.sso.sso_provider_type import SsoProviderType
from workos.typing.sync_or_async import SyncOrAsync
Expand Down Expand Up @@ -40,7 +41,7 @@ class ConnectionsListFilters(ListArgs, total=False):


class SSOModule(Protocol):
_http_client: HTTPClient
_client_configuration: ClientConfiguration

def get_authorization_url(
self,
Expand Down Expand Up @@ -70,7 +71,7 @@ def get_authorization_url(
str: URL to redirect a User to to begin the OAuth workflow with WorkOS
"""
params: QueryParameters = {
"client_id": self._http_client.client_id,
"client_id": self._client_configuration.client_id,
"redirect_uri": redirect_uri,
"response_type": RESPONSE_TYPE_CODE,
}
Expand All @@ -94,10 +95,12 @@ def get_authorization_url(
params["state"] = state

return RequestHelper.build_url_with_query_params(
base_url=self._http_client.base_url, path=AUTHORIZATION_PATH, **params
base_url=self._client_configuration.base_url,
path=AUTHORIZATION_PATH,
**params,
)

def get_profile(self, accessToken: str) -> SyncOrAsync[Profile]: ...
def get_profile(self, access_token: str) -> SyncOrAsync[Profile]: ...

def get_profile_and_token(self, code: str) -> SyncOrAsync[ProfileAndToken]: ...

Expand All @@ -117,15 +120,18 @@ def list_connections(
order: PaginationOrder = "desc",
) -> SyncOrAsync[ConnectionsListResource]: ...

def delete_connection(self, connection: str) -> SyncOrAsync[None]: ...
def delete_connection(self, connection_id: str) -> SyncOrAsync[None]: ...


class SSO(SSOModule):
"""Offers methods to assist in authenticating through the WorkOS SSO service."""

_http_client: SyncHTTPClient

def __init__(self, http_client: SyncHTTPClient):
def __init__(
self, http_client: SyncHTTPClient, client_configuration: ClientConfiguration
):
self._client_configuration = client_configuration
self._http_client = http_client

def get_profile(self, access_token: str) -> Profile:
Expand All @@ -139,7 +145,10 @@ def get_profile(self, access_token: str) -> Profile:
Profile
"""
response = self._http_client.request(
PROFILE_PATH, method=REQUEST_METHOD_GET, token=access_token
PROFILE_PATH,
method=REQUEST_METHOD_GET,
headers={**self._http_client.auth_header_from_token(access_token)},
exclude_default_auth_headers=True,
)

return Profile.model_validate(response)
Expand Down Expand Up @@ -250,7 +259,10 @@ class AsyncSSO(SSOModule):

_http_client: AsyncHTTPClient

def __init__(self, http_client: AsyncHTTPClient):
def __init__(
self, http_client: AsyncHTTPClient, client_configuration: ClientConfiguration
):
self._client_configuration = client_configuration
self._http_client = http_client

async def get_profile(self, access_token: str) -> Profile:
Expand Down
Loading
Loading