diff --git a/tests/conftest.py b/tests/conftest.py index fddf8fff..bf96ee24 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -128,7 +128,7 @@ def mock(*args, **kwargs): def mock_http_client_with_response(monkeypatch): def inner( http_client: Union[SyncHTTPClient, AsyncHTTPClient], - response_dict: dict, + response_dict: Optional[dict] = None, status_code: int = 200, headers: Optional[Mapping[str, str]] = None, ): @@ -145,6 +145,39 @@ def inner( return inner +@pytest.fixture +def capture_and_mock_http_client_request(monkeypatch): + def inner( + http_client: Union[SyncHTTPClient, AsyncHTTPClient], + response_dict: dict, + status_code: int = 200, + headers: Optional[Mapping[str, str]] = None, + ): + request_args = [] + request_kwargs = {} + + def capture_and_mock(*args, **kwargs): + request_args.extend(args) + request_kwargs.update(kwargs) + + return httpx.Response( + status_code=status_code, + headers=headers, + json=response_dict, + ) + + mock_class = ( + AsyncMock if isinstance(http_client, AsyncHTTPClient) else MagicMock + ) + mock = mock_class(side_effect=capture_and_mock) + + monkeypatch.setattr(http_client._client, "request", mock) + + return (request_args, request_kwargs) + + return inner + + @pytest.fixture def mock_pagination_request_for_http_client(monkeypatch): # Mocking pagination correctly requires us to index into a list of data diff --git a/tests/test_client.py b/tests/test_client.py index d7eabea9..177b86a2 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -171,3 +171,38 @@ def test_initialize_events_missing_api_key(self): message = str(ex) assert "api_key" in message + + def test_initialize_sso(self, set_api_key_and_client_id): + assert bool(async_client.sso) + + def test_initialize_sso_missing_api_key(self, set_client_id): + with pytest.raises(ConfigurationException) as ex: + async_client.sso + + message = str(ex) + + assert "api_key" in message + assert "client_id" not in message + + def test_initialize_sso_missing_client_id(self, set_api_key): + with pytest.raises(ConfigurationException) as ex: + async_client.sso + + message = str(ex) + + assert "client_id" in message + assert "api_key" not in message + + def test_initialize_sso_missing_api_key_and_client_id(self): + with pytest.raises(ConfigurationException) as ex: + async_client.sso + + message = str(ex) + + assert all( + setting in message + for setting in ( + "api_key", + "client_id", + ) + ) diff --git a/tests/test_directory_sync.py b/tests/test_directory_sync.py index dcdc6cce..5b36aa94 100644 --- a/tests/test_directory_sync.py +++ b/tests/test_directory_sync.py @@ -213,7 +213,6 @@ def test_delete_directory(self, mock_http_client_with_response): mock_http_client_with_response( http_client=self.http_client, status_code=202, - response_dict=None, headers={"content-type": "text/plain; charset=utf-8"}, ) @@ -450,7 +449,6 @@ async def test_delete_directory(self, mock_http_client_with_response): mock_http_client_with_response( http_client=self.http_client, status_code=202, - response_dict=None, headers={"content-type": "text/plain; charset=utf-8"}, ) diff --git a/tests/test_organizations.py b/tests/test_organizations.py index e559cb33..3eb7bb81 100644 --- a/tests/test_organizations.py +++ b/tests/test_organizations.py @@ -1,11 +1,8 @@ import datetime -from typing import Dict, List, Union, cast import pytest -import requests -from tests.conftest import MockResponse -from tests.utils.list_resource import list_data_to_dicts, list_response_of +from tests.utils.list_resource import list_data_to_dicts from workos.organizations import Organizations from tests.utils.fixtures.mock_organization import MockOrganization diff --git a/tests/test_sso.py b/tests/test_sso.py index 70649fd7..d5bac063 100644 --- a/tests/test_sso.py +++ b/tests/test_sso.py @@ -1,30 +1,22 @@ import json + from six.moves.urllib.parse import parse_qsl, urlparse import pytest + +from tests.utils.list_resource import list_data_to_dicts, list_response_of import workos -from workos.sso import SSO -from workos.utils.connection_types import ConnectionType -from workos.utils.sso_provider_types import SsoProviderType +from workos.sso import SSO, AsyncSSO +from workos.resources.sso import SsoProviderType +from workos.utils.http_client import AsyncHTTPClient, SyncHTTPClient from workos.utils.request import RESPONSE_TYPE_CODE from tests.utils.fixtures.mock_connection import MockConnection -class TestSSO(object): - @pytest.fixture - def setup_with_client_id(self, set_api_key_and_client_id): - self.sso = SSO() - self.provider = SsoProviderType.GoogleOAuth - self.customer_domain = "workos.com" - self.login_hint = "foo@workos.com" - self.redirect_uri = "https://localhost/auth/callback" - self.state = json.dumps({"things": "with_stuff"}) - self.connection = "connection_123" - self.organization = "organization_123" - self.setup_completed = True - +class SSOFixtures: @pytest.fixture def mock_profile(self): return { + "object": "profile", "id": "prof_01DWAS7ZQWM70PV93BFV1V78QV", "email": "demo@workos-okta.com", "first_name": "WorkOS", @@ -45,6 +37,7 @@ def mock_profile(self): @pytest.fixture def mock_magic_link_profile(self): return { + "object": "profile", "id": "prof_01DWAS7ZQWM70PV93BFV1V78QV", "email": "demo@workos-magic-link.com", "organization_id": None, @@ -63,221 +56,64 @@ def mock_connection(self): @pytest.fixture def mock_connections(self): - connection_list = [MockConnection(id=str(i)).to_dict() for i in range(5000)] - - return { - "data": connection_list, - "list_metadata": {"before": None, "after": None}, - "metadata": { - "params": { - "domains": None, - "limit": 4, - "before": None, - "after": None, - "order": None, - "default_limit": True, - }, - "method": SSO.list_connections, - }, - } - - @pytest.fixture - def mock_connections_with_limit(self): - connection_list = [MockConnection(id=str(i)).to_dict() for i in range(4)] - - return { - "data": connection_list, - "list_metadata": {"before": None, "after": None}, - "metadata": { - "params": { - "connection_type": None, - "domain": None, - "organization_id": None, - "limit": 4, - "before": None, - "after": None, - "order": None, - }, - "method": SSO.list_connections, - }, - } - - @pytest.fixture - def mock_connections_with_limit_v2(self, set_api_key_and_client_id): - connection_list = [MockConnection(id=str(i)).to_dict() for i in range(4)] - - dict_response = { - "data": connection_list, - "list_metadata": {"before": None, "after": None}, - "metadata": { - "params": { - "connection_type": None, - "domain": None, - "organization_id": None, - "limit": 4, - "before": None, - "after": None, - "order": None, - }, - "method": SSO.list_connections_v2, - }, - } - return SSO.construct_from_response(dict_response) - - @pytest.fixture - def mock_connections_with_default_limit(self): connection_list = [MockConnection(id=str(i)).to_dict() for i in range(10)] - return { - "data": connection_list, - "list_metadata": {"before": None, "after": "conn_xxx"}, - "metadata": { - "params": { - "connection_type": None, - "domain": None, - "organization_id": None, - "limit": 4, - "before": None, - "after": None, - "order": None, - "default_limit": True, - }, - "method": SSO.list_connections, - }, - } + return list_response_of(data=connection_list) @pytest.fixture - def mock_connections_with_default_limit_v2(self, setup_with_client_id): - connection_list = [MockConnection(id=str(i)).to_dict() for i in range(10)] + def mock_connections_multiple_data_pages(self): + return [MockConnection(id=str(i)).to_dict() for i in range(40)] - dict_response = { - "data": connection_list, - "list_metadata": {"before": None, "after": "conn_xxx"}, - "metadata": { - "params": { - "connection_type": None, - "domain": None, - "organization_id": None, - "limit": 4, - "before": None, - "after": None, - "order": None, - "default_limit": True, - }, - "method": SSO.list_connections_v2, - }, - } - return self.sso.construct_from_response(dict_response) - @pytest.fixture - def mock_connections_pagination_response(self): - connection_list = [MockConnection(id=str(i)).to_dict() for i in range(4990)] +class TestSSOBase(SSOFixtures): + provider: SsoProviderType - return { - "data": connection_list, - "list_metadata": {"before": None, "after": None}, - "metadata": { - "params": { - "connection_type": None, - "domain": None, - "organization_id": None, - "limit": None, - "before": None, - "after": None, - "order": None, - "default_limit": True, - }, - "method": SSO.list_connections, - }, - } + @pytest.fixture(autouse=True) + def setup(self, set_api_key_and_client_id): + self.http_client = SyncHTTPClient( + base_url="https://api.workos.test", version="test" + ) + self.sso = SSO(http_client=self.http_client) + self.provider = "GoogleOAuth" + self.customer_domain = "workos.com" + self.login_hint = "foo@workos.com" + self.redirect_uri = "https://localhost/auth/callback" + self.authorization_state = json.dumps({"things": "with_stuff"}) + self.connection_id = "connection_123" + self.organization_id = "organization_123" + self.setup_completed = True - def test_authorization_url_throws_value_error_with_missing_connection_domain_and_provider( - self, setup_with_client_id + def test_authorization_url_throws_value_error_with_missing_connection_organization_and_provider( + self, ): with pytest.raises(ValueError, match=r"Incomplete arguments.*"): self.sso.get_authorization_url( - redirect_uri=self.redirect_uri, state=self.state - ) - - @pytest.mark.parametrize( - "invalid_provider", - [ - 123, - SsoProviderType, - True, - False, - {"provider": "GoogleOAuth"}, - ["GoogleOAuth"], - ], - ) - def test_authorization_url_throws_value_error_with_incorrect_provider_type( - self, setup_with_client_id, invalid_provider - ): - with pytest.raises( - ValueError, match="'provider' must be of type SsoProviderType" - ): - self.sso.get_authorization_url( - provider=invalid_provider, - redirect_uri=self.redirect_uri, - state=self.state, - ) - - def test_authorization_url_throws_value_error_without_redirect_uri( - self, setup_with_client_id - ): - with pytest.raises( - ValueError, match="Incomplete arguments. Need to specify a 'redirect_uri'." - ): - self.sso.get_authorization_url( - connection=self.connection, - login_hint=self.login_hint, - state=self.state, + redirect_uri=self.redirect_uri, state=self.authorization_state ) - def test_authorization_url_has_expected_query_params_with_provider( - self, setup_with_client_id - ): + def test_authorization_url_has_expected_query_params_with_provider(self): authorization_url = self.sso.get_authorization_url( - provider=self.provider, redirect_uri=self.redirect_uri, state=self.state - ) - - parsed_url = urlparse(authorization_url) - - assert dict(parse_qsl(parsed_url.query)) == { - "provider": str(self.provider.value), - "client_id": workos.client_id, - "redirect_uri": self.redirect_uri, - "response_type": RESPONSE_TYPE_CODE, - "state": self.state, - } - - def test_authorization_url_has_expected_query_params_with_domain( - self, setup_with_client_id - ): - authorization_url = self.sso.get_authorization_url( - domain=self.customer_domain, + provider=self.provider, redirect_uri=self.redirect_uri, - state=self.state, + state=self.authorization_state, ) parsed_url = urlparse(authorization_url) assert dict(parse_qsl(parsed_url.query)) == { - "domain": self.customer_domain, + "provider": self.provider, "client_id": workos.client_id, "redirect_uri": self.redirect_uri, "response_type": RESPONSE_TYPE_CODE, - "state": self.state, + "state": self.authorization_state, } - def test_authorization_url_has_expected_query_params_with_domain_hint( - self, setup_with_client_id - ): + def test_authorization_url_has_expected_query_params_with_domain_hint(self): authorization_url = self.sso.get_authorization_url( - connection=self.connection, + connection_id=self.connection_id, domain_hint=self.customer_domain, redirect_uri=self.redirect_uri, - state=self.state, + state=self.authorization_state, ) parsed_url = urlparse(authorization_url) @@ -286,19 +122,17 @@ def test_authorization_url_has_expected_query_params_with_domain_hint( "domain_hint": self.customer_domain, "client_id": workos.client_id, "redirect_uri": self.redirect_uri, - "connection": self.connection, + "connection": self.connection_id, "response_type": RESPONSE_TYPE_CODE, - "state": self.state, + "state": self.authorization_state, } - def test_authorization_url_has_expected_query_params_with_login_hint( - self, setup_with_client_id - ): + def test_authorization_url_has_expected_query_params_with_login_hint(self): authorization_url = self.sso.get_authorization_url( - connection=self.connection, + connection_id=self.connection_id, login_hint=self.login_hint, redirect_uri=self.redirect_uri, - state=self.state, + state=self.authorization_state, ) parsed_url = urlparse(authorization_url) @@ -307,93 +141,108 @@ def test_authorization_url_has_expected_query_params_with_login_hint( "login_hint": self.login_hint, "client_id": workos.client_id, "redirect_uri": self.redirect_uri, - "connection": self.connection, + "connection": self.connection_id, "response_type": RESPONSE_TYPE_CODE, - "state": self.state, + "state": self.authorization_state, } - def test_authorization_url_has_expected_query_params_with_connection( - self, setup_with_client_id - ): + def test_authorization_url_has_expected_query_params_with_connection(self): authorization_url = self.sso.get_authorization_url( - connection=self.connection, + connection_id=self.connection_id, redirect_uri=self.redirect_uri, - state=self.state, + state=self.authorization_state, ) parsed_url = urlparse(authorization_url) assert dict(parse_qsl(parsed_url.query)) == { - "connection": self.connection, + "connection": self.connection_id, "client_id": workos.client_id, "redirect_uri": self.redirect_uri, "response_type": RESPONSE_TYPE_CODE, - "state": self.state, + "state": self.authorization_state, } def test_authorization_url_with_string_provider_has_expected_query_params_with_organization( - self, setup_with_client_id + self, ): authorization_url = self.sso.get_authorization_url( provider=self.provider, - organization=self.organization, + organization_id=self.organization_id, redirect_uri=self.redirect_uri, - state=self.state, + state=self.authorization_state, ) parsed_url = urlparse(authorization_url) assert dict(parse_qsl(parsed_url.query)) == { - "organization": self.organization, - "provider": self.provider.value, + "organization": self.organization_id, + "provider": self.provider, "client_id": workos.client_id, "redirect_uri": self.redirect_uri, "response_type": RESPONSE_TYPE_CODE, - "state": self.state, + "state": self.authorization_state, } - def test_authorization_url_has_expected_query_params_with_organization( - self, setup_with_client_id - ): + def test_authorization_url_has_expected_query_params_with_organization(self): authorization_url = self.sso.get_authorization_url( - organization=self.organization, + organization_id=self.organization_id, redirect_uri=self.redirect_uri, - state=self.state, + state=self.authorization_state, ) parsed_url = urlparse(authorization_url) assert dict(parse_qsl(parsed_url.query)) == { - "organization": self.organization, + "organization": self.organization_id, "client_id": workos.client_id, "redirect_uri": self.redirect_uri, "response_type": RESPONSE_TYPE_CODE, - "state": self.state, + "state": self.authorization_state, } - def test_authorization_url_has_expected_query_params_with_domain_and_provider( - self, setup_with_client_id + def test_authorization_url_has_expected_query_params_with_organization_and_provider( + self, ): authorization_url = self.sso.get_authorization_url( - domain=self.customer_domain, + organization_id=self.organization_id, provider=self.provider, redirect_uri=self.redirect_uri, - state=self.state, + state=self.authorization_state, ) parsed_url = urlparse(authorization_url) assert dict(parse_qsl(parsed_url.query)) == { - "domain": self.customer_domain, - "provider": str(self.provider.value), + "organization": self.organization_id, + "provider": self.provider, "client_id": workos.client_id, "redirect_uri": self.redirect_uri, "response_type": RESPONSE_TYPE_CODE, - "state": self.state, + "state": self.authorization_state, } - def test_get_profile_and_token_returns_expected_workosprofile_object( - self, setup_with_client_id, mock_profile, mock_request_method + +class TestSSO(SSOFixtures): + provider: SsoProviderType + + @pytest.fixture(autouse=True) + def setup(self, set_api_key_and_client_id): + self.http_client = SyncHTTPClient( + base_url="https://api.workos.test", version="test" + ) + self.sso = SSO(http_client=self.http_client) + self.provider = "GoogleOAuth" + self.customer_domain = "workos.com" + self.login_hint = "foo@workos.com" + self.redirect_uri = "https://localhost/auth/callback" + self.state = json.dumps({"things": "with_stuff"}) + self.connection_id = "connection_123" + self.organization_id = "organization_123" + self.setup_completed = True + + def test_get_profile_and_token_returns_expected_profile_object( + self, mock_profile, mock_http_client_with_response ): response_dict = { "profile": { @@ -417,190 +266,237 @@ def test_get_profile_and_token_returns_expected_workosprofile_object( "access_token": "01DY34ACQTM3B1CSX1YSZ8Z00D", } - mock_request_method("post", response_dict, 200) + mock_http_client_with_response(self.http_client, response_dict, 200) - profile_and_token = self.sso.get_profile_and_token(123) + profile_and_token = self.sso.get_profile_and_token("123") assert profile_and_token.access_token == "01DY34ACQTM3B1CSX1YSZ8Z00D" - assert profile_and_token.profile.to_dict() == mock_profile + assert profile_and_token.profile.dict() == mock_profile - def test_get_profile_and_token_without_first_name_or_last_name_returns_expected_workosprofile_object( - self, setup_with_client_id, mock_magic_link_profile, mock_request_method + def test_get_profile_and_token_without_first_name_or_last_name_returns_expected_profile_object( + self, mock_magic_link_profile, mock_http_client_with_response ): response_dict = { - "profile": { - "object": "profile", - "id": mock_magic_link_profile["id"], - "email": mock_magic_link_profile["email"], - "organization_id": mock_magic_link_profile["organization_id"], - "connection_id": mock_magic_link_profile["connection_id"], - "connection_type": mock_magic_link_profile["connection_type"], - "idp_id": "", - "raw_attributes": {}, - }, + "profile": mock_magic_link_profile, "access_token": "01DY34ACQTM3B1CSX1YSZ8Z00D", } - mock_request_method("post", response_dict, 200) + mock_http_client_with_response(self.http_client, response_dict, 200) - profile_and_token = self.sso.get_profile_and_token(123) + profile_and_token = self.sso.get_profile_and_token("123") assert profile_and_token.access_token == "01DY34ACQTM3B1CSX1YSZ8Z00D" - assert profile_and_token.profile.to_dict() == mock_magic_link_profile + assert profile_and_token.profile.dict() == mock_magic_link_profile - def test_get_profile(self, setup_with_client_id, mock_profile, mock_request_method): - mock_request_method("get", mock_profile, 200) + def test_get_profile(self, mock_profile, mock_http_client_with_response): + mock_http_client_with_response(self.http_client, mock_profile, 200) - profile = self.sso.get_profile(123) + profile = self.sso.get_profile("123") - assert profile.to_dict() == mock_profile + assert profile.dict() == mock_profile - def test_get_connection( - self, setup_with_client_id, mock_connection, mock_request_method - ): - mock_request_method("get", mock_connection, 200) + def test_get_connection(self, mock_connection, mock_http_client_with_response): + mock_http_client_with_response(self.http_client, mock_connection, 200) - connection = self.sso.get_connection(connection="connection_id") + connection = self.sso.get_connection(connection_id="connection_id") - assert connection == mock_connection + assert connection.dict() == mock_connection - def test_list_connections( - self, setup_with_client_id, mock_connections, mock_request_method - ): - mock_request_method("get", mock_connections, 200) + def test_list_connections(self, mock_connections, mock_http_client_with_response): + mock_http_client_with_response(self.http_client, mock_connections, 200) - connections_response = self.sso.list_connections() + connections = self.sso.list_connections() - assert connections_response["data"] == mock_connections["data"] + assert list_data_to_dicts(connections.data) == mock_connections["data"] - def test_list_connections_with_connection_type_as_invalid_string( - self, setup_with_client_id, mock_connections, mock_request_method + def test_list_connections_with_connection_type( + self, mock_connections, capture_and_mock_http_client_request ): - mock_request_method("get", mock_connections, 200) + _, request_kwargs = capture_and_mock_http_client_request( + http_client=self.http_client, + response_dict=mock_connections, + status_code=200, + ) - with pytest.raises( - ValueError, match="'connection_type' must be a member of ConnectionType" - ): - self.sso.list_connections(connection_type="UnknownSAML") + self.sso.list_connections(connection_type="GenericSAML") - def test_list_connections_with_connection_type_as_string( - self, setup_with_client_id, mock_connections, capture_and_mock_request - ): - request_args, request_kwargs = capture_and_mock_request( - "get", mock_connections, 200 + assert request_kwargs["params"] == { + "connection_type": "GenericSAML", + "limit": 10, + "order": "desc", + } + + def test_delete_connection(self, mock_http_client_with_response): + mock_http_client_with_response( + self.http_client, + status_code=204, + headers={"content-type": "text/plain; charset=utf-8"}, ) - connections_response = self.sso.list_connections(connection_type="GenericSAML") + response = self.sso.delete_connection(connection_id="connection_id") - request_params = request_kwargs["params"] - assert request_params["connection_type"] == "GenericSAML" + assert response is None - def test_list_connections_with_connection_type_as_enum( - self, setup_with_client_id, mock_connections, capture_and_mock_request + def test_list_connections_auto_pagination( + self, + mock_connections_multiple_data_pages, + mock_pagination_request_for_http_client, ): - request_args, request_kwargs = capture_and_mock_request( - "get", mock_connections, 200 + mock_pagination_request_for_http_client( + http_client=self.http_client, + data_list=mock_connections_multiple_data_pages, + status_code=200, ) - connections_response = self.sso.list_connections( - connection_type=ConnectionType.OktaSAML - ) + connections = self.sso.list_connections() + all_connections = [] - request_params = request_kwargs["params"] - assert request_params["connection_type"] == "OktaSAML" + for connection in connections.auto_paging_iter(): + all_connections.append(connection) - def test_delete_connection(self, setup_with_client_id, mock_raw_request_method): - mock_raw_request_method( - "delete", - "No Content", - 204, - headers={"content-type": "text/plain; charset=utf-8"}, + assert len(list(all_connections)) == len(mock_connections_multiple_data_pages) + assert ( + list_data_to_dicts(all_connections) + ) == mock_connections_multiple_data_pages + + +@pytest.mark.asyncio +class TestAsyncSSO(SSOFixtures): + provider: SsoProviderType + + @pytest.fixture(autouse=True) + def setup(self, set_api_key_and_client_id): + self.http_client = AsyncHTTPClient( + base_url="https://api.workos.test", version="test" ) + self.sso = AsyncSSO(http_client=self.http_client) + self.provider = "GoogleOAuth" + self.customer_domain = "workos.com" + self.login_hint = "foo@workos.com" + self.redirect_uri = "https://localhost/auth/callback" + self.state = json.dumps({"things": "with_stuff"}) + self.connection_id = "connection_123" + self.organization_id = "organization_123" + self.setup_completed = True + + async def test_get_profile_and_token_returns_expected_profile_object( + self, mock_profile, mock_http_client_with_response + ): + response_dict = { + "profile": { + "object": "profile", + "id": mock_profile["id"], + "email": mock_profile["email"], + "first_name": mock_profile["first_name"], + "groups": mock_profile["groups"], + "organization_id": mock_profile["organization_id"], + "connection_id": mock_profile["connection_id"], + "connection_type": mock_profile["connection_type"], + "last_name": mock_profile["last_name"], + "idp_id": mock_profile["idp_id"], + "raw_attributes": { + "email": mock_profile["raw_attributes"]["email"], + "first_name": mock_profile["raw_attributes"]["first_name"], + "last_name": mock_profile["raw_attributes"]["last_name"], + "groups": mock_profile["raw_attributes"]["groups"], + }, + }, + "access_token": "01DY34ACQTM3B1CSX1YSZ8Z00D", + } - response = self.sso.delete_connection(connection="connection_id") + mock_http_client_with_response(self.http_client, response_dict, 200) - assert response is None + profile_and_token = await self.sso.get_profile_and_token("123") - def test_list_connections_auto_pagination( - self, - mock_connections_with_default_limit, - mock_connections_pagination_response, - mock_connections, - mock_request_method, - setup_with_client_id, + assert profile_and_token.access_token == "01DY34ACQTM3B1CSX1YSZ8Z00D" + assert profile_and_token.profile.dict() == mock_profile + + async def test_get_profile_and_token_without_first_name_or_last_name_returns_expected_profile_object( + self, mock_magic_link_profile, mock_http_client_with_response ): - mock_request_method("get", mock_connections_pagination_response, 200) - connections = mock_connections_with_default_limit + response_dict = { + "profile": mock_magic_link_profile, + "access_token": "01DY34ACQTM3B1CSX1YSZ8Z00D", + } - all_connections = SSO.construct_from_response(connections).auto_paging_iter() + mock_http_client_with_response(self.http_client, response_dict, 200) - assert len(*list(all_connections)) == len(mock_connections["data"]) + profile_and_token = await self.sso.get_profile_and_token("123") - def test_list_connections_auto_pagination_v2( - self, - mock_connections_with_default_limit_v2, - mock_connections_pagination_response, - mock_connections, - mock_request_method, - setup_with_client_id, + assert profile_and_token.access_token == "01DY34ACQTM3B1CSX1YSZ8Z00D" + assert profile_and_token.profile.dict() == mock_magic_link_profile + + async def test_get_profile(self, mock_profile, mock_http_client_with_response): + mock_http_client_with_response(self.http_client, mock_profile, 200) + + profile = await self.sso.get_profile("123") + + assert profile.dict() == mock_profile + + async def test_get_connection( + self, mock_connection, mock_http_client_with_response ): - connections = mock_connections_with_default_limit_v2 + mock_http_client_with_response(self.http_client, mock_connection, 200) - mock_request_method("get", mock_connections_pagination_response, 200) - all_connections = connections.auto_paging_iter() + connection = await self.sso.get_connection(connection_id="connection_id") - number_of_connections = len(*list(all_connections)) - assert number_of_connections == len(mock_connections["data"]) + assert connection.dict() == mock_connection - def test_list_connections_honors_limit( - self, - mock_connections_with_limit, - mock_connections_pagination_response, - mock_request_method, - setup_with_client_id, + async def test_list_connections( + self, mock_connections, mock_http_client_with_response ): - connections = mock_connections_with_limit - mock_request_method("get", mock_connections_pagination_response, 200) - all_connections = SSO.construct_from_response(connections).auto_paging_iter() + mock_http_client_with_response(self.http_client, mock_connections, 200) - assert len(*list(all_connections)) == len(mock_connections_with_limit["data"]) + connections = await self.sso.list_connections() - def test_list_connections_honors_limit_v2( - self, - mock_connections_with_limit_v2, - mock_connections_pagination_response, - mock_request_method, - setup_with_client_id, + assert list_data_to_dicts(connections.data) == mock_connections["data"] + + async def test_list_connections_with_connection_type( + self, mock_connections, capture_and_mock_http_client_request ): - connections = mock_connections_with_limit_v2 - mock_request_method("get", mock_connections_pagination_response, 200) - all_connections = connections.auto_paging_iter() - dict_mock_connections_with_limit = mock_connections_with_limit_v2.to_dict() + _, request_kwargs = capture_and_mock_http_client_request( + http_client=self.http_client, + response_dict=mock_connections, + status_code=200, + ) - assert len(*list(all_connections)) == len( - dict_mock_connections_with_limit["data"] + await self.sso.list_connections(connection_type="GenericSAML") + + assert request_kwargs["params"] == { + "connection_type": "GenericSAML", + "limit": 10, + "order": "desc", + } + + async def test_delete_connection(self, mock_http_client_with_response): + mock_http_client_with_response( + self.http_client, + status_code=204, + headers={"content-type": "text/plain; charset=utf-8"}, ) - def test_list_connections_returns_metadata( - self, - mock_connections, - mock_request_method, - setup_with_client_id, - ): - mock_request_method("get", mock_connections, 200) - connections = self.sso.list_connections(domain="planet-express.com") + response = await self.sso.delete_connection(connection_id="connection_id") - assert connections["metadata"]["params"]["domain"] == "planet-express.com" + assert response is None - def test_list_connections_returns_metadata_v2( + async def test_list_connections_auto_pagination( self, - mock_connections, - mock_request_method, - setup_with_client_id, + mock_connections_multiple_data_pages, + mock_pagination_request_for_http_client, ): - mock_request_method("get", mock_connections, 200) + mock_pagination_request_for_http_client( + http_client=self.http_client, + data_list=mock_connections_multiple_data_pages, + status_code=200, + ) + + connections = await self.sso.list_connections() + all_connections = [] - connections = self.sso.list_connections_v2(domain="planet-express.com") - dict_connections = connections.to_dict() + async for connection in connections.auto_paging_iter(): + all_connections.append(connection) - assert dict_connections["metadata"]["params"]["domain"] == "planet-express.com" + assert len(list(all_connections)) == len(mock_connections_multiple_data_pages) + assert ( + list_data_to_dicts(all_connections) + ) == mock_connections_multiple_data_pages diff --git a/tests/utils/fixtures/mock_connection.py b/tests/utils/fixtures/mock_connection.py index d15101bb..1bf7352a 100644 --- a/tests/utils/fixtures/mock_connection.py +++ b/tests/utils/fixtures/mock_connection.py @@ -4,16 +4,21 @@ class MockConnection(WorkOSBaseResource): def __init__(self, id): - self.object = "organization" + self.object = "connection" self.id = id self.organization_id = "org_id_" + id - self.connection_type = "Okta" + self.connection_type = "OktaSAML" self.name = "Foo Corporation" - self.state = None - self.created_at = datetime.datetime.now() - self.updated_at = datetime.datetime.now() - self.status = None - self.domains = ["domain1.com"] + self.state = "active" + self.created_at = datetime.datetime.now().isoformat() + self.updated_at = datetime.datetime.now().isoformat() + self.domains = [ + { + "id": "connection_domain_abc123", + "object": "connection_domain", + "domain": "domain1.com", + } + ] OBJECT_FIELDS = [ "object", @@ -24,6 +29,5 @@ def __init__(self, id): "state", "created_at", "updated_at", - "status", "domains", ] diff --git a/workos/async_client.py b/workos/async_client.py index c238c387..1b1d1cd7 100644 --- a/workos/async_client.py +++ b/workos/async_client.py @@ -1,6 +1,7 @@ from workos._base_client import BaseClient from workos.directory_sync import AsyncDirectorySync from workos.events import AsyncEvents +from workos.sso import AsyncSSO from workos.utils.http_client import AsyncHTTPClient @@ -16,7 +17,9 @@ def __init__(self, base_url: str, version: str, timeout: int): @property def sso(self): - raise NotImplementedError("SSO APIs are not yet supported in the async client.") + if not getattr(self, "_sso", None): + self._sso = AsyncSSO(self._http_client) + return self._sso @property def audit_logs(self): diff --git a/workos/client.py b/workos/client.py index 82803ed7..0fa77d5c 100644 --- a/workos/client.py +++ b/workos/client.py @@ -25,7 +25,7 @@ def __init__(self, base_url: str, version: str, timeout: int): @property def sso(self): if not getattr(self, "_sso", None): - self._sso = SSO() + self._sso = SSO(self._http_client) return self._sso @property diff --git a/workos/directory_sync.py b/workos/directory_sync.py index 678cf422..4ff4fc36 100644 --- a/workos/directory_sync.py +++ b/workos/directory_sync.py @@ -3,7 +3,11 @@ from workos.typing.sync_or_async import SyncOrAsync from workos.utils.http_client import AsyncHTTPClient, SyncHTTPClient from workos.utils.pagination_order import PaginationOrder -from workos.utils.request import REQUEST_METHOD_DELETE, REQUEST_METHOD_GET +from workos.utils.request import ( + DEFAULT_LIST_RESPONSE_LIMIT, + REQUEST_METHOD_DELETE, + REQUEST_METHOD_GET, +) from workos.utils.validation import DIRECTORY_SYNC_MODULE, validate_settings from workos.resources.directory_sync import ( DirectoryGroup, @@ -19,9 +23,6 @@ ) -RESPONSE_LIMIT = 10 - - class DirectoryListFilters(ListArgs, total=False): search: Optional[str] organization_id: Optional[str] @@ -46,7 +47,7 @@ def list_users( self, directory: Optional[str] = None, group: Optional[str] = None, - limit: int = RESPONSE_LIMIT, + limit: int = DEFAULT_LIST_RESPONSE_LIMIT, before: Optional[str] = None, after: Optional[str] = None, order: PaginationOrder = "desc", @@ -56,7 +57,7 @@ def list_groups( self, directory: Optional[str] = None, user: Optional[str] = None, - limit: int = RESPONSE_LIMIT, + limit: int = DEFAULT_LIST_RESPONSE_LIMIT, before: Optional[str] = None, after: Optional[str] = None, order: PaginationOrder = "desc", @@ -72,7 +73,7 @@ def list_directories( self, domain: Optional[str] = None, search: Optional[str] = None, - limit: int = RESPONSE_LIMIT, + limit: int = DEFAULT_LIST_RESPONSE_LIMIT, before: Optional[str] = None, after: Optional[str] = None, organization: Optional[str] = None, @@ -95,7 +96,7 @@ def list_users( self, directory: Optional[str] = None, group: Optional[str] = None, - limit: Optional[int] = RESPONSE_LIMIT, + limit: Optional[int] = DEFAULT_LIST_RESPONSE_LIMIT, before: Optional[str] = None, after: Optional[str] = None, order: PaginationOrder = "desc", @@ -117,7 +118,7 @@ def list_users( """ list_params: DirectoryUserListFilters = { - "limit": limit if limit is not None else RESPONSE_LIMIT, + "limit": limit if limit is not None else DEFAULT_LIST_RESPONSE_LIMIT, "before": before, "after": after, "order": order, @@ -145,7 +146,7 @@ def list_groups( self, directory: Optional[str] = None, user: Optional[str] = None, - limit: int = RESPONSE_LIMIT, + limit: int = DEFAULT_LIST_RESPONSE_LIMIT, before: Optional[str] = None, after: Optional[str] = None, order: PaginationOrder = "desc", @@ -246,7 +247,7 @@ def list_directories( self, domain: Optional[str] = None, search: Optional[str] = None, - limit: int = RESPONSE_LIMIT, + limit: int = DEFAULT_LIST_RESPONSE_LIMIT, before: Optional[str] = None, after: Optional[str] = None, organization: Optional[str] = None, @@ -318,7 +319,7 @@ async def list_users( self, directory: Optional[str] = None, group: Optional[str] = None, - limit: int = RESPONSE_LIMIT, + limit: int = DEFAULT_LIST_RESPONSE_LIMIT, before: Optional[str] = None, after: Optional[str] = None, order: PaginationOrder = "desc", @@ -368,7 +369,7 @@ async def list_groups( self, directory: Optional[str] = None, user: Optional[str] = None, - limit: int = RESPONSE_LIMIT, + limit: int = DEFAULT_LIST_RESPONSE_LIMIT, before: Optional[str] = None, after: Optional[str] = None, order: PaginationOrder = "desc", @@ -468,7 +469,7 @@ async def list_directories( self, domain: Optional[str] = None, search: Optional[str] = None, - limit: int = RESPONSE_LIMIT, + limit: int = DEFAULT_LIST_RESPONSE_LIMIT, before: Optional[str] = None, after: Optional[str] = None, organization: Optional[str] = None, diff --git a/workos/events.py b/workos/events.py index 809c7315..80e838f5 100644 --- a/workos/events.py +++ b/workos/events.py @@ -2,7 +2,7 @@ import workos from workos.typing.sync_or_async import SyncOrAsync -from workos.utils.request import REQUEST_METHOD_GET +from workos.utils.request import DEFAULT_LIST_RESPONSE_LIMIT, REQUEST_METHOD_GET from workos.resources.events import Event, EventType from workos.utils.http_client import AsyncHTTPClient, SyncHTTPClient from workos.utils.validation import EVENTS_MODULE, validate_settings @@ -12,8 +12,6 @@ WorkOsListResource, ) -RESPONSE_LIMIT = 10 - class EventsListFilters(ListArgs, total=False): events: List[EventType] @@ -29,7 +27,7 @@ class EventsModule(Protocol): def list_events( self, events: List[EventType], - limit: int = RESPONSE_LIMIT, + limit: int = DEFAULT_LIST_RESPONSE_LIMIT, organization_id: Optional[str] = None, after: Optional[str] = None, range_start: Optional[str] = None, @@ -49,7 +47,7 @@ def __init__(self, http_client: SyncHTTPClient): def list_events( self, events: List[EventType], - limit: int = RESPONSE_LIMIT, + limit: int = DEFAULT_LIST_RESPONSE_LIMIT, organization: Optional[str] = None, after: Optional[str] = None, range_start: Optional[str] = None, @@ -103,7 +101,7 @@ def __init__(self, http_client: AsyncHTTPClient): async def list_events( self, events: List[EventType], - limit: int = RESPONSE_LIMIT, + limit: int = DEFAULT_LIST_RESPONSE_LIMIT, organization_id: Optional[str] = None, after: Optional[str] = None, range_start: Optional[str] = None, diff --git a/workos/organizations.py b/workos/organizations.py index 34d4d0dc..a43baf30 100644 --- a/workos/organizations.py +++ b/workos/organizations.py @@ -2,6 +2,7 @@ import workos from workos.utils.pagination_order import PaginationOrder from workos.utils.request import ( + DEFAULT_LIST_RESPONSE_LIMIT, RequestHelper, REQUEST_METHOD_DELETE, REQUEST_METHOD_GET, @@ -16,7 +17,6 @@ from workos.resources.list import ListPage, WorkOsListResource, ListArgs ORGANIZATIONS_PATH = "organizations" -RESPONSE_LIMIT = 10 class OrganizationListFilters(ListArgs, total=False): @@ -27,7 +27,7 @@ class OrganizationsModule(Protocol): def list_organizations( self, domains: Optional[List[str]] = None, - limit: int = RESPONSE_LIMIT, + limit: int = DEFAULT_LIST_RESPONSE_LIMIT, before: Optional[str] = None, after: Optional[str] = None, order: PaginationOrder = "desc", @@ -68,7 +68,7 @@ def request_helper(self): def list_organizations( self, domains: Optional[List[str]] = None, - limit: int = RESPONSE_LIMIT, + limit: int = DEFAULT_LIST_RESPONSE_LIMIT, before: Optional[str] = None, after: Optional[str] = None, order: PaginationOrder = "desc", diff --git a/workos/resources/list.py b/workos/resources/list.py index e714f3f2..560eb6d5 100644 --- a/workos/resources/list.py +++ b/workos/resources/list.py @@ -19,6 +19,7 @@ from workos.resources.events import Event from workos.resources.organizations import Organization from pydantic import BaseModel, Field +from workos.resources.sso import Connection from workos.resources.workos_model import WorkOSModel @@ -116,11 +117,12 @@ def auto_paging_iter(self): ListableResource = TypeVar( # add all possible generics of List Resource "ListableResource", - Organization, + Connection, Directory, DirectoryGroup, DirectoryUser, Event, + Organization, ) diff --git a/workos/resources/sso.py b/workos/resources/sso.py index c9ec223e..1d9216d3 100644 --- a/workos/resources/sso.py +++ b/workos/resources/sso.py @@ -1,87 +1,61 @@ -from workos.resources.base import WorkOSBaseResource +from typing import List, Literal, Union +from workos.resources.workos_model import WorkOSModel +from workos.typing.literals import LiteralOrUntyped +from workos.utils.connection_types import ConnectionType -class WorkOSProfile(WorkOSBaseResource): - """Representation of a User Profile as returned by WorkOS through the SSO feature. - Attributes: - OBJECT_FIELDS (list): List of fields a WorkOSProfile is comprised of. - """ +class Profile(WorkOSModel): + """Representation of a User Profile as returned by WorkOS through the SSO feature.""" - OBJECT_FIELDS = [ - "id", - "email", - "first_name", - "last_name", - "groups", - "organization_id", - "connection_id", - "connection_type", - "idp_id", - "raw_attributes", - ] + object: Literal["profile"] + id: str + connection_id: str + connection_type: LiteralOrUntyped[ConnectionType] + organization_id: Union[str, None] + email: str + first_name: Union[str, None] + last_name: Union[str, None] + idp_id: str + groups: Union[List[str], None] + raw_attributes: dict -class WorkOSProfileAndToken(WorkOSBaseResource): - """Representation of a User Profile and Access Token as returned by WorkOS through the SSO feature. +class ProfileAndToken(WorkOSModel): + """Representation of a User Profile and Access Token as returned by WorkOS through the SSO feature.""" - Attributes: - OBJECT_FIELDS (list): List of fields a WorkOSProfileAndToken is comprised of. - """ + access_token: str + profile: Profile - OBJECT_FIELDS = [ - "access_token", - ] - @classmethod - def construct_from_response(cls, response): - profile_and_token = super(WorkOSProfileAndToken, cls).construct_from_response( - response - ) +ConnectionState = Literal[ + "active", "deleting", "inactive", "requires_type", "validating" +] - profile_and_token.profile = WorkOSProfile.construct_from_response( - response["profile"] - ) - return profile_and_token +class ConnectionDomain(WorkOSModel): + object: Literal["connection_domain"] + id: str + domain: str - def to_dict(self): - profile_and_token_dict = super(WorkOSProfileAndToken, self).to_dict() - profile_dict = self.profile.to_dict() - profile_and_token_dict["profile"] = profile_dict +class Connection(WorkOSModel): + """Representation of a Connection Response as returned by WorkOS through the SSO feature.""" - return profile_and_token_dict + object: Literal["connection"] + id: str + organization_id: str + connection_type: LiteralOrUntyped[ConnectionType] + name: str + state: LiteralOrUntyped[ConnectionState] + created_at: str + updated_at: str + domains: List[ConnectionDomain] -class WorkOSConnection(WorkOSBaseResource): - """Representation of a Connection Response as returned by WorkOS through the SSO feature. - Attributes: - OBJECT_FIELDS (list): List of fields a WorkOSConnection is comprised of. - """ - - OBJECT_FIELDS = [ - "object", - "id", - "organization_id", - "connection_type", - "name", - "state", - "created_at", - "updated_at", - "status", - "domains", - ] - - @classmethod - def construct_from_response(cls, response): - connection_response = super(WorkOSConnection, cls).construct_from_response( - response - ) - - return connection_response - - def to_dict(self): - connection_response_dict = super(WorkOSConnection, self).to_dict() - - return connection_response_dict +SsoProviderType = Literal[ + "AppleOAuth", + "GitHubOAuth", + "GoogleOAuth", + "MicrosoftOAuth", +] diff --git a/workos/sso.py b/workos/sso.py index 08b2d9f4..293b69a5 100644 --- a/workos/sso.py +++ b/workos/sso.py @@ -1,25 +1,32 @@ -from typing import Protocol -from warnings import warn +from typing import Optional, Protocol, Union -from requests import Request import workos -from workos.utils.pagination_order import Order +from workos.typing.sync_or_async import SyncOrAsync +from workos.utils.http_client import AsyncHTTPClient, SyncHTTPClient +from workos.utils.pagination_order import PaginationOrder from workos.resources.sso import ( - WorkOSProfile, - WorkOSProfileAndToken, - WorkOSConnection, + Connection, + Profile, + ProfileAndToken, + SsoProviderType, ) from workos.utils.connection_types import ConnectionType -from workos.utils.sso_provider_types import SsoProviderType from workos.utils.request import ( - RequestHelper, + DEFAULT_LIST_RESPONSE_LIMIT, RESPONSE_TYPE_CODE, REQUEST_METHOD_DELETE, REQUEST_METHOD_GET, REQUEST_METHOD_POST, + RequestHelper, ) from workos.utils.validation import SSO_MODULE, validate_settings -from workos.resources.list import WorkOSListResource +from workos.resources.list import ( + AsyncWorkOsListResource, + ListArgs, + ListPage, + SyncOrAsyncListResource, + WorkOsListResource, +) AUTHORIZATION_PATH = "sso/authorize" TOKEN_PATH = "sso/token" @@ -27,90 +34,38 @@ OAUTH_GRANT_TYPE = "authorization_code" -RESPONSE_LIMIT = 10 - - -class SSOModule(Protocol): - def get_authorization_url( - self, - domain=None, - domain_hint=None, - login_hint=None, - redirect_uri=None, - state=None, - provider=None, - connection=None, - organization=None, - ) -> str: ... - - def get_profile(self, accessToken: str) -> WorkOSProfile: ... - - def get_profile_and_token(self, code: str) -> WorkOSProfileAndToken: ... - - def get_connection(self, connection: str) -> dict: ... - def list_connections( - self, - connection_type=None, - domain=None, - organization_id=None, - limit=None, - before=None, - after=None, - order=None, - ) -> dict: ... - - def list_connections_v2( - self, - connection_type=None, - domain=None, - organization_id=None, - limit=None, - before=None, - after=None, - order=None, - ) -> dict: ... +class ConnectionsListFilters(ListArgs, total=False): + connection_type: Optional[ConnectionType] + domain: Optional[str] + organization_id: Optional[str] - def delete_connection(self, connection: str) -> None: ... - -class SSO(SSOModule, WorkOSListResource): - """Offers methods to assist in authenticating through the WorkOS SSO service.""" - - @validate_settings(SSO_MODULE) - def __init__(self): - pass - - @property - def request_helper(self): - if not getattr(self, "_request_helper", None): - self._request_helper = RequestHelper() - return self._request_helper +class SSOModule(Protocol): + _http_client: Union[SyncHTTPClient, AsyncHTTPClient] def get_authorization_url( self, - domain=None, - domain_hint=None, - login_hint=None, - redirect_uri=None, - state=None, - provider=None, - connection=None, - organization=None, - ): + redirect_uri: str, + domain_hint: Optional[str] = None, + login_hint: Optional[str] = None, + state: Optional[str] = None, + provider: Optional[SsoProviderType] = None, + connection_id: Optional[str] = None, + organization_id: Optional[str] = None, + ) -> str: """Generate an OAuth 2.0 authorization URL. The URL generated will redirect a User to the Identity Provider configured through WorkOS. Kwargs: - domain (str) - The domain a user is associated with, as configured on WorkOS redirect_uri (str) - A valid redirect URI, as specified on WorkOS state (str) - An encoded string passed to WorkOS that'd be preserved through the authentication workflow, passed back as a query parameter - provider (SsoProviderType) - Authentication service provider descriptor - connection (string) - Unique identifier for a WorkOS Connection - organization (string) - Unique identifier for a WorkOS Organization + provider (SSOProviderType) - Authentication service provider descriptor + connection_id (string) - Unique identifier for a WorkOS Connection + organization_id (string) - Unique identifier for a WorkOS Organization Returns: str: URL to redirect a User to to begin the OAuth workflow with WorkOS @@ -121,50 +76,58 @@ def get_authorization_url( "response_type": RESPONSE_TYPE_CODE, } - if ( - domain is None - and provider is None - and connection is None - and organization is None - ): + if connection_id is None and organization_id is None and provider is None: raise ValueError( - "Incomplete arguments. Need to specify either a 'connection', 'organization', 'domain', or 'provider'" + "Incomplete arguments. Need to specify either a 'connection', 'organization', or 'provider'" ) if provider is not None: - if not isinstance(provider, SsoProviderType): - raise ValueError("'provider' must be of type SsoProviderType") - - params["provider"] = provider.value - if domain is not None: - warn( - "The 'domain' parameter for 'get_authorization_url' is deprecated. Please use 'organization' instead.", - DeprecationWarning, - ) - params["domain"] = domain + params["provider"] = provider if domain_hint is not None: params["domain_hint"] = domain_hint if login_hint is not None: params["login_hint"] = login_hint - if connection is not None: - params["connection"] = connection - if organization is not None: - params["organization"] = organization + if connection_id is not None: + params["connection"] = connection_id + if organization_id is not None: + params["organization"] = organization_id if state is not None: params["state"] = state - if redirect_uri is None: - raise ValueError("Incomplete arguments. Need to specify a 'redirect_uri'.") + return RequestHelper.build_url_with_query_params( + self._http_client.base_url, **params + ) - prepared_request = Request( - "GET", - self.request_helper.generate_api_url(AUTHORIZATION_PATH), - params=params, - ).prepare() + def get_profile(self, accessToken: str) -> SyncOrAsync[Profile]: ... + + def get_profile_and_token(self, code: str) -> SyncOrAsync[ProfileAndToken]: ... - return prepared_request.url + def get_connection(self, connection: str) -> SyncOrAsync[Connection]: ... + + def list_connections( + self, + connection_type: Optional[ConnectionType] = None, + domain: Optional[str] = None, + organization_id: Optional[str] = None, + limit: int = DEFAULT_LIST_RESPONSE_LIMIT, + before: Optional[str] = None, + after: Optional[str] = None, + order: PaginationOrder = "desc", + ) -> SyncOrAsyncListResource: ... - def get_profile(self, accessToken): + def delete_connection(self, connection: str) -> SyncOrAsync[None]: ... + + +class SSO(SSOModule): + """Offers methods to assist in authenticating through the WorkOS SSO service.""" + + _http_client: SyncHTTPClient + + @validate_settings(SSO_MODULE) + def __init__(self, http_client: SyncHTTPClient): + self._http_client = http_client + + def get_profile(self, access_token: str) -> Profile: """ Verify that SSO has been completed successfully and retrieve the identity of the user. @@ -172,18 +135,15 @@ def get_profile(self, accessToken): accessToken (str): the token used to authenticate the API call Returns: - WorkOSProfile + Profile """ - - token = accessToken - - response = self.request_helper.request( - PROFILE_PATH, method=REQUEST_METHOD_GET, token=token + response = self._http_client.request( + PROFILE_PATH, method=REQUEST_METHOD_GET, token=access_token ) - return WorkOSProfile.construct_from_response(response) + return Profile.model_validate(response) - def get_profile_and_token(self, code): + def get_profile_and_token(self, code: str) -> ProfileAndToken: """Get the profile of an authenticated User Once authenticated, using the code returned having followed the authorization URL, @@ -193,7 +153,7 @@ def get_profile_and_token(self, code): code (str): Code returned by WorkOS on completion of OAuth 2.0 workflow Returns: - WorkOSProfileAndToken: WorkOSProfileAndToken object representing the User + ProfileAndToken: WorkOSProfileAndToken object representing the User """ params = { "client_id": workos.client_id, @@ -202,13 +162,13 @@ def get_profile_and_token(self, code): "grant_type": OAUTH_GRANT_TYPE, } - response = self.request_helper.request( + response = self._http_client.request( TOKEN_PATH, method=REQUEST_METHOD_POST, params=params ) - return WorkOSProfileAndToken.construct_from_response(response) + return ProfileAndToken.model_validate(response) - def get_connection(self, connection): + def get_connection(self, connection_id: str) -> Connection: """Gets details for a single Connection Args: @@ -217,24 +177,24 @@ def get_connection(self, connection): Returns: dict: Connection response from WorkOS. """ - response = self.request_helper.request( - "connections/{connection}".format(connection=connection), + response = self._http_client.request( + f"connections/{connection_id}", method=REQUEST_METHOD_GET, token=workos.api_key, ) - return WorkOSConnection.construct_from_response(response).to_dict() + return Connection.model_validate(response) def list_connections( self, - connection_type=None, - domain=None, - organization_id=None, - limit=None, - before=None, - after=None, - order=None, - ): + connection_type: Optional[ConnectionType] = None, + domain: Optional[str] = None, + organization_id: Optional[str] = None, + limit: Optional[int] = DEFAULT_LIST_RESPONSE_LIMIT, + before: Optional[str] = None, + after: Optional[str] = None, + order: PaginationOrder = "desc", + ) -> WorkOsListResource[Connection, ConnectionsListFilters]: """Gets details for existing Connections. Args: @@ -248,33 +208,9 @@ def list_connections( Returns: dict: Connections response from WorkOS. """ - warn( - "The 'list_connections' method is deprecated. Please use 'list_connections_v2' instead.", - DeprecationWarning, - ) - - # This method used to accept `connection_type` as a string, so we try - # to convert strings to a `ConnectionType` to support existing callers. - # - # TODO: Remove support for string values of `ConnectionType` in the next - # major version. - if connection_type is not None and isinstance(connection_type, str): - try: - connection_type = ConnectionType[connection_type] - - warn( - "Passing a string value as the 'connection_type' parameter for 'list_connections' is deprecated and will be removed in the next major version. Please pass a 'ConnectionType' instead.", - DeprecationWarning, - ) - except KeyError: - raise ValueError("'connection_type' must be a member of ConnectionType") - - if limit is None: - limit = RESPONSE_LIMIT - default_limit = True params = { - "connection_type": connection_type.value if connection_type else None, + "connection_type": connection_type, "domain": domain, "organization_id": organization_id, "limit": limit, @@ -283,45 +219,109 @@ def list_connections( "order": order or "desc", } - if order is not None: - if isinstance(order, Order): - params["order"] = str(order.value) - - elif order == "asc" or order == "desc": - params["order"] = order - else: - raise ValueError("Parameter order must be of enum type Order") - - response = self.request_helper.request( + response = self._http_client.request( "connections", method=REQUEST_METHOD_GET, params=params, token=workos.api_key, ) - response["metadata"] = { - "params": params, - "method": SSO.list_connections, + return WorkOsListResource( + list_method=self.list_connections, + list_args=params, + **ListPage[Connection](**response).model_dump(), + ) + + def delete_connection(self, connection_id: str) -> None: + """Deletes a single Connection + + Args: + connection (str): Connection unique identifier + """ + self._http_client.request( + f"connections/{connection_id}", + method=REQUEST_METHOD_DELETE, + token=workos.api_key, + ) + + +class AsyncSSO(SSOModule): + """Offers methods to assist in authenticating through the WorkOS SSO service.""" + + _http_client: AsyncHTTPClient + + @validate_settings(SSO_MODULE) + def __init__(self, http_client: AsyncHTTPClient): + self._http_client = http_client + + async def get_profile(self, access_token: str) -> Profile: + """ + Verify that SSO has been completed successfully and retrieve the identity of the user. + + Args: + accessToken (str): the token used to authenticate the API call + + Returns: + Profile + """ + response = await self._http_client.request( + PROFILE_PATH, method=REQUEST_METHOD_GET, token=access_token + ) + + return Profile.model_validate(response) + + async def get_profile_and_token(self, code: str) -> ProfileAndToken: + """Get the profile of an authenticated User + + Once authenticated, using the code returned having followed the authorization URL, + get the WorkOS profile of the User. + + Args: + code (str): Code returned by WorkOS on completion of OAuth 2.0 workflow + + Returns: + ProfileAndToken: WorkOSProfileAndToken object representing the User + """ + params = { + "client_id": workos.client_id, + "client_secret": workos.api_key, + "code": code, + "grant_type": OAUTH_GRANT_TYPE, } - if "default_limit" in locals(): - if "metadata" in response and "params" in response["metadata"]: - response["metadata"]["params"]["default_limit"] = default_limit - else: - response["metadata"] = {"params": {"default_limit": default_limit}} + response = await self._http_client.request( + TOKEN_PATH, method=REQUEST_METHOD_POST, params=params + ) + + return ProfileAndToken.model_validate(response) + + async def get_connection(self, connection_id: str) -> Connection: + """Gets details for a single Connection + + Args: + connection (str): Connection unique identifier + + Returns: + dict: Connection response from WorkOS. + """ + response = await self._http_client.request( + f"connections/{connection_id}", + method=REQUEST_METHOD_GET, + token=workos.api_key, + ) - return response + return Connection.model_validate(response) - def list_connections_v2( + async def list_connections( self, - connection_type=None, - domain=None, - organization_id=None, - limit=None, - before=None, - after=None, - order=None, - ): + connection_type: Optional[ConnectionType] = None, + domain: Optional[str] = None, + organization_id: Optional[str] = None, + limit: Optional[int] = DEFAULT_LIST_RESPONSE_LIMIT, + before: Optional[str] = None, + after: Optional[str] = None, + order: PaginationOrder = "desc", + ) -> AsyncWorkOsListResource[Connection, ConnectionsListFilters]: """Gets details for existing Connections. Args: @@ -336,28 +336,8 @@ def list_connections_v2( dict: Connections response from WorkOS. """ - # This method used to accept `connection_type` as a string, so we try - # to convert strings to a `ConnectionType` to support existing callers. - # - # TODO: Remove support for string values of `ConnectionType` in the next - # major version. - if connection_type is not None and isinstance(connection_type, str): - try: - connection_type = ConnectionType[connection_type] - - warn( - "Passing a string value as the 'connection_type' parameter for 'list_connections' is deprecated and will be removed in the next major version. Please pass a 'ConnectionType' instead.", - DeprecationWarning, - ) - except KeyError: - raise ValueError("'connection_type' must be a member of ConnectionType") - - if limit is None: - limit = RESPONSE_LIMIT - default_limit = True - params = { - "connection_type": connection_type.value if connection_type else None, + "connection_type": connection_type, "domain": domain, "organization_id": organization_id, "limit": limit, @@ -366,43 +346,27 @@ def list_connections_v2( "order": order or "desc", } - if order is not None: - if isinstance(order, Order): - params["order"] = str(order.value) - - elif order == "asc" or order == "desc": - params["order"] = order - else: - raise ValueError("Parameter order must be of enum type Order") - - response = self.request_helper.request( + response = await self._http_client.request( "connections", method=REQUEST_METHOD_GET, params=params, token=workos.api_key, ) - response["metadata"] = { - "params": params, - "method": SSO.list_connections_v2, - } - - if "default_limit" in locals(): - if "metadata" in response and "params" in response["metadata"]: - response["metadata"]["params"]["default_limit"] = default_limit - else: - response["metadata"] = {"params": {"default_limit": default_limit}} - - return self.construct_from_response(response) + return AsyncWorkOsListResource( + list_method=self.list_connections, + list_args=params, + **ListPage[Connection](**response).model_dump(), + ) - def delete_connection(self, connection): + async def delete_connection(self, connection_id: str) -> None: """Deletes a single Connection Args: connection (str): Connection unique identifier """ - return self.request_helper.request( - "connections/{connection}".format(connection=connection), + await self._http_client.request( + f"connections/{connection_id}", method=REQUEST_METHOD_DELETE, token=workos.api_key, ) diff --git a/workos/user_management.py b/workos/user_management.py index 6fef217d..9f9931fb 100644 --- a/workos/user_management.py +++ b/workos/user_management.py @@ -19,6 +19,7 @@ from workos.utils.pagination_order import Order from workos.utils.um_provider_types import UserManagementProviderType from workos.utils.request import ( + DEFAULT_LIST_RESPONSE_LIMIT, RequestHelper, RESPONSE_TYPE_CODE, REQUEST_METHOD_POST, @@ -56,8 +57,6 @@ PASSWORD_RESET_PATH = "user_management/password_reset" PASSWORD_RESET_DETAIL_PATH = "user_management/password_reset/{0}" -RESPONSE_LIMIT = 10 - class UserManagementModule(Protocol): def get_user(self, user_id: str) -> dict: ... @@ -293,7 +292,7 @@ def list_users( default_limit = None if limit is None: - limit = RESPONSE_LIMIT + limit = DEFAULT_LIST_RESPONSE_LIMIT default_limit = True params = { @@ -504,7 +503,7 @@ def list_organization_memberships( default_limit = None if limit is None: - limit = RESPONSE_LIMIT + limit = DEFAULT_LIST_RESPONSE_LIMIT default_limit = True if statuses is not None: @@ -1449,7 +1448,7 @@ def list_invitations( default_limit = None if limit is None: - limit = RESPONSE_LIMIT + limit = DEFAULT_LIST_RESPONSE_LIMIT default_limit = True params = { diff --git a/workos/utils/_base_http_client.py b/workos/utils/_base_http_client.py index 91d4dccb..13b5637b 100644 --- a/workos/utils/_base_http_client.py +++ b/workos/utils/_base_http_client.py @@ -176,6 +176,16 @@ def _handle_response(self, response: httpx.Response) -> dict: return cast(Dict, response_json) + def build_request_url( + self, + url: str, + method: Optional[str] = REQUEST_METHOD_GET, + params: Optional[Mapping] = None, + ) -> str: + return self._client.build_request( + method=method or REQUEST_METHOD_GET, url=url, params=params + ).url.__str__() + @property def base_url(self) -> str: return self._base_url @@ -207,3 +217,7 @@ def user_agent(self) -> str: @property def timeout(self) -> int: return self._timeout + + @property + def version(self) -> str: + return self._version diff --git a/workos/utils/connection_types.py b/workos/utils/connection_types.py index 32e1ef20..739026c4 100644 --- a/workos/utils/connection_types.py +++ b/workos/utils/connection_types.py @@ -1,48 +1,39 @@ -from enum import Enum +from typing import Literal -class ConnectionType(Enum): - ADFSSAML = "ADFSSAML" - AdpOidc = "AdpOidc" - AppleOAuth = "AppleOAuth" - Auth0SAML = "Auth0SAML" - AzureSAML = "AzureSAML" - CasSAML = "CasSAML" - CloudflareSAML = "CloudflareSAML" - ClassLinkSAML = "ClassLinkSAML" - CyberArkSAML = "CyberArkSAML" - DuoSAML = "DuoSAML" - GenericOIDC = "GenericOIDC" - GenericSAML = "GenericSAML" - GitHubOAuth = "GitHubOAuth" - GoogleOAuth = "GoogleOAuth" - GoogleSAML = "GoogleSAML" - JumpCloudSAML = "JumpCloudSAML" - KeycloakSAML = "KeycloakSAML" - LastPassSAML = "LastPassSAML" - MagicLink = "MagicLink" - MicrosoftOAuth = "MicrosoftOAuth" - MiniOrangeSAML = "MiniOrangeSAML" - NetIqSAML = "NetIqSAML" - OktaSAML = "OktaSAML" - OneLoginSAML = "OneLoginSAML" - OracleSAML = "OracleSAML" - PingFederateSAML = "PingFederateSAML" - PingOneSAML = "PingOneSAML" - RipplingSAML = "RipplingSAML" - SalesforceSAML = "SalesforceSAML" - ShibbolethGenericSAML = "ShibbolethGenericSAML" - ShibbolethSAML = "ShibbolethSAML" - SimpleSamlPhpSAML = "SimpleSamlPhpSAML" - VMwareSAML = "VMwareSAML" - - @classmethod - def providers(cls): - """Returns a generator of all connection types/providers. - This is only needed as a workaround for providers passed - as a string connection type. - - Returns: - generator(list): A lazy list of all connection types - """ - return (connection_type.value for connection_type in ConnectionType) +ConnectionType = Literal[ + "ADFSSAML", + "AdpOidc", + "AppleOAuth", + "Auth0SAML", + "AzureSAML", + "CasSAML", + "CloudflareSAML", + "ClassLinkSAML", + "CyberArkSAML", + "DuoSAML", + "GenericOIDC", + "GenericSAML", + "GitHubOAuth", + "GoogleOAuth", + "GoogleSAML", + "JumpCloudSAML", + "KeycloakSAML", + "LastPassSAML", + "LoginGovOidc", + "MagicLink", + "MicrosoftOAuth", + "MiniOrangeSAML", + "NetIqSAML", + "OktaSAML", + "OneLoginSAML", + "OracleSAML", + "PingFederateSAML", + "PingOneSAML", + "RipplingSAML", + "SalesforceSAML", + "ShibbolethGenericSAML", + "ShibbolethSAML", + "SimpleSamlPhpSAML", + "VMwareSAML", +] diff --git a/workos/utils/request.py b/workos/utils/request.py index e54aeafd..09173f18 100644 --- a/workos/utils/request.py +++ b/workos/utils/request.py @@ -21,8 +21,8 @@ ), } +DEFAULT_LIST_RESPONSE_LIMIT = 10 RESPONSE_TYPE_CODE = "code" - REQUEST_METHOD_DELETE = "delete" REQUEST_METHOD_GET = "get" REQUEST_METHOD_POST = "post" @@ -52,6 +52,10 @@ def build_parameterized_url(self, url, **params): escaped_params = {k: urllib.parse.quote(str(v)) for k, v in params.items()} return url.format(**escaped_params) + @classmethod + def build_url_with_query_params(cls, url, **params): + return url.format("?" + urllib.parse.urlencode(params)) + def request( self, path, diff --git a/workos/utils/sso_provider_types.py b/workos/utils/sso_provider_types.py deleted file mode 100644 index 1be0120d..00000000 --- a/workos/utils/sso_provider_types.py +++ /dev/null @@ -1,8 +0,0 @@ -from enum import Enum - - -class SsoProviderType(Enum): - AppleOAuth = "AppleOAuth" - GitHubOAuth = "GitHubOAuth" - GoogleOAuth = "GoogleOAuth" - MicrosoftOAuth = "MicrosoftOAuth"