From 37f36c6a527f8f9d743d97dd8441599a94ff632b Mon Sep 17 00:00:00 2001 From: pantera Date: Wed, 31 Jul 2024 11:26:14 -0700 Subject: [PATCH] Type webhook methods (#307) * Type inputs to webhook verification methods * Return well defined types from webhook verification * mypy fixes * Remove some unneeded tests * Update test fixtures with new values * Add explicit None --- tests/test_webhooks.py | 28 +--- workos/resources/webhooks.py | 283 +++++++++++++++++++++++++++++++++++ workos/webhooks.py | 56 ++++--- 3 files changed, 325 insertions(+), 42 deletions(-) create mode 100644 workos/resources/webhooks.py diff --git a/tests/test_webhooks.py b/tests/test_webhooks.py index 1d91243b..f14f08d0 100644 --- a/tests/test_webhooks.py +++ b/tests/test_webhooks.py @@ -1,3 +1,4 @@ +import datetime import json from os import error from workos.webhooks import Webhooks @@ -16,15 +17,15 @@ def setup(self, set_api_key): @pytest.fixture def mock_event_body(self): - return '{"id":"wh_01FG9JXJ9C9S052FX59JVG4EG1","data":{"id":"conn_01EHWNC0FCBHZ3BJ7EGKYXK0E6","name":"Foo Corp\'s Connection","state":"active","object":"connection","domains":[{"id":"conn_domain_01EHWNFTAFCF3CQAE5A9Q0P1YB","domain":"foo-corp.com","object":"connection_domain"}],"connection_type":"OktaSAML","organization_id":"org_01EHWNCE74X7JSDV0X3SZ3KJNY"},"event":"connection.activated"}' + return '{"id":"event_01J44T8116Q5M0RYCFA6KWNXN9","data":{"id":"conn_01EHWNC0FCBHZ3BJ7EGKYXK0E6","name":"Foo Corp\'s Connection","state":"active","object":"connection","status":"linked","domains":[{"id":"conn_domain_01EHWNFTAFCF3CQAE5A9Q0P1YB","domain":"foo-corp.com","object":"connection_domain"}],"created_at":"2021-06-25T19:07:33.155Z","updated_at":"2021-06-25T19:07:33.155Z","external_key":"3QMR4u0Tok6SgwY2AWG6u6mkQ","connection_type":"OktaSAML","organization_id":"org_01EHWNCE74X7JSDV0X3SZ3KJNY"},"event":"connection.activated","created_at":"2021-06-25T19:07:33.155Z"}' @pytest.fixture def mock_header(self): - return "t=1632409405772, v1=67612f0e74f008b436a13b00266f90ef5c13f9cbcf6262206f5f4a539ff61702" + return "t=1722443701539, v1=bd54a3768f461461c8439c2f97ab0d646ef3976f84d5d5b132d18f2fa89cdad5" @pytest.fixture def mock_secret(self): - return "1lyKDzhJjuCkIscIWqkSe4YsQ" + return "2sAZJlbjP8Ce3rwkKEv2GfKef" @pytest.fixture def mock_bad_secret(self): @@ -32,31 +33,12 @@ def mock_bad_secret(self): @pytest.fixture def mock_header_no_timestamp(self): - return "v1=67612f0e74f008b436a13b00266f90ef5c13f9cbcf6262206f5f4a539ff61702" + return "v1=bd54a3768f461461c8439c2f97ab0d646ef3976f84d5d5b132d18f2fa89cdad5" @pytest.fixture def mock_sig_hash(self): return "df25b6efdd39d82e7b30e75ea19655b306860ad5cde3eeaeb6f1dfea029ea259" - def test_missing_body(self, mock_header, mock_secret): - with pytest.raises(ValueError) as err: - self.webhooks.verify_event(None, mock_header, mock_secret) - assert "Payload body is missing and is a required parameter" in str(err.value) - - def test_missing_header(self, mock_event_body, mock_secret): - with pytest.raises(ValueError) as err: - self.webhooks.verify_event( - mock_event_body.encode("utf-8"), None, mock_secret - ) - assert "Payload signature missing and is a required parameter" in str(err.value) - - def test_missing_secret(self, mock_event_body, mock_header): - with pytest.raises(ValueError) as err: - self.webhooks.verify_event( - mock_event_body.encode("utf-8"), mock_header, None - ) - assert "Secret is missing and is a required parameter" in str(err.value) - def test_unable_to_extract_timestamp( self, mock_event_body, mock_header_no_timestamp, mock_secret ): diff --git a/workos/resources/webhooks.py b/workos/resources/webhooks.py new file mode 100644 index 00000000..6579fa68 --- /dev/null +++ b/workos/resources/webhooks.py @@ -0,0 +1,283 @@ +from typing import Generic, Literal, Union +from pydantic import Field +from typing_extensions import Annotated +from workos.resources.directory_sync import DirectoryGroup +from workos.resources.events import EventPayload +from workos.resources.user_management import OrganizationMembership, User +from workos.resources.workos_model import WorkOSModel +from workos.types.directory_sync.directory_user import DirectoryUser +from workos.types.events.authentication_payload import ( + AuthenticationEmailVerificationSucceededPayload, + AuthenticationMagicAuthFailedPayload, + AuthenticationMagicAuthSucceededPayload, + AuthenticationMfaSucceededPayload, + AuthenticationOauthSucceededPayload, + AuthenticationPasswordFailedPayload, + AuthenticationPasswordSucceededPayload, + AuthenticationSsoSucceededPayload, +) +from workos.types.events.connection_payload_with_legacy_fields import ( + ConnectionPayloadWithLegacyFields, +) +from workos.types.events.directory_group_membership_payload import ( + DirectoryGroupMembershipPayload, +) +from workos.types.events.directory_group_with_previous_attributes import ( + DirectoryGroupWithPreviousAttributes, +) +from workos.types.events.directory_payload import DirectoryPayload +from workos.types.events.directory_payload_with_legacy_fields import ( + DirectoryPayloadWithLegacyFields, +) +from workos.types.events.directory_user_with_previous_attributes import ( + DirectoryUserWithPreviousAttributes, +) +from workos.types.events.organization_domain_verification_failed_payload import ( + OrganizationDomainVerificationFailedPayload, +) +from workos.types.events.session_created_payload import SessionCreatedPayload +from workos.types.organizations.organization_common import OrganizationCommon +from workos.types.organizations.organization_domain import OrganizationDomain +from workos.types.roles.role import Role +from workos.types.sso.connection import Connection +from workos.types.user_management.email_verification_common import ( + EmailVerificationCommon, +) +from workos.types.user_management.invitation_common import InvitationCommon +from workos.types.user_management.magic_auth_common import MagicAuthCommon +from workos.types.user_management.password_reset_common import PasswordResetCommon + + +class WebhookModel(WorkOSModel, Generic[EventPayload]): + """Representation of an Webhook delivered via Webhook. + Attributes: + OBJECT_FIELDS (list): List of fields an Webhook is comprised of. + """ + + id: str + data: EventPayload + created_at: str + + +class AuthenticationEmailVerificationSucceededWebhook( + WebhookModel[AuthenticationEmailVerificationSucceededPayload,] +): + event: Literal["authentication.email_verification_succeeded"] + + +class AuthenticationMagicAuthFailedWebhook( + WebhookModel[AuthenticationMagicAuthFailedPayload,] +): + event: Literal["authentication.magic_auth_failed"] + + +class AuthenticationMagicAuthSucceededWebhook( + WebhookModel[AuthenticationMagicAuthSucceededPayload,] +): + event: Literal["authentication.magic_auth_succeeded"] + + +class AuthenticationMfaSucceededWebhook( + WebhookModel[AuthenticationMfaSucceededPayload] +): + event: Literal["authentication.mfa_succeeded"] + + +class AuthenticationOauthSucceededWebhook( + WebhookModel[AuthenticationOauthSucceededPayload] +): + event: Literal["authentication.oauth_succeeded"] + + +class AuthenticationPasswordFailedWebhook( + WebhookModel[AuthenticationPasswordFailedPayload] +): + event: Literal["authentication.password_failed"] + + +class AuthenticationPasswordSucceededWebhook( + WebhookModel[AuthenticationPasswordSucceededPayload,] +): + event: Literal["authentication.password_succeeded"] + + +class AuthenticationSsoSucceededWebhook( + WebhookModel[AuthenticationSsoSucceededPayload] +): + event: Literal["authentication.sso_succeeded"] + + +class ConnectionActivatedWebhook(WebhookModel[ConnectionPayloadWithLegacyFields]): + event: Literal["connection.activated"] + + +class ConnectionDeactivatedWebhook(WebhookModel[ConnectionPayloadWithLegacyFields]): + event: Literal["connection.deactivated"] + + +class ConnectionDeletedWebhook(WebhookModel[Connection]): + event: Literal["connection.deleted"] + + +class DirectoryActivatedWebhook(WebhookModel[DirectoryPayloadWithLegacyFields]): + event: Literal["dsync.activated"] + + +class DirectoryDeletedWebhook(WebhookModel[DirectoryPayload]): + event: Literal["dsync.deleted"] + + +class DirectoryGroupCreatedWebhook(WebhookModel[DirectoryGroup]): + event: Literal["dsync.group.created"] + + +class DirectoryGroupDeletedWebhook(WebhookModel[DirectoryGroup]): + event: Literal["dsync.group.deleted"] + + +class DirectoryGroupUpdatedWebhook(WebhookModel[DirectoryGroupWithPreviousAttributes]): + event: Literal["dsync.group.updated"] + + +class DirectoryUserCreatedWebhook(WebhookModel[DirectoryUser]): + event: Literal["dsync.user.created"] + + +class DirectoryUserDeletedWebhook(WebhookModel[DirectoryUser]): + event: Literal["dsync.user.deleted"] + + +class DirectoryUserUpdatedWebhook(WebhookModel[DirectoryUserWithPreviousAttributes]): + event: Literal["dsync.user.updated"] + + +class DirectoryUserAddedToGroupWebhook(WebhookModel[DirectoryGroupMembershipPayload]): + event: Literal["dsync.group.user_added"] + + +class DirectoryUserRemovedFromGroupWebhook( + WebhookModel[DirectoryGroupMembershipPayload] +): + event: Literal["dsync.group.user_removed"] + + +class EmailVerificationCreatedWebhook(WebhookModel[EmailVerificationCommon]): + event: Literal["email_verification.created"] + + +class InvitationCreatedWebhook(WebhookModel[InvitationCommon]): + event: Literal["invitation.created"] + + +class MagicAuthCreatedWebhook(WebhookModel[MagicAuthCommon]): + event: Literal["magic_auth.created"] + + +class OrganizationCreatedWebhook(WebhookModel[OrganizationCommon]): + event: Literal["organization.created"] + + +class OrganizationDeletedWebhook(WebhookModel[OrganizationCommon]): + event: Literal["organization.deleted"] + + +class OrganizationUpdatedWebhook(WebhookModel[OrganizationCommon]): + event: Literal["organization.updated"] + + +class OrganizationDomainVerificationFailedWebhook( + WebhookModel[OrganizationDomainVerificationFailedPayload,] +): + event: Literal["organization_domain.verification_failed"] + + +class OrganizationDomainVerifiedWebhook(WebhookModel[OrganizationDomain]): + event: Literal["organization_domain.verified"] + + +class OrganizationMembershipCreatedWebhook(WebhookModel[OrganizationMembership]): + event: Literal["organization_membership.created"] + + +class OrganizationMembershipDeletedWebhook(WebhookModel[OrganizationMembership]): + event: Literal["organization_membership.deleted"] + + +class OrganizationMembershipUpdatedWebhook(WebhookModel[OrganizationMembership]): + event: Literal["organization_membership.updated"] + + +class PasswordResetCreatedWebhook(WebhookModel[PasswordResetCommon]): + event: Literal["password_reset.created"] + + +class RoleCreatedWebhook(WebhookModel[Role]): + event: Literal["role.created"] + + +class RoleDeletedWebhook(WebhookModel[Role]): + event: Literal["role.deleted"] + + +class RoleUpdatedWebhook(WebhookModel[Role]): + event: Literal["role.updated"] + + +class SessionCreatedWebhook(WebhookModel[SessionCreatedPayload]): + event: Literal["session.created"] + + +class UserCreatedWebhook(WebhookModel[User]): + event: Literal["user.created"] + + +class UserDeletedWebhook(WebhookModel[User]): + event: Literal["user.deleted"] + + +class UserUpdatedWebhook(WebhookModel[User]): + event: Literal["user.updated"] + + +Webhook = Annotated[ + Union[ + AuthenticationEmailVerificationSucceededWebhook, + AuthenticationMagicAuthFailedWebhook, + AuthenticationMagicAuthSucceededWebhook, + AuthenticationMfaSucceededWebhook, + AuthenticationOauthSucceededWebhook, + AuthenticationPasswordFailedWebhook, + AuthenticationPasswordSucceededWebhook, + AuthenticationSsoSucceededWebhook, + ConnectionActivatedWebhook, + ConnectionDeactivatedWebhook, + ConnectionDeletedWebhook, + DirectoryActivatedWebhook, + DirectoryDeletedWebhook, + DirectoryGroupCreatedWebhook, + DirectoryGroupDeletedWebhook, + DirectoryGroupUpdatedWebhook, + DirectoryUserCreatedWebhook, + DirectoryUserDeletedWebhook, + DirectoryUserUpdatedWebhook, + DirectoryUserAddedToGroupWebhook, + DirectoryUserRemovedFromGroupWebhook, + EmailVerificationCreatedWebhook, + InvitationCreatedWebhook, + MagicAuthCreatedWebhook, + OrganizationCreatedWebhook, + OrganizationDeletedWebhook, + OrganizationUpdatedWebhook, + OrganizationDomainVerificationFailedWebhook, + OrganizationDomainVerifiedWebhook, + PasswordResetCreatedWebhook, + RoleCreatedWebhook, + RoleDeletedWebhook, + RoleUpdatedWebhook, + SessionCreatedWebhook, + UserCreatedWebhook, + UserDeletedWebhook, + UserUpdatedWebhook, + ], + Field(..., discriminator="event"), +] diff --git a/workos/webhooks.py b/workos/webhooks.py index c97df151..e19cfe5f 100644 --- a/workos/webhooks.py +++ b/workos/webhooks.py @@ -1,18 +1,32 @@ -from typing import Protocol - +from typing import Optional, Protocol, Union +from pydantic import TypeAdapter +from workos.resources.webhooks import Webhook from workos.utils.request_helper import RequestHelper from workos.utils.validation import WEBHOOKS_MODULE, validate_settings import hmac -import json import time -from collections import OrderedDict import hashlib +WebhookPayload = Union[bytes, bytearray] +WebhookTypeAdapter: TypeAdapter[Webhook] = TypeAdapter(Webhook) -class WebhooksModule(Protocol): - def verify_event(self, payload, sig_header, secret, tolerance) -> dict: ... - def verify_header(self, event_body, event_signature, secret, tolerance) -> None: ... +class WebhooksModule(Protocol): + def verify_event( + self, + payload: WebhookPayload, + sig_header: str, + secret: str, + tolerance: Optional[int] = None, + ) -> Webhook: ... + + def verify_header( + self, + event_body: WebhookPayload, + event_signature: str, + secret: str, + tolerance: Optional[int] = None, + ) -> None: ... def constant_time_compare(self, val1, val2) -> bool: ... @@ -34,19 +48,23 @@ def request_helper(self): DEFAULT_TOLERANCE = 180 - def verify_event(self, payload, sig_header, secret, tolerance=DEFAULT_TOLERANCE): - if payload is None: - raise ValueError("Payload body is missing and is a required parameter") - if sig_header is None: - raise ValueError("Payload signature missing and is a required parameter") - if secret is None: - raise ValueError("Secret is missing and is a required parameter") - + def verify_event( + self, + payload: WebhookPayload, + sig_header: str, + secret: str, + tolerance: Optional[int] = DEFAULT_TOLERANCE, + ) -> Webhook: Webhooks.verify_header(self, payload, sig_header, secret, tolerance) - event = json.loads(payload, object_pairs_hook=OrderedDict) - return event - - def verify_header(self, event_body, event_signature, secret, tolerance=None): + return WebhookTypeAdapter.validate_json(payload) + + def verify_header( + self, + event_body: WebhookPayload, + event_signature: str, + secret: str, + tolerance: Optional[int] = None, + ) -> None: try: # Verify and define variables parsed from the event body issued_timestamp, signature_hash = event_signature.split(", ")