Skip to content

Commit

Permalink
enforce endpoint uri validation (#46)
Browse files Browse the repository at this point in the history
* enforce endpoint uri validation
* introduce TestOAuthClient for AS development and testing purposes
  • Loading branch information
guillp authored Feb 14, 2024
1 parent b7e7aba commit cf3eae8
Show file tree
Hide file tree
Showing 8 changed files with 230 additions and 47 deletions.
4 changes: 4 additions & 0 deletions requests_oauth2client/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
AuthorizationRequest,
AuthorizationRequestSerializer,
AuthorizationResponse,
CodeChallengeMethods,
PkceUtils,
RequestParameterAuthorizationRequest,
RequestUriParameterAuthorizationRequest,
Expand All @@ -32,6 +33,7 @@
from .client import (
GrantType,
OAuth2Client,
TestingOAuth2Client,
)
from .client_authentication import (
BaseClientAuthenticationMethod,
Expand Down Expand Up @@ -128,6 +130,7 @@
"ClientSecretBasic",
"ClientSecretJwt",
"ClientSecretPost",
"CodeChallengeMethods",
"ConsentRequired",
"DeviceAuthorizationError",
"DeviceAuthorizationPoolingJob",
Expand Down Expand Up @@ -181,6 +184,7 @@
"SessionSelectionRequired",
"SignatureAlgs",
"SlowDown",
"TestingOAuth2Client",
"TokenEndpointError",
"TokenEndpointPoolingJob",
"UnauthorizedClient",
Expand Down
8 changes: 8 additions & 0 deletions requests_oauth2client/authorization_request.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import re
import secrets
from datetime import datetime
from enum import Enum
from typing import Any, Callable, ClassVar, Iterable, Sequence

from attrs import Factory, asdict, field, fields, frozen
Expand Down Expand Up @@ -107,6 +108,13 @@ def validate_code_verifier(cls, verifier: str, challenge: str, method: str = "S2
return cls.code_verifier_re.match(verifier) is not None and cls.derive_challenge(verifier, method) == challenge


class CodeChallengeMethods(str, Enum):
"""PKCE Code Challenge Methods."""

plain = "plain"
S256 = "S256"


@frozen(init=False)
class AuthorizationResponse:
"""Represent a successful Authorization Response.
Expand Down
158 changes: 132 additions & 26 deletions requests_oauth2client/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,17 +2,20 @@

from __future__ import annotations

import warnings
from enum import Enum
from typing import Any, Callable, ClassVar, Iterable, TypeVar

import requests
from attrs import field, frozen
from jwskate import Jwk, JwkSet, Jwt, SignatureAlgs
from typing_extensions import override

from .auth import BearerAuth
from .authorization_request import (
AuthorizationRequest,
AuthorizationResponse,
CodeChallengeMethods,
RequestUriParameterAuthorizationRequest,
)
from .backchannel_authentication import BackChannelAuthenticationResponse
Expand Down Expand Up @@ -44,7 +47,7 @@
UnsupportedTokenType,
)
from .tokens import BearerToken, IdToken, TokenType
from .utils import validate_endpoint_uri
from .utils import validate_endpoint_uri, validate_issuer_uri

T = TypeVar("T")

Expand Down Expand Up @@ -118,7 +121,7 @@ class OAuth2Client:
"""

auth: requests.auth.AuthBase = field(converter=client_auth_factory)
token_endpoint: str
token_endpoint: str = field()
revocation_endpoint: str | None
introspection_endpoint: str | None
userinfo_endpoint: str | None
Expand Down Expand Up @@ -224,18 +227,28 @@ def __init__( # noqa: PLR0913
session = requests.Session()

self.__attrs_init__(
token_endpoint=token_endpoint,
revocation_endpoint=revocation_endpoint,
introspection_endpoint=introspection_endpoint,
userinfo_endpoint=userinfo_endpoint,
authorization_endpoint=authorization_endpoint,
token_endpoint=self.validate_endpoint_uri(token_endpoint),
revocation_endpoint=self.validate_endpoint_uri(revocation_endpoint) if revocation_endpoint else None,
introspection_endpoint=self.validate_endpoint_uri(introspection_endpoint)
if introspection_endpoint
else None,
userinfo_endpoint=self.validate_endpoint_uri(userinfo_endpoint) if userinfo_endpoint else None,
authorization_endpoint=self.validate_endpoint_uri(authorization_endpoint)
if authorization_endpoint
else None,
redirect_uri=redirect_uri,
backchannel_authentication_endpoint=backchannel_authentication_endpoint,
device_authorization_endpoint=device_authorization_endpoint,
pushed_authorization_request_endpoint=pushed_authorization_request_endpoint,
jwks_uri=jwks_uri,
backchannel_authentication_endpoint=self.validate_endpoint_uri(backchannel_authentication_endpoint)
if backchannel_authentication_endpoint
else None,
device_authorization_endpoint=self.validate_endpoint_uri(device_authorization_endpoint)
if device_authorization_endpoint
else None,
pushed_authorization_request_endpoint=self.validate_endpoint_uri(pushed_authorization_request_endpoint)
if pushed_authorization_request_endpoint
else None,
jwks_uri=self.validate_endpoint_uri(jwks_uri) if jwks_uri else None,
authorization_server_jwks=authorization_server_jwks,
issuer=issuer,
issuer=self.validate_issuer_uri(issuer) if issuer else None,
session=session,
auth=auth,
id_token_signed_response_alg=id_token_signed_response_alg,
Expand All @@ -247,6 +260,23 @@ def __init__( # noqa: PLR0913
extra_metadata=extra_metadata,
)

def validate_endpoint_uri(self, uri: str) -> str:
"""Validate that an endpoint URI is suitable for use.
If you need to disable some checks (for AS testing purposes only!), provide a different
method here.
"""
return validate_endpoint_uri(uri)

def validate_issuer_uri(self, uri: str) -> str:
"""Validate that an Issuer identifier is suitable for use.
This is the same check as an endpoint URI, but the path may be (and usually is) empty.
"""
return validate_issuer_uri(uri)

@property
def client_id(self) -> str:
"""Client ID."""
Expand Down Expand Up @@ -1473,13 +1503,18 @@ def from_discovery_document(
private_key: private key to sign client assertions
authorization_server_jwks: the current authorization server JWKS keys
session: a requests Session to use to retrieve the document and initialise the client with
https: if True, validates that urls in the discovery document use the https scheme
https: if `True`, validates that urls in the discovery document use the https scheme
**kwargs: additional args that will be passed to OAuth2Client
Returns:
an OAuth2Client
an `OAuth2Client`
"""
if not https:
warnings.warn(
"The https parameter is deprecated. Use the TestingOAuth2Client subclass instead.", stacklevel=1
)
cls = TestingOAuth2Client
if issuer and discovery.get("issuer") != issuer:
msg = "Mismatching issuer value in discovery document: "
raise ValueError(
Expand All @@ -1494,20 +1529,10 @@ def from_discovery_document(
if token_endpoint is None:
msg = "token_endpoint not found in that discovery document"
raise ValueError(msg)
validate_endpoint_uri(token_endpoint, https=https)
authorization_endpoint = discovery.get("authorization_endpoint")
if authorization_endpoint is not None:
validate_endpoint_uri(authorization_endpoint, https=https)
validate_endpoint_uri(token_endpoint, https=https)
revocation_endpoint = discovery.get("revocation_endpoint")
if revocation_endpoint is not None:
validate_endpoint_uri(revocation_endpoint, https=https)
introspection_endpoint = discovery.get("introspection_endpoint")
if introspection_endpoint is not None:
validate_endpoint_uri(introspection_endpoint, https=https)
userinfo_endpoint = discovery.get("userinfo_endpoint")
if userinfo_endpoint is not None:
validate_endpoint_uri(userinfo_endpoint, https=https)
jwks_uri = discovery.get("jwks_uri")
if jwks_uri is not None:
validate_endpoint_uri(jwks_uri, https=https)
Expand All @@ -1534,9 +1559,9 @@ def from_discovery_document(
)

def __enter__(self) -> OAuth2Client:
"""Allow using OAuth2Client as a context-manager.
"""Allow using `OAuth2Client` as a context-manager.
The Authorization Server public keys are retrieved on __enter__.
The Authorization Server public keys are retrieved on `__enter__`.
"""
self.update_authorization_server_public_keys()
Expand Down Expand Up @@ -1569,3 +1594,84 @@ class GrantType(str, Enum):
JWT_BEARER = "urn:ietf:params:oauth:grant-type:jwt-bearer"
CLIENT_INITIATED_BACKCHANNEL_AUTHENTICATION = "urn:openid:params:grant-type:ciba"
DEVICE_CODE = "urn:ietf:params:oauth:grant-type:device_code"


class TestingOAuth2Client(OAuth2Client):
"""A testing-purposes OAuth2Client, for local AS testing and debugging only.
Compared to the OAuth2Client base class, this will:
- allow arbitrary URLs for the authorization server, including:
- non-HTTPS
- custom ports
- things not allowed by the standards: fragments, username and passwords in URI, etc.
- disable server certificate verification
"""

def __init__( # noqa: PLR0913
self,
token_endpoint: str,
auth: (
requests.auth.AuthBase | tuple[str, str] | tuple[str, Jwk] | tuple[str, dict[str, Any]] | str | None
) = None,
*,
client_id: str | None = None,
client_secret: str | None = None,
private_key: Jwk | dict[str, Any] | None = None,
revocation_endpoint: str | None = None,
introspection_endpoint: str | None = None,
userinfo_endpoint: str | None = None,
authorization_endpoint: str | None = None,
redirect_uri: str | None = None,
backchannel_authentication_endpoint: str | None = None,
device_authorization_endpoint: str | None = None,
pushed_authorization_request_endpoint: str | None = None,
jwks_uri: str | None = None,
authorization_server_jwks: JwkSet | dict[str, Any] | None = None,
issuer: str | None = None,
id_token_signed_response_alg: str | None = SignatureAlgs.RS256,
id_token_encrypted_response_alg: str | None = None,
id_token_decryption_key: Jwk | dict[str, Any] | None = None,
code_challenge_method: str = CodeChallengeMethods.S256,
authorization_response_iss_parameter_supported: bool = False,
session: requests.Session | None = None,
**extra_metadata: Any,
):
if session is None:
session = requests.Session()
session.verify = False
super().__init__(
token_endpoint,
auth,
client_id=client_id,
client_secret=client_secret,
private_key=private_key,
revocation_endpoint=revocation_endpoint,
introspection_endpoint=introspection_endpoint,
userinfo_endpoint=userinfo_endpoint,
authorization_endpoint=authorization_endpoint,
redirect_uri=redirect_uri,
backchannel_authentication_endpoint=backchannel_authentication_endpoint,
device_authorization_endpoint=device_authorization_endpoint,
pushed_authorization_request_endpoint=pushed_authorization_request_endpoint,
jwks_uri=jwks_uri,
authorization_server_jwks=authorization_server_jwks,
issuer=issuer,
id_token_signed_response_alg=id_token_signed_response_alg,
id_token_encrypted_response_alg=id_token_encrypted_response_alg,
id_token_decryption_key=id_token_decryption_key,
code_challenge_method=code_challenge_method,
authorization_response_iss_parameter_supported=authorization_response_iss_parameter_supported,
session=session,
**extra_metadata,
)

@override
def validate_endpoint_uri(self, uri: str) -> str:
"""Disable endpoint URIs validation, for testing purposes."""
return uri

@override
def validate_issuer_uri(self, uri: str) -> str:
"""Disable issuer URI validation, for testing purposes."""
return uri
38 changes: 34 additions & 4 deletions requests_oauth2client/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,22 @@
from furl import furl # type: ignore[import-untyped]


def validate_endpoint_uri(uri: str, *, https: bool = True, no_fragment: bool = True, path: bool = True) -> None:
def validate_endpoint_uri(
uri: str,
*,
https: bool = True,
no_credentials: bool = True,
no_port: bool = True,
no_fragment: bool = True,
path: bool = True,
) -> str:
"""Validate that a URI is suitable as an endpoint URI.
It checks:
- that the scheme is `https`
- that no custom port number is being used
- that no username or password are included
- that no fragment is included
- that a path is present
Expand All @@ -31,25 +41,45 @@ def validate_endpoint_uri(uri: str, *, https: bool = True, no_fragment: bool = T
Args:
uri: the uri
https: if `True`, check that the uri is https
no_port: if `True`, check that no custom port number is included
no_credentials: if ` True`, check that no username/password are included
no_fragment: if `True`, check that the uri contains no fragment
path: if `True`, check that the uri contains a path component
Raises:
ValueError: if the supplied url is not suitable
Returns:
the endpoint URI, if all checks passed
"""
url = furl(uri)
msg: list[str] = []
if https and url.scheme != "https":
msg += "url must use https"
msg.append("url must use https")
if no_port and url.port != 443: # noqa: PLR2004
msg.append("no custom port number allowed")
if no_credentials and url.username or url.password:
msg.append("no username or password are allowed")
if no_fragment and url.fragment:
msg += "url must not contain a fragment"
msg.append("url must not contain a fragment")
if path and (not url.path or url.path == "/"):
msg += "url has no path"
msg.append("url has no path")

if msg:
raise ValueError(", ".join(msg))

return uri


def validate_issuer_uri(uri: str) -> str:
"""Validate that an Issuer Identifier URI is valid.
This is almost the same as a valid endpoint URI, but a path is not mandatory.
"""
return validate_endpoint_uri(uri, path=False)


def accepts_expires_in(f: Callable[..., Any]) -> Callable[..., Any]:
"""Decorate methods to handle both `expires_at` and `expires_in`.
Expand Down
13 changes: 7 additions & 6 deletions requests_oauth2client/vendor_specific/auth0.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,11 +55,12 @@ def client(
) -> OAuth2Client:
"""Initialise an OAuth2Client for an Auth0 tenant."""
tenant = cls.tenant(tenant)
token_endpoint = f"https://{tenant}/oauth/token"
authorization_endpoint = f"https://{tenant}/authorize"
revocation_endpoint = f"https://{tenant}/oauth/revoke"
userinfo_endpoint = f"https://{tenant}/userinfo"
jwks_uri = f"https://{tenant}/.well-known/jwks.json"
issuer = f"https://{tenant}"
token_endpoint = f"{issuer}/oauth/token"
authorization_endpoint = f"{issuer}/authorize"
revocation_endpoint = f"{issuer}/oauth/revoke"
userinfo_endpoint = f"{issuer}/userinfo"
jwks_uri = f"{issuer}/.well-known/jwks.json"

return OAuth2Client(
auth=auth,
Expand All @@ -71,7 +72,7 @@ def client(
authorization_endpoint=authorization_endpoint,
revocation_endpoint=revocation_endpoint,
userinfo_endpoint=userinfo_endpoint,
issuer=tenant,
issuer=issuer,
jwks_uri=jwks_uri,
**kwargs,
)
Expand Down
Loading

0 comments on commit cf3eae8

Please sign in to comment.