diff --git a/requests_oauth2client/__init__.py b/requests_oauth2client/__init__.py index 5af0821..937f5c6 100644 --- a/requests_oauth2client/__init__.py +++ b/requests_oauth2client/__init__.py @@ -21,6 +21,7 @@ AuthorizationRequest, AuthorizationRequestSerializer, AuthorizationResponse, + CodeChallengeMethods, PkceUtils, RequestParameterAuthorizationRequest, RequestUriParameterAuthorizationRequest, @@ -32,6 +33,7 @@ from .client import ( GrantType, OAuth2Client, + TestingOAuth2Client, ) from .client_authentication import ( BaseClientAuthenticationMethod, @@ -128,6 +130,7 @@ "ClientSecretBasic", "ClientSecretJwt", "ClientSecretPost", + "CodeChallengeMethods", "ConsentRequired", "DeviceAuthorizationError", "DeviceAuthorizationPoolingJob", @@ -181,6 +184,7 @@ "SessionSelectionRequired", "SignatureAlgs", "SlowDown", + "TestingOAuth2Client", "TokenEndpointError", "TokenEndpointPoolingJob", "UnauthorizedClient", diff --git a/requests_oauth2client/authorization_request.py b/requests_oauth2client/authorization_request.py index 660b8de..56adf8e 100644 --- a/requests_oauth2client/authorization_request.py +++ b/requests_oauth2client/authorization_request.py @@ -5,6 +5,7 @@ import re import secrets from datetime import datetime +from enum import Enum from typing import Any, Callable, ClassVar, Iterable, Sequence from attrs import Factory, asdict, field, fields, frozen @@ -107,6 +108,13 @@ def validate_code_verifier(cls, verifier: str, challenge: str, method: str = "S2 return cls.code_verifier_re.match(verifier) is not None and cls.derive_challenge(verifier, method) == challenge +class CodeChallengeMethods(str, Enum): + """PKCE Code Challenge Methods.""" + + plain = "plain" + S256 = "S256" + + @frozen(init=False) class AuthorizationResponse: """Represent a successful Authorization Response. diff --git a/requests_oauth2client/client.py b/requests_oauth2client/client.py index ba21ca8..72c33f9 100644 --- a/requests_oauth2client/client.py +++ b/requests_oauth2client/client.py @@ -2,17 +2,20 @@ from __future__ import annotations +import warnings from enum import Enum from typing import Any, Callable, ClassVar, Iterable, TypeVar import requests from attrs import field, frozen from jwskate import Jwk, JwkSet, Jwt, SignatureAlgs +from typing_extensions import override from .auth import BearerAuth from .authorization_request import ( AuthorizationRequest, AuthorizationResponse, + CodeChallengeMethods, RequestUriParameterAuthorizationRequest, ) from .backchannel_authentication import BackChannelAuthenticationResponse @@ -44,7 +47,7 @@ UnsupportedTokenType, ) from .tokens import BearerToken, IdToken, TokenType -from .utils import validate_endpoint_uri +from .utils import validate_endpoint_uri, validate_issuer_uri T = TypeVar("T") @@ -118,7 +121,7 @@ class OAuth2Client: """ auth: requests.auth.AuthBase = field(converter=client_auth_factory) - token_endpoint: str + token_endpoint: str = field() revocation_endpoint: str | None introspection_endpoint: str | None userinfo_endpoint: str | None @@ -224,18 +227,28 @@ def __init__( # noqa: PLR0913 session = requests.Session() self.__attrs_init__( - token_endpoint=token_endpoint, - revocation_endpoint=revocation_endpoint, - introspection_endpoint=introspection_endpoint, - userinfo_endpoint=userinfo_endpoint, - authorization_endpoint=authorization_endpoint, + token_endpoint=self.validate_endpoint_uri(token_endpoint), + revocation_endpoint=self.validate_endpoint_uri(revocation_endpoint) if revocation_endpoint else None, + introspection_endpoint=self.validate_endpoint_uri(introspection_endpoint) + if introspection_endpoint + else None, + userinfo_endpoint=self.validate_endpoint_uri(userinfo_endpoint) if userinfo_endpoint else None, + authorization_endpoint=self.validate_endpoint_uri(authorization_endpoint) + if authorization_endpoint + else None, redirect_uri=redirect_uri, - backchannel_authentication_endpoint=backchannel_authentication_endpoint, - device_authorization_endpoint=device_authorization_endpoint, - pushed_authorization_request_endpoint=pushed_authorization_request_endpoint, - jwks_uri=jwks_uri, + backchannel_authentication_endpoint=self.validate_endpoint_uri(backchannel_authentication_endpoint) + if backchannel_authentication_endpoint + else None, + device_authorization_endpoint=self.validate_endpoint_uri(device_authorization_endpoint) + if device_authorization_endpoint + else None, + pushed_authorization_request_endpoint=self.validate_endpoint_uri(pushed_authorization_request_endpoint) + if pushed_authorization_request_endpoint + else None, + jwks_uri=self.validate_endpoint_uri(jwks_uri) if jwks_uri else None, authorization_server_jwks=authorization_server_jwks, - issuer=issuer, + issuer=self.validate_issuer_uri(issuer) if issuer else None, session=session, auth=auth, id_token_signed_response_alg=id_token_signed_response_alg, @@ -247,6 +260,23 @@ def __init__( # noqa: PLR0913 extra_metadata=extra_metadata, ) + def validate_endpoint_uri(self, uri: str) -> str: + """Validate that an endpoint URI is suitable for use. + + If you need to disable some checks (for AS testing purposes only!), provide a different + method here. + + """ + return validate_endpoint_uri(uri) + + def validate_issuer_uri(self, uri: str) -> str: + """Validate that an Issuer identifier is suitable for use. + + This is the same check as an endpoint URI, but the path may be (and usually is) empty. + + """ + return validate_issuer_uri(uri) + @property def client_id(self) -> str: """Client ID.""" @@ -1473,13 +1503,18 @@ def from_discovery_document( private_key: private key to sign client assertions authorization_server_jwks: the current authorization server JWKS keys session: a requests Session to use to retrieve the document and initialise the client with - https: if True, validates that urls in the discovery document use the https scheme + https: if `True`, validates that urls in the discovery document use the https scheme **kwargs: additional args that will be passed to OAuth2Client Returns: - an OAuth2Client + an `OAuth2Client` """ + if not https: + warnings.warn( + "The https parameter is deprecated. Use the TestingOAuth2Client subclass instead.", stacklevel=1 + ) + cls = TestingOAuth2Client if issuer and discovery.get("issuer") != issuer: msg = "Mismatching issuer value in discovery document: " raise ValueError( @@ -1494,20 +1529,10 @@ def from_discovery_document( if token_endpoint is None: msg = "token_endpoint not found in that discovery document" raise ValueError(msg) - validate_endpoint_uri(token_endpoint, https=https) authorization_endpoint = discovery.get("authorization_endpoint") - if authorization_endpoint is not None: - validate_endpoint_uri(authorization_endpoint, https=https) - validate_endpoint_uri(token_endpoint, https=https) revocation_endpoint = discovery.get("revocation_endpoint") - if revocation_endpoint is not None: - validate_endpoint_uri(revocation_endpoint, https=https) introspection_endpoint = discovery.get("introspection_endpoint") - if introspection_endpoint is not None: - validate_endpoint_uri(introspection_endpoint, https=https) userinfo_endpoint = discovery.get("userinfo_endpoint") - if userinfo_endpoint is not None: - validate_endpoint_uri(userinfo_endpoint, https=https) jwks_uri = discovery.get("jwks_uri") if jwks_uri is not None: validate_endpoint_uri(jwks_uri, https=https) @@ -1534,9 +1559,9 @@ def from_discovery_document( ) def __enter__(self) -> OAuth2Client: - """Allow using OAuth2Client as a context-manager. + """Allow using `OAuth2Client` as a context-manager. - The Authorization Server public keys are retrieved on __enter__. + The Authorization Server public keys are retrieved on `__enter__`. """ self.update_authorization_server_public_keys() @@ -1569,3 +1594,84 @@ class GrantType(str, Enum): JWT_BEARER = "urn:ietf:params:oauth:grant-type:jwt-bearer" CLIENT_INITIATED_BACKCHANNEL_AUTHENTICATION = "urn:openid:params:grant-type:ciba" DEVICE_CODE = "urn:ietf:params:oauth:grant-type:device_code" + + +class TestingOAuth2Client(OAuth2Client): + """A testing-purposes OAuth2Client, for local AS testing and debugging only. + + Compared to the OAuth2Client base class, this will: + - allow arbitrary URLs for the authorization server, including: + - non-HTTPS + - custom ports + - things not allowed by the standards: fragments, username and passwords in URI, etc. + - disable server certificate verification + + """ + + def __init__( # noqa: PLR0913 + self, + token_endpoint: str, + auth: ( + requests.auth.AuthBase | tuple[str, str] | tuple[str, Jwk] | tuple[str, dict[str, Any]] | str | None + ) = None, + *, + client_id: str | None = None, + client_secret: str | None = None, + private_key: Jwk | dict[str, Any] | None = None, + revocation_endpoint: str | None = None, + introspection_endpoint: str | None = None, + userinfo_endpoint: str | None = None, + authorization_endpoint: str | None = None, + redirect_uri: str | None = None, + backchannel_authentication_endpoint: str | None = None, + device_authorization_endpoint: str | None = None, + pushed_authorization_request_endpoint: str | None = None, + jwks_uri: str | None = None, + authorization_server_jwks: JwkSet | dict[str, Any] | None = None, + issuer: str | None = None, + id_token_signed_response_alg: str | None = SignatureAlgs.RS256, + id_token_encrypted_response_alg: str | None = None, + id_token_decryption_key: Jwk | dict[str, Any] | None = None, + code_challenge_method: str = CodeChallengeMethods.S256, + authorization_response_iss_parameter_supported: bool = False, + session: requests.Session | None = None, + **extra_metadata: Any, + ): + if session is None: + session = requests.Session() + session.verify = False + super().__init__( + token_endpoint, + auth, + client_id=client_id, + client_secret=client_secret, + private_key=private_key, + revocation_endpoint=revocation_endpoint, + introspection_endpoint=introspection_endpoint, + userinfo_endpoint=userinfo_endpoint, + authorization_endpoint=authorization_endpoint, + redirect_uri=redirect_uri, + backchannel_authentication_endpoint=backchannel_authentication_endpoint, + device_authorization_endpoint=device_authorization_endpoint, + pushed_authorization_request_endpoint=pushed_authorization_request_endpoint, + jwks_uri=jwks_uri, + authorization_server_jwks=authorization_server_jwks, + issuer=issuer, + id_token_signed_response_alg=id_token_signed_response_alg, + id_token_encrypted_response_alg=id_token_encrypted_response_alg, + id_token_decryption_key=id_token_decryption_key, + code_challenge_method=code_challenge_method, + authorization_response_iss_parameter_supported=authorization_response_iss_parameter_supported, + session=session, + **extra_metadata, + ) + + @override + def validate_endpoint_uri(self, uri: str) -> str: + """Disable endpoint URIs validation, for testing purposes.""" + return uri + + @override + def validate_issuer_uri(self, uri: str) -> str: + """Disable issuer URI validation, for testing purposes.""" + return uri diff --git a/requests_oauth2client/utils.py b/requests_oauth2client/utils.py index 57ee8f4..d151b48 100644 --- a/requests_oauth2client/utils.py +++ b/requests_oauth2client/utils.py @@ -13,12 +13,22 @@ from furl import furl # type: ignore[import-untyped] -def validate_endpoint_uri(uri: str, *, https: bool = True, no_fragment: bool = True, path: bool = True) -> None: +def validate_endpoint_uri( + uri: str, + *, + https: bool = True, + no_credentials: bool = True, + no_port: bool = True, + no_fragment: bool = True, + path: bool = True, +) -> str: """Validate that a URI is suitable as an endpoint URI. It checks: - that the scheme is `https` + - that no custom port number is being used + - that no username or password are included - that no fragment is included - that a path is present @@ -31,25 +41,45 @@ def validate_endpoint_uri(uri: str, *, https: bool = True, no_fragment: bool = T Args: uri: the uri https: if `True`, check that the uri is https + no_port: if `True`, check that no custom port number is included + no_credentials: if ` True`, check that no username/password are included no_fragment: if `True`, check that the uri contains no fragment path: if `True`, check that the uri contains a path component Raises: ValueError: if the supplied url is not suitable + Returns: + the endpoint URI, if all checks passed + """ url = furl(uri) msg: list[str] = [] if https and url.scheme != "https": - msg += "url must use https" + msg.append("url must use https") + if no_port and url.port != 443: # noqa: PLR2004 + msg.append("no custom port number allowed") + if no_credentials and url.username or url.password: + msg.append("no username or password are allowed") if no_fragment and url.fragment: - msg += "url must not contain a fragment" + msg.append("url must not contain a fragment") if path and (not url.path or url.path == "/"): - msg += "url has no path" + msg.append("url has no path") if msg: raise ValueError(", ".join(msg)) + return uri + + +def validate_issuer_uri(uri: str) -> str: + """Validate that an Issuer Identifier URI is valid. + + This is almost the same as a valid endpoint URI, but a path is not mandatory. + + """ + return validate_endpoint_uri(uri, path=False) + def accepts_expires_in(f: Callable[..., Any]) -> Callable[..., Any]: """Decorate methods to handle both `expires_at` and `expires_in`. diff --git a/requests_oauth2client/vendor_specific/auth0.py b/requests_oauth2client/vendor_specific/auth0.py index 0d8ca98..0372d84 100644 --- a/requests_oauth2client/vendor_specific/auth0.py +++ b/requests_oauth2client/vendor_specific/auth0.py @@ -55,11 +55,12 @@ def client( ) -> OAuth2Client: """Initialise an OAuth2Client for an Auth0 tenant.""" tenant = cls.tenant(tenant) - token_endpoint = f"https://{tenant}/oauth/token" - authorization_endpoint = f"https://{tenant}/authorize" - revocation_endpoint = f"https://{tenant}/oauth/revoke" - userinfo_endpoint = f"https://{tenant}/userinfo" - jwks_uri = f"https://{tenant}/.well-known/jwks.json" + issuer = f"https://{tenant}" + token_endpoint = f"{issuer}/oauth/token" + authorization_endpoint = f"{issuer}/authorize" + revocation_endpoint = f"{issuer}/oauth/revoke" + userinfo_endpoint = f"{issuer}/userinfo" + jwks_uri = f"{issuer}/.well-known/jwks.json" return OAuth2Client( auth=auth, @@ -71,7 +72,7 @@ def client( authorization_endpoint=authorization_endpoint, revocation_endpoint=revocation_endpoint, userinfo_endpoint=userinfo_endpoint, - issuer=tenant, + issuer=issuer, jwks_uri=jwks_uri, **kwargs, ) diff --git a/tests/test_oidc.py b/tests/test_oidc.py index 127fd61..e336a2c 100644 --- a/tests/test_oidc.py +++ b/tests/test_oidc.py @@ -49,8 +49,8 @@ def test_encrypted_id_token(requests_mock: RequestsMocker) -> None: "kid": "Vs6sw5LGsEYfeiAs3rwiOwXKJpw4S926IaOpefvm-Ec", } ) - token_endpoint = "https://token.endpoint" - authorization_endpoint = "https://authorization.endpoint" + token_endpoint = "https://as.local/token" + authorization_endpoint = "https://as.local/authorize" issuer = "https://issuer" claims = {"iss": issuer, "iat": Jwt.timestamp(), "exp": Jwt.timestamp(60), "sub": subject, "nonce": nonce} diff --git a/tests/unit_tests/test_client.py b/tests/unit_tests/test_client.py index 74fec0b..a02b3bd 100644 --- a/tests/unit_tests/test_client.py +++ b/tests/unit_tests/test_client.py @@ -28,6 +28,7 @@ UnauthorizedClient, UnknownIntrospectionError, oidc_discovery_document_url, + TestingOAuth2Client, ) from tests.conftest import RequestsMocker, RequestValidatorType @@ -616,6 +617,17 @@ def test_from_discovery_document( auth=client_id, ) + with pytest.warns(match="https parameter is deprecated"): + OAuth2Client.from_discovery_document({ + "issuer": issuer, + "token_endpoint": token_endpoint, + "revocation_endpoint": revocation_endpoint, + "introspection_endpoint": introspection_endpoint, + "userinfo_endpoint": userinfo_endpoint, + "jwks_uri": jwks_uri, + }, issuer=issuer, + auth=client_id,https=False) + def test_from_discovery_document_missing_token_endpoint(revocation_endpoint: str, client_id: str) -> None: """Invalid discovery documents raises an exception.""" @@ -1401,13 +1413,13 @@ def test_client_authorization_server_jwks() -> None: jwks = Jwk.generate(alg="ES256").public_jwk().as_jwks() assert ( OAuth2Client( - "https://token.endpoint", client_id="client_id", authorization_server_jwks=jwks + "https://as.local/token", client_id="client_id", authorization_server_jwks=jwks ).authorization_server_jwks is jwks ) assert ( OAuth2Client( - "https://token.endpoint", client_id="client_id", authorization_server_jwks=jwks.to_dict() + "https://as.local/token", client_id="client_id", authorization_server_jwks=jwks.to_dict() ).authorization_server_jwks == jwks ) @@ -1417,20 +1429,20 @@ def test_client_id_token_decryption_key() -> None: decryption_key = Jwk.generate(alg=KeyManagementAlgs.ECDH_ES_A256KW) assert ( OAuth2Client( - "https://token.endpoint", client_id="client_id", id_token_decryption_key=decryption_key + "https://as.local/token", client_id="client_id", id_token_decryption_key=decryption_key ).id_token_decryption_key is decryption_key ) assert ( OAuth2Client( - "https://token.endpoint", client_id="client_id", id_token_decryption_key=decryption_key.to_dict() + "https://as.local/token", client_id="client_id", id_token_decryption_key=decryption_key.to_dict() ).id_token_decryption_key == decryption_key ) with pytest.raises(ValueError, match="no decryption algorithm is defined"): assert OAuth2Client( - "https://token.endpoint", client_id="client_id", id_token_decryption_key=decryption_key.minimize() + "https://as.local/token", client_id="client_id", id_token_decryption_key=decryption_key.minimize() ) @@ -1441,4 +1453,22 @@ def __call__(self, request: requests.PreparedRequest) -> requests.PreparedReques return request with pytest.raises(AttributeError, match="custom authentication method without client_id"): - OAuth2Client("https://token.endpoint", auth=CustomAuthHandler()).client_id + OAuth2Client("https://as.local/token", auth=CustomAuthHandler()).client_id + + +def test_testing_oauth2client() -> None: + token_endpoint = "http://localhost:1234/token" + + with pytest.raises(ValueError, match="must use https"): + OAuth2Client(token_endpoint=token_endpoint, client_id="client_id") + + test_client = TestingOAuth2Client( + token_endpoint=token_endpoint, + client_id="foo", + client_secret="bar", + issuer="http://localhost:1234", + ) + + assert test_client.token_endpoint == token_endpoint + + diff --git a/tests/unit_tests/test_utils.py b/tests/unit_tests/test_utils.py index df82acd..395a549 100644 --- a/tests/unit_tests/test_utils.py +++ b/tests/unit_tests/test_utils.py @@ -9,12 +9,16 @@ def test_validate_uri() -> None: validate_endpoint_uri("https://myas.local/token") - with pytest.raises(ValueError): + with pytest.raises(ValueError, match="https"): validate_endpoint_uri("http://myas.local/token") - with pytest.raises(ValueError): + with pytest.raises(ValueError, match="path"): validate_endpoint_uri("https://myas.local") - with pytest.raises(ValueError): + with pytest.raises(ValueError, match="fragment"): validate_endpoint_uri("https://myas.local/token#foo") + with pytest.raises(ValueError, match="username"): + validate_endpoint_uri("https://user:passwd@myas.local/token") + with pytest.raises(ValueError, match="port"): + validate_endpoint_uri("https://myas.local:1234/token") @pytest.mark.parametrize("expires_in", [10, "10"])