From 3c9ec5e8d757ef4d58b9c6dae9fcd6bd9994c09f Mon Sep 17 00:00:00 2001 From: James Elliott Date: Sun, 14 Jul 2024 17:43:33 +1000 Subject: [PATCH 01/33] feat(jwt): support jwe This adds support for RFC7516 JSON Web Encryption (JWE). This adds a layer of privacy on top of tokens generated and asserted making their contents irrelevant. This is supported for both generated tokens and decoded tokens for elements like client assertions. This implementation uses the nested format which is generally the supported and recommended format. --- access_error_test.go | 4 +- access_write_test.go | 2 +- authorize_request_handler.go | 73 +-- ...orize_request_handler_oidc_request_test.go | 22 +- client.go | 64 ++- client_authentication_jwks_strategy.go | 16 +- client_authentication_test.go | 8 +- compose/compose.go | 14 +- compose/compose_oauth2.go | 4 +- compose/compose_openid.go | 8 +- compose/compose_strategy.go | 14 +- config.go | 23 +- config_default.go | 33 +- fosite.go | 5 +- handler/oauth2/introspector_jwt.go | 36 +- handler/oauth2/introspector_jwt_test.go | 23 +- handler/oauth2/strategy.go | 6 +- handler/oauth2/strategy_jwt_profile.go | 11 +- handler/oauth2/strategy_jwt_profile_test.go | 15 +- handler/oauth2/strategy_jwt_session.go | 7 +- .../openid/flow_device_authorization_test.go | 32 +- handler/openid/flow_explicit_auth_test.go | 15 +- handler/openid/flow_explicit_token_test.go | 19 +- handler/openid/flow_hybrid.go | 2 +- handler/openid/flow_hybrid_test.go | 40 +- handler/openid/flow_implicit.go | 2 +- handler/openid/flow_implicit_test.go | 17 +- handler/openid/flow_refresh_token_test.go | 11 +- handler/openid/helper_test.go | 8 +- handler/openid/strategy_jwt.go | 9 +- handler/openid/strategy_jwt_test.go | 14 +- handler/openid/validator.go | 6 +- handler/openid/validator_test.go | 11 +- handler/rfc7523/handler.go | 4 +- handler/rfc8693/custom_jwt_type_handler.go | 10 +- handler/rfc8693/id_token_type_handler.go | 4 +- handler/rfc8693/token_exchange_test.go | 25 +- integration/helper_setup_test.go | 8 +- integration/introspect_token_test.go | 11 +- internal/consts/client_auth_method.go | 2 +- internal/consts/const.go | 1 + internal/consts/jwt.go | 22 +- internal/consts/spec.go | 2 +- introspection_response_writer.go | 4 +- response_handler.go | 2 +- testing/mock/client.go | 15 + token/jarm/generate.go | 6 +- token/jarm/types.go | 4 +- token/jwt/client.go | 468 +++++++++++++++++ token/jwt/client_test.go | 56 ++ token/jwt/consts.go | 34 ++ token/jwt/errors.go | 7 + token/jwt/issuer.go | 115 ++++ token/jwt/jwt.go | 198 ------- token/jwt/jwt_signer_test.go | 320 +++++++++++ token/jwt/jwt_strategy.go | 284 ++++++++++ token/jwt/jwt_strategy_opts.go | 179 +++++++ token/jwt/jwt_strategy_test.go | 361 +++++++++++++ token/jwt/jwt_test.go | 214 -------- token/jwt/token.go | 497 ++++++++++++------ token/jwt/token_test.go | 14 +- token/jwt/util.go | 310 +++++++++++ 62 files changed, 2846 insertions(+), 905 deletions(-) create mode 100644 token/jwt/client.go create mode 100644 token/jwt/client_test.go create mode 100644 token/jwt/consts.go create mode 100644 token/jwt/errors.go create mode 100644 token/jwt/issuer.go delete mode 100644 token/jwt/jwt.go create mode 100644 token/jwt/jwt_signer_test.go create mode 100644 token/jwt/jwt_strategy.go create mode 100644 token/jwt/jwt_strategy_opts.go create mode 100644 token/jwt/jwt_strategy_test.go delete mode 100644 token/jwt/jwt_test.go create mode 100644 token/jwt/util.go diff --git a/access_error_test.go b/access_error_test.go index d43f9f85..1485e131 100644 --- a/access_error_test.go +++ b/access_error_test.go @@ -30,7 +30,7 @@ func TestWriteAccessError(t *testing.T) { rw.EXPECT().WriteHeader(http.StatusBadRequest) rw.EXPECT().Write(gomock.Any()) - provider.WriteAccessError(context.Background(), rw, nil, ErrInvalidRequest) + provider.WriteAccessError(context.TODO(), rw, nil, ErrInvalidRequest) } func TestWriteAccessError_RFC6749(t *testing.T) { @@ -62,7 +62,7 @@ func TestWriteAccessError_RFC6749(t *testing.T) { config.UseLegacyErrorFormat = c.includeExtraFields rw := httptest.NewRecorder() - provider.WriteAccessError(context.Background(), rw, nil, c.err) + provider.WriteAccessError(context.TODO(), rw, nil, c.err) var params struct { Error string `json:"error"` // specified by RFC, required diff --git a/access_write_test.go b/access_write_test.go index 9b7124d2..6de92459 100644 --- a/access_write_test.go +++ b/access_write_test.go @@ -30,7 +30,7 @@ func TestWriteAccessResponse(t *testing.T) { rw.EXPECT().Write(gomock.Any()) resp.EXPECT().ToMap().Return(map[string]any{}) - provider.WriteAccessResponse(context.Background(), rw, ar, resp) + provider.WriteAccessResponse(context.TODO(), rw, ar, resp) assert.Equal(t, consts.ContentTypeApplicationJSON, header.Get(consts.HeaderContentType)) assert.Equal(t, consts.CacheControlNoStore, header.Get(consts.HeaderCacheControl)) assert.Equal(t, consts.PragmaNoCache, header.Get(consts.HeaderPragma)) diff --git a/authorize_request_handler.go b/authorize_request_handler.go index 6bd8b451..46573d2a 100644 --- a/authorize_request_handler.go +++ b/authorize_request_handler.go @@ -10,7 +10,6 @@ import ( "net/http" "strings" - "github.com/go-jose/go-jose/v4" "github.com/pkg/errors" "authelia.com/provider/oauth2/i18n" @@ -74,10 +73,11 @@ func (f *Fosite) authorizeRequestParametersFromOpenIDConnectRequestObject(ctx co } var ( + alg string algAny, algNone bool ) - switch alg := client.GetRequestObjectSigningAlg(); alg { + switch alg = client.GetRequestObjectSigningAlg(); alg { case consts.JSONWebTokenAlgNone: algNone = true case "": @@ -123,65 +123,34 @@ func (f *Fosite) authorizeRequestParametersFromOpenIDConnectRequestObject(ctx co assertion = request.Form.Get(consts.FormParameterRequest) } - token, err := jwt.ParseWithClaims(assertion, jwt.MapClaims{}, func(t *jwt.Token) (key any, err error) { - // request_object_signing_alg - OPTIONAL. - // JWS [JWS] alg algorithm [JWA] that MUST be used for signing Request Objects sent to the OP. All Request Objects from this Client MUST be rejected, - // if not signed with this algorithm. Request Objects are described in Section 6.1 of OpenID Connect Core 1.0 [OpenID.Core]. This algorithm MUST - // be used both when the Request Object is passed by value (using the request parameter) and when it is passed by reference (using the request_uri parameter). - // Servers SHOULD support RS256. The value none MAY be used. The default, if omitted, is that any algorithm supported by the OP and the RP MAY be used. - if !algAny && client.GetRequestObjectSigningAlg() != fmt.Sprintf("%s", t.Header[consts.JSONWebTokenHeaderAlgorithm]) { - return nil, errorsx.WithStack(ErrInvalidRequestObject.WithHintf(hintRequestObjectValidate, hintRequestObjectPrefix(openid)).WithDebugf("The OAuth 2.0 client with id '%s' expects request objects to be signed with the '%s' algorithm but the request object was signed with the '%s' algorithm.", request.GetClient().GetID(), client.GetRequestObjectSigningAlg(), t.Header[consts.JSONWebTokenHeaderAlgorithm])) - } - - if t.Method == jwt.SigningMethodNone { - algNone = true + strategy := f.Config.GetJWTStrategy(ctx) - return jwt.UnsafeAllowNoneSignatureType, nil - } else if algNone { - return nil, errorsx.WithStack(ErrInvalidRequestObject.WithHintf(hintRequestObjectValidate, hintRequestObjectPrefix(openid)).WithDebugf("The OAuth 2.0 client with id '%s' expects request objects to be signed with the '%s' algorithm but the request object was signed with the '%s' algorithm.", request.GetClient().GetID(), client.GetRequestObjectSigningAlg(), t.Header[consts.JSONWebTokenHeaderAlgorithm])) + token, err := strategy.Decode(ctx, assertion, jwt.WithSigAlgorithm(jwt.SignatureAlgorithmsNone...), jwt.WithJARClient(client)) + if err != nil { + var e *jwt.ValidationError + if errors.As(err, &e) { + return wrapSigningKeyFailure(ErrInvalidRequestObject.WithHintf("The OAuth 2.0 client with id '%s' could not validate the request object.", client.GetID()), err) + } else { + return errorsx.WithStack(ErrInvalidRequestObject.WithHintf("The OAuth 2.0 client with id '%s' could not validate the request object.", client.GetID()).WithDebugError(err)) } + } - switch t.Method { - case jose.RS256, jose.RS384, jose.RS512: - if key, err = f.findClientPublicJWK(ctx, client, t, true); err != nil { - return nil, wrapSigningKeyFailure( - ErrInvalidRequestObject.WithHint("Unable to retrieve RSA signing key from OAuth 2.0 Client."), err) - } - - return key, nil - case jose.ES256, jose.ES384, jose.ES512: - if key, err = f.findClientPublicJWK(ctx, client, t, false); err != nil { - return nil, wrapSigningKeyFailure( - ErrInvalidRequestObject.WithHint("Unable to retrieve ECDSA signing key from OAuth 2.0 Client."), err) - } - - return key, nil - case jose.PS256, jose.PS384, jose.PS512: - if key, err = f.findClientPublicJWK(ctx, client, t, true); err != nil { - return nil, wrapSigningKeyFailure( - ErrInvalidRequestObject.WithHint("Unable to retrieve RSA signing key from OAuth 2.0 Client."), err) - } + if algAny { + if token.SignatureAlgorithm == "none" { - return key, nil - default: - return nil, errorsx.WithStack(ErrInvalidRequestObject.WithHintf(hintRequestObjectValidate, hintRequestObjectPrefix(openid)).WithDebugf("The OAuth 2.0 client with id '%s' provided a request object that uses the unsupported signing algorithm '%s'.", request.GetClient().GetID(), t.Header[consts.JSONWebTokenHeaderAlgorithm])) } - }) - if err != nil { - // Do not re-process already enhanced errors - var e *jwt.ValidationError - if errors.As(err, &e) { - if e.Inner != nil { - return e.Inner - } + if kid := client.GetRequestObjectSigningKeyID(); kid != "" && kid != token.KeyID { - return errorsx.WithStack(ErrInvalidRequestObject.WithHintf(hintRequestObjectValidate, hintRequestObjectPrefix(openid)).WithDebugf("The OAuth 2.0 client with id '%s' provided a request object which failed to validate with error: %+v.", request.GetClient().GetID(), err).WithWrap(err)) } + } else if string(token.SignatureAlgorithm) != alg { - return err - } else if err = token.Claims.Valid(); err != nil { - return errorsx.WithStack(ErrInvalidRequestObject.WithHintf(hintRequestObjectValidate, hintRequestObjectPrefix(openid)).WithDebugf("The OAuth 2.0 client with id '%s' provided a request object which could not be validated because its claims could not be validated with error: %+v.", request.GetClient().GetID(), err).WithWrap(err)) + } + + if token.SignatureAlgorithm == "none" && !algNone { + print("") + } else if !algAny && string(token.SignatureAlgorithm) != client.GetRequestObjectSigningAlg() { + print("") } claims := token.Claims diff --git a/authorize_request_handler_oidc_request_test.go b/authorize_request_handler_oidc_request_test.go index 7d957d1b..ac618198 100644 --- a/authorize_request_handler_oidc_request_test.go +++ b/authorize_request_handler_oidc_request_test.go @@ -30,9 +30,10 @@ func TestAuthorizeRequestParametersFromOpenIDConnectRequestObject(t *testing.T) jwks := &jose.JSONWebKeySet{ Keys: []jose.JSONWebKey{ { - KeyID: "kid-foo", - Use: "sig", - Key: &key.PublicKey, + KeyID: "kid-foo", + Use: "sig", + Algorithm: string(jose.RS256), + Key: &key.PublicKey, }, }, } @@ -371,7 +372,14 @@ func TestAuthorizeRequestParametersFromOpenIDConnectRequestObject(t *testing.T) }, } - provider := &Fosite{Config: &Config{JWKSFetcherStrategy: NewDefaultJWKSFetcherStrategy(), IDTokenIssuer: "https://auth.example.com"}} + config := &Config{JWKSFetcherStrategy: NewDefaultJWKSFetcherStrategy(), IDTokenIssuer: "https://auth.example.com"} + + strategy := &jwt.DefaultStrategy{ + Config: config, + Issuer: jwt.MustGenDefaultIssuer(), + } + + provider := &Fosite{Config: &Config{JWKSFetcherStrategy: NewDefaultJWKSFetcherStrategy(), IDTokenIssuer: "https://auth.example.com", JWTStrategy: strategy}} err = provider.authorizeRequestParametersFromOpenIDConnectRequestObject(context.Background(), r, tc.par) if tc.err != nil { @@ -400,21 +408,21 @@ func mustGenerateAssertion(t *testing.T, claims jwt.MapClaims, key *rsa.PrivateK if kid != "" { token.Header[consts.JSONWebTokenHeaderKeyIdentifier] = kid } - tokenString, err := token.SignedString(key) + tokenString, err := token.CompactSignedString(key) require.NoError(t, err) return tokenString } func mustGenerateHSAssertion(t *testing.T, claims jwt.MapClaims) string { token := jwt.NewWithClaims(jose.HS256, claims) - tokenString, err := token.SignedString([]byte("aaaaaaaaaaaaaaabbbbbbbbbbbbbbbbbbbbbbbcccccccccccccccccccccddddddddddddddddddddddd")) + tokenString, err := token.CompactSignedString([]byte("aaaaaaaaaaaaaaabbbbbbbbbbbbbbbbbbbbbbbcccccccccccccccccccccddddddddddddddddddddddd")) require.NoError(t, err) return tokenString } func mustGenerateNoneAssertion(t *testing.T, claims jwt.MapClaims) string { token := jwt.NewWithClaims(jwt.SigningMethodNone, claims) - tokenString, err := token.SignedString(jwt.UnsafeAllowNoneSignatureType) + tokenString, err := token.CompactSignedString(jwt.UnsafeAllowNoneSignatureType) require.NoError(t, err) return tokenString } diff --git a/client.go b/client.go index 2dc370f2..34bd7d69 100644 --- a/client.go +++ b/client.go @@ -5,6 +5,7 @@ package oauth2 import ( "context" + "fmt" "time" "github.com/go-jose/go-jose/v4" @@ -17,8 +18,12 @@ type Client interface { // GetID returns the client ID. GetID() (id string) + // GetClientSecret returns the ClientSecret. GetClientSecret() (secret ClientSecret) + // GetClientSecretPlainText returns the ClientSecret as plaintext if available. + GetClientSecretPlainText() (secret []byte, err error) + // GetRedirectURIs returns the client's allowed redirect URIs. GetRedirectURIs() []string @@ -230,6 +235,19 @@ type AuthenticationMethodClient interface { // methods. GetRevocationEndpointAuthSigningAlg() (alg string) + // GetPushedAuthorizationRequestEndpointAuthMethod is equivalent to the + // 'pushed_authorize_request_endpoint_auth_method' client metadata value which determines the requested Client + // Authentication method for the Pushed Authorization Request Endpoint. The options are client_secret_post, + // client_secret_basic, client_secret_jwt, and private_key_jwt. + GetPushedAuthorizationRequestEndpointAuthMethod() (method string) + + // GetPushedAuthorizationRequestEndpointAuthSigningAlg is equivalent to the + // 'pushed_authorization_request_endpoint_auth_signing_alg' client metadata value which determines the JWS [JWS] alg + // algorithm [JWA] that MUST be used for signing the JWT [JWT] used to authenticate the + // Client at the Pushed Authorization Request Endpoint for the private_key_jwt and client_secret_jwt authentication + // methods. + GetPushedAuthorizationRequestEndpointAuthSigningAlg() (alg string) + JSONWebKeysClient } @@ -414,20 +432,22 @@ type DefaultClient struct { type DefaultJARClient struct { *DefaultClient - JSONWebKeysURI string `json:"jwks_uri"` - JSONWebKeys *jose.JSONWebKeySet `json:"jwks"` - TokenEndpointAuthMethod string `json:"token_endpoint_auth_method"` - IntrospectionEndpointAuthMethod string `json:"introspection_endpoint_auth_method"` - RevocationEndpointAuthMethod string `json:"revocation_endpoint_auth_method"` - RequestURIs []string `json:"request_uris"` - RequestObjectSigningKeyID string `json:"request_object_signing_kid"` - RequestObjectSigningAlg string `json:"request_object_signing_alg"` - RequestObjectEncryptionKeyID string `json:"request_object_encryption_kid"` - RequestObjectEncryptionAlg string `json:"request_object_encryption_alg"` - RequestObjectEncryptionEnc string `json:"request_object_encryption_enc"` - TokenEndpointAuthSigningAlg string `json:"token_endpoint_auth_signing_alg"` - IntrospectionEndpointAuthSigningAlg string `json:"introspection_endpoint_auth_signing_alg"` - RevocationEndpointAuthSigningAlg string `json:"revocation_endpoint_auth_signing_alg"` + JSONWebKeysURI string `json:"jwks_uri"` + JSONWebKeys *jose.JSONWebKeySet `json:"jwks"` + TokenEndpointAuthMethod string `json:"token_endpoint_auth_method"` + IntrospectionEndpointAuthMethod string `json:"introspection_endpoint_auth_method"` + RevocationEndpointAuthMethod string `json:"revocation_endpoint_auth_method"` + PushedAuthorizationRequestEndpointAuthMethod string `json:"pushed_authorization_request_endpoint_auth_method"` + RequestURIs []string `json:"request_uris"` + RequestObjectSigningKeyID string `json:"request_object_signing_kid"` + RequestObjectSigningAlg string `json:"request_object_signing_alg"` + RequestObjectEncryptionKeyID string `json:"request_object_encryption_kid"` + RequestObjectEncryptionAlg string `json:"request_object_encryption_alg"` + RequestObjectEncryptionEnc string `json:"request_object_encryption_enc"` + TokenEndpointAuthSigningAlg string `json:"token_endpoint_auth_signing_alg"` + IntrospectionEndpointAuthSigningAlg string `json:"introspection_endpoint_auth_signing_alg"` + RevocationEndpointAuthSigningAlg string `json:"revocation_endpoint_auth_signing_alg"` + PushedAuthorizationRequestEndpointAuthSigningAlg string `json:"pushed_authorization_request_endpoint_auth_signing_alg"` } type DefaultResponseModeClient struct { @@ -455,6 +475,14 @@ func (c *DefaultClient) GetClientSecret() (secret ClientSecret) { return c.ClientSecret } +func (c *DefaultClient) GetClientSecretPlainText() (secret []byte, err error) { + if c.ClientSecret == nil || !c.ClientSecret.Valid() { + return nil, fmt.Errorf("this secret doesn't support plaintext") + } + + return c.ClientSecret.GetPlainTextValue() +} + func (c *DefaultClient) GetRotatedClientSecrets() (secrets []ClientSecret) { return c.RotatedClientSecrets } @@ -513,6 +541,10 @@ func (c *DefaultJARClient) GetRevocationEndpointAuthSigningAlg() string { return c.RevocationEndpointAuthSigningAlg } +func (c *DefaultJARClient) GetPushedAuthorizationRequestEndpointAuthSigningAlg() (alg string) { + return c.PushedAuthorizationRequestEndpointAuthSigningAlg +} + func (c *DefaultJARClient) GetRequestObjectSigningKeyID() string { return c.RequestObjectSigningKeyID } @@ -545,6 +577,10 @@ func (c *DefaultJARClient) GetRevocationEndpointAuthMethod() string { return c.RevocationEndpointAuthMethod } +func (c *DefaultJARClient) GetPushedAuthorizationRequestEndpointAuthMethod() string { + return c.PushedAuthorizationRequestEndpointAuthMethod +} + func (c *DefaultJARClient) GetRequestURIs() []string { return c.RequestURIs } diff --git a/client_authentication_jwks_strategy.go b/client_authentication_jwks_strategy.go index 2d12fedc..fd882a0e 100644 --- a/client_authentication_jwks_strategy.go +++ b/client_authentication_jwks_strategy.go @@ -13,21 +13,13 @@ import ( "github.com/go-jose/go-jose/v4" "github.com/hashicorp/go-retryablehttp" + "authelia.com/provider/oauth2/token/jwt" "authelia.com/provider/oauth2/x/errorsx" ) const defaultJWKSFetcherStrategyCachePrefix = "authelia.com/provider/oauth2.DefaultJWKSFetcherStrategy:" -// JWKSFetcherStrategy is a strategy which pulls (optionally caches) JSON Web Key Sets from a location, -// typically a client's jwks_uri. -type JWKSFetcherStrategy interface { - // Resolve returns the JSON Web Key Set, or an error if something went wrong. The forceRefresh, if true, forces - // the strategy to fetch the key from the remote. If forceRefresh is false, the strategy may use a caching strategy - // to fetch the key. - Resolve(ctx context.Context, location string, ignoreCache bool) (*jose.JSONWebKeySet, error) -} - -// DefaultJWKSFetcherStrategy is a default implementation of the JWKSFetcherStrategy interface. +// DefaultJWKSFetcherStrategy is a default implementation of the jwt.JWKSFetcherStrategy interface. type DefaultJWKSFetcherStrategy struct { client *retryablehttp.Client cache *ristretto.Cache @@ -36,7 +28,7 @@ type DefaultJWKSFetcherStrategy struct { } // NewDefaultJWKSFetcherStrategy returns a new instance of the DefaultJWKSFetcherStrategy. -func NewDefaultJWKSFetcherStrategy(opts ...func(*DefaultJWKSFetcherStrategy)) JWKSFetcherStrategy { +func NewDefaultJWKSFetcherStrategy(opts ...func(*DefaultJWKSFetcherStrategy)) jwt.JWKSFetcherStrategy { dc, err := ristretto.NewCache(&ristretto.Config{ NumCounters: 10000 * 10, MaxCost: 10000, @@ -100,7 +92,7 @@ func (s *DefaultJWKSFetcherStrategy) Resolve(ctx context.Context, location strin if !ok || ignoreCache { req, err := retryablehttp.NewRequest(http.MethodGet, location, nil) if err != nil { - return nil, errorsx.WithStack(ErrServerError.WithHintf("Unable to create HTTP 'GET' request to fetch JSON Web Keys from location '%s'.", location).WithWrap(err).WithDebugError(err)) + return nil, errorsx.WithStack(ErrServerError.WithHintf("Unable to create HTTP 'GET' request to fetch JSON Web Keys from location '%s'.", location).WithWrap(err).WithDebugError(err)) } hc := s.client diff --git a/client_authentication_test.go b/client_authentication_test.go index 6ce06cfd..3995341b 100644 --- a/client_authentication_test.go +++ b/client_authentication_test.go @@ -792,7 +792,7 @@ func TestAuthenticateClientTwice(t *testing.T) { func mustGenerateRSAAssertion(t *testing.T, claims jwt.MapClaims, key *rsa.PrivateKey, kid string) string { token := jwt.NewWithClaims(jose.RS256, claims) token.Header["kid"] = kid - tokenString, err := token.SignedString(key) + tokenString, err := token.CompactSignedString(key) require.NoError(t, err) return tokenString } @@ -800,7 +800,7 @@ func mustGenerateRSAAssertion(t *testing.T, claims jwt.MapClaims, key *rsa.Priva func mustGenerateECDSAAssertion(t *testing.T, claims jwt.MapClaims, key *ecdsa.PrivateKey, kid string) string { token := jwt.NewWithClaims(jose.ES256, claims) token.Header["kid"] = kid - tokenString, err := token.SignedString(key) + tokenString, err := token.CompactSignedString(key) require.NoError(t, err) return tokenString } @@ -808,7 +808,7 @@ func mustGenerateECDSAAssertion(t *testing.T, claims jwt.MapClaims, key *ecdsa.P //nolint:unparam func mustGenerateHSAssertion(t *testing.T, claims jwt.MapClaims, key *rsa.PrivateKey, kid string) string { token := jwt.NewWithClaims(jose.HS256, claims) - tokenString, err := token.SignedString([]byte("aaaaaaaaaaaaaaabbbbbbbbbbbbbbbbbbbbbbbcccccccccccccccccccccddddddddddddddddddddddd")) + tokenString, err := token.CompactSignedString([]byte("aaaaaaaaaaaaaaabbbbbbbbbbbbbbbbbbbbbbbcccccccccccccccccccccddddddddddddddddddddddd")) require.NoError(t, err) return tokenString } @@ -816,7 +816,7 @@ func mustGenerateHSAssertion(t *testing.T, claims jwt.MapClaims, key *rsa.Privat //nolint:unparam func mustGenerateNoneAssertion(t *testing.T, claims jwt.MapClaims, key *rsa.PrivateKey, kid string) string { token := jwt.NewWithClaims(jwt.SigningMethodNone, claims) - tokenString, err := token.SignedString(jwt.UnsafeAllowNoneSignatureType) + tokenString, err := token.CompactSignedString(jwt.UnsafeAllowNoneSignatureType) require.NoError(t, err) return tokenString } diff --git a/compose/compose.go b/compose/compose.go index 0b002776..7596445b 100644 --- a/compose/compose.go +++ b/compose/compose.go @@ -4,10 +4,9 @@ package compose import ( - "context" - "authelia.com/provider/oauth2" "authelia.com/provider/oauth2/token/jwt" + "context" ) type Factory func(config oauth2.Configurator, storage any, strategy any) any @@ -69,13 +68,20 @@ func ComposeAllEnabled(config *oauth2.Config, storage any, key any) oauth2.Provi keyGetter := func(context.Context) (any, error) { return key, nil } + + strategy := &jwt.DefaultStrategy{ + Config: config, + Issuer: jwt.NewDefaultIssuerRS256Unverified(key), + } + return Compose( config, storage, &CommonStrategy{ CoreStrategy: NewOAuth2HMACStrategy(config), - OpenIDConnectTokenStrategy: NewOpenIDConnectStrategy(keyGetter, config), - Signer: &jwt.DefaultSigner{GetPrivateKey: keyGetter}, + OpenIDConnectTokenStrategy: NewOpenIDConnectStrategy(keyGetter, strategy, config), + Strategy: strategy, + //Signer: &jwt.DefaultSigner{GetPrivateKey: keyGetter}, }, OAuth2AuthorizeExplicitFactory, OAuth2AuthorizeImplicitFactory, diff --git a/compose/compose_oauth2.go b/compose/compose_oauth2.go index 44492c32..33cb1f28 100644 --- a/compose/compose_oauth2.go +++ b/compose/compose_oauth2.go @@ -111,7 +111,7 @@ func OAuth2TokenIntrospectionFactory(config oauth2.Configurator, storage any, st // If you need revocation, you can validate JWTs statefully, using the other factories. func OAuth2StatelessJWTIntrospectionFactory(config oauth2.Configurator, storage any, strategy any) any { return &hoauth2.StatelessJWTValidator{ - Signer: strategy.(jwt.Signer), - Config: config, + Strategy: strategy.(jwt.Strategy), + Config: config, } } diff --git a/compose/compose_openid.go b/compose/compose_openid.go index 4c4676dc..abe3b2e0 100644 --- a/compose/compose_openid.go +++ b/compose/compose_openid.go @@ -20,7 +20,7 @@ func OpenIDConnectExplicitFactory(config oauth2.Configurator, storage any, strat IDTokenHandleHelper: &openid.IDTokenHandleHelper{ IDTokenStrategy: strategy.(openid.OpenIDConnectTokenStrategy), }, - OpenIDConnectRequestValidator: openid.NewOpenIDConnectRequestValidator(strategy.(jwt.Signer), config), + OpenIDConnectRequestValidator: openid.NewOpenIDConnectRequestValidator(strategy.(jwt.Strategy), config), Config: config, } } @@ -51,7 +51,7 @@ func OpenIDConnectImplicitFactory(config oauth2.Configurator, storage any, strat IDTokenHandleHelper: &openid.IDTokenHandleHelper{ IDTokenStrategy: strategy.(openid.OpenIDConnectTokenStrategy), }, - OpenIDConnectRequestValidator: openid.NewOpenIDConnectRequestValidator(strategy.(jwt.Signer), config), + OpenIDConnectRequestValidator: openid.NewOpenIDConnectRequestValidator(strategy.(jwt.Strategy), config), } } @@ -77,14 +77,14 @@ func OpenIDConnectHybridFactory(config oauth2.Configurator, storage any, strateg IDTokenStrategy: strategy.(openid.OpenIDConnectTokenStrategy), }, OpenIDConnectRequestStorage: storage.(openid.OpenIDConnectRequestStorage), - OpenIDConnectRequestValidator: openid.NewOpenIDConnectRequestValidator(strategy.(jwt.Signer), config), + OpenIDConnectRequestValidator: openid.NewOpenIDConnectRequestValidator(strategy.(jwt.Strategy), config), } } func OpenIDConnectDeviceAuthorizeFactory(config oauth2.Configurator, storage any, strategy any) any { return &openid.OpenIDConnectDeviceAuthorizeHandler{ OpenIDConnectRequestStorage: storage.(openid.OpenIDConnectRequestStorage), - OpenIDConnectRequestValidator: openid.NewOpenIDConnectRequestValidator(strategy.(jwt.Signer), config), + OpenIDConnectRequestValidator: openid.NewOpenIDConnectRequestValidator(strategy.(jwt.Strategy), config), CodeTokenEndpointHandler: &rfc8628.DeviceCodeTokenHandler{ Strategy: strategy.(rfc8628.CodeStrategy), Storage: storage.(rfc8628.Storage), diff --git a/compose/compose_strategy.go b/compose/compose_strategy.go index e03dda1c..cc272ff4 100644 --- a/compose/compose_strategy.go +++ b/compose/compose_strategy.go @@ -16,7 +16,7 @@ import ( type CommonStrategy struct { hoauth2.CoreStrategy openid.OpenIDConnectTokenStrategy - jwt.Signer + jwt.Strategy } type HMACSHAStrategyConfigurator interface { @@ -37,17 +37,17 @@ func NewOAuth2HMACStrategy(config HMACSHAStrategyConfigurator) *hoauth2.HMACCore } } -func NewOAuth2JWTStrategy(keyGetter func(context.Context) (any, error), strategy *hoauth2.HMACCoreStrategy, config oauth2.Configurator) *hoauth2.JWTProfileCoreStrategy { +func NewOAuth2JWTStrategy(strategy jwt.Strategy, strategyHMAC *hoauth2.HMACCoreStrategy, config oauth2.Configurator) *hoauth2.JWTProfileCoreStrategy { return &hoauth2.JWTProfileCoreStrategy{ - Signer: &jwt.DefaultSigner{GetPrivateKey: keyGetter}, - HMACCoreStrategy: strategy, + Strategy: strategy, + HMACCoreStrategy: strategyHMAC, Config: config, } } -func NewOpenIDConnectStrategy(keyGetter func(context.Context) (any, error), config oauth2.Configurator) *openid.DefaultStrategy { +func NewOpenIDConnectStrategy(keyGetter func(context.Context) (any, error), strategy jwt.Strategy, config oauth2.Configurator) *openid.DefaultStrategy { return &openid.DefaultStrategy{ - Signer: &jwt.DefaultSigner{GetPrivateKey: keyGetter}, - Config: config, + Strategy: strategy, + Config: config, } } diff --git a/config.go b/config.go index 7c16abd8..49cdd182 100644 --- a/config.go +++ b/config.go @@ -101,10 +101,10 @@ type IntrospectionIssuerProvider interface { GetIntrospectionIssuer(ctx context.Context) (issuer string) } -// IntrospectionJWTResponseSignerProvider returns the provider for configuring the Introspection signer. -type IntrospectionJWTResponseSignerProvider interface { - // GetIntrospectionJWTResponseSigner returns the Introspection JWT signer. - GetIntrospectionJWTResponseSigner(ctx context.Context) jwt.Signer +// IntrospectionJWTResponseStrategyProvider returns the provider for configuring the Introspection jwt.Strategy. +type IntrospectionJWTResponseStrategyProvider interface { + // GetIntrospectionJWTResponseStrategy returns the Introspection JWT Strategy. + GetIntrospectionJWTResponseStrategy(ctx context.Context) jwt.Strategy } // AuthorizationServerIssuerIdentificationProvider provides OAuth 2.0 Authorization Server Issuer Identification related methods. @@ -124,10 +124,15 @@ type JWTSecuredAuthorizeResponseModeIssuerProvider interface { GetJWTSecuredAuthorizeResponseModeIssuer(ctx context.Context) string } -// JWTSecuredAuthorizeResponseModeSignerProvider returns the provider for configuring the JARM signer. -type JWTSecuredAuthorizeResponseModeSignerProvider interface { - // GetJWTSecuredAuthorizeResponseModeSigner returns the JARM signer. - GetJWTSecuredAuthorizeResponseModeSigner(ctx context.Context) jwt.Signer +// JWTSecuredAuthorizeResponseModeStrategyProvider returns the provider for configuring the JARM jwt.Strategy. +type JWTSecuredAuthorizeResponseModeStrategyProvider interface { + // GetJWTSecuredAuthorizeResponseModeStrategy returns the JARM Strategy. + GetJWTSecuredAuthorizeResponseModeStrategy(ctx context.Context) jwt.Strategy +} + +// JWTStrategyProvider returns the provider for configuring the jwt.Strategy. +type JWTStrategyProvider interface { + GetJWTStrategy(ctx context.Context) jwt.Strategy } // JWTSecuredAuthorizeResponseModeLifespanProvider returns the provider for configuring the JWT Secured Authorize Response Mode token lifespan. @@ -253,7 +258,7 @@ type RevokeRefreshTokensExplicitlyProvider interface { // JWKSFetcherStrategyProvider returns the provider for configuring the JWKS fetcher strategy. type JWKSFetcherStrategyProvider interface { // GetJWKSFetcherStrategy returns the JWKS fetcher strategy. - GetJWKSFetcherStrategy(ctx context.Context) (strategy JWKSFetcherStrategy) + GetJWKSFetcherStrategy(ctx context.Context) (strategy jwt.JWKSFetcherStrategy) } // HTTPClientProvider returns the provider for configuring the HTTP client. diff --git a/config_default.go b/config_default.go index f1a105f2..7c7677e5 100644 --- a/config_default.go +++ b/config_default.go @@ -54,8 +54,8 @@ type Config struct { // IntrospectionIssuer is the issuer to be used when generating signed introspection responses. IntrospectionIssuer string - // IntrospectionJWTResponseSigner is the signer for Introspection Responses. Has no default. - IntrospectionJWTResponseSigner jwt.Signer + // IntrospectionJWTResponseStrategy is the signer for Introspection Responses. Has no default. + IntrospectionJWTResponseStrategy jwt.Strategy // HashCost sets the cost of the password hashing cost. Defaults to 12. HashCost int @@ -102,7 +102,7 @@ type Config struct { // JWKSFetcherStrategy is responsible for fetching JSON Web Keys from remote URLs. This is required when the private_key_jwt // client authentication method is used. Defaults to oauth2.DefaultJWKSFetcherStrategy. - JWKSFetcherStrategy JWKSFetcherStrategy + JWKSFetcherStrategy jwt.JWKSFetcherStrategy // TokenEntropy indicates the entropy of the random string, used as the "message" part of the HMAC token. // Defaults to 32. @@ -175,8 +175,11 @@ type Config struct { // JWT Secured Authorization Response Mode. Defaults to 10 minutes. JWTSecuredAuthorizeResponseModeLifespan time.Duration - // JWTSecuredAuthorizeResponseModeSigner is the signer for JWT Secured Authorization Response Mode. Has no default. - JWTSecuredAuthorizeResponseModeSigner jwt.Signer + // JWTSecuredAuthorizeResponseModeStrategy is the signer for JWT Secured Authorization Response Mode. Has no default. + JWTSecuredAuthorizeResponseModeStrategy jwt.Strategy + + // JWTStrategy handles less specific jwt.Strategy cases. + JWTStrategy jwt.Strategy // EnforceJWTProfileAccessTokens forces the issuer to return JWT Profile Access Tokens to all clients. EnforceJWTProfileAccessTokens bool @@ -333,8 +336,8 @@ func (c *Config) GetIntrospectionIssuer(ctx context.Context) string { return c.IntrospectionIssuer } -func (c *Config) GetIntrospectionJWTResponseSigner(ctx context.Context) jwt.Signer { - return c.IntrospectionJWTResponseSigner +func (c *Config) GetIntrospectionJWTResponseStrategy(ctx context.Context) jwt.Strategy { + return c.IntrospectionJWTResponseStrategy } // GetGrantTypeJWTBearerIssuedDateOptional returns the GrantTypeJWTBearerIssuedDateOptional field. @@ -389,8 +392,12 @@ func (c *Config) GetJWTSecuredAuthorizeResponseModeIssuer(ctx context.Context) s return c.IDTokenIssuer } -func (c *Config) GetJWTSecuredAuthorizeResponseModeSigner(ctx context.Context) jwt.Signer { - return c.JWTSecuredAuthorizeResponseModeSigner +func (c *Config) GetJWTSecuredAuthorizeResponseModeStrategy(ctx context.Context) jwt.Strategy { + return c.JWTSecuredAuthorizeResponseModeStrategy +} + +func (c *Config) GetJWTStrategy(ctx context.Context) jwt.Strategy { + return c.JWTStrategy } func (c *Config) GetEnforceJWTProfileAccessTokens(ctx context.Context) (enable bool) { @@ -496,8 +503,8 @@ func (c *Config) GetBCryptCost(_ context.Context) int { return c.HashCost } -// GetJWKSFetcherStrategy returns the JWKSFetcherStrategy. -func (c *Config) GetJWKSFetcherStrategy(_ context.Context) JWKSFetcherStrategy { +// GetJWKSFetcherStrategy returns the jwt.JWKSFetcherStrategy. +func (c *Config) GetJWKSFetcherStrategy(_ context.Context) jwt.JWKSFetcherStrategy { if c.JWKSFetcherStrategy == nil { c.JWKSFetcherStrategy = NewDefaultJWKSFetcherStrategy() } @@ -628,7 +635,7 @@ var ( _ AccessTokenIssuerProvider = (*Config)(nil) _ JWTScopeFieldProvider = (*Config)(nil) _ JWTSecuredAuthorizeResponseModeIssuerProvider = (*Config)(nil) - _ JWTSecuredAuthorizeResponseModeSignerProvider = (*Config)(nil) + _ JWTSecuredAuthorizeResponseModeStrategyProvider = (*Config)(nil) _ JWTSecuredAuthorizeResponseModeLifespanProvider = (*Config)(nil) _ JWTProfileAccessTokensProvider = (*Config)(nil) _ AllowedPromptsProvider = (*Config)(nil) @@ -666,5 +673,5 @@ var ( _ RFC8628DeviceAuthorizeEndpointHandlersProvider = (*Config)(nil) _ RFC8628UserAuthorizeEndpointHandlersProvider = (*Config)(nil) _ IntrospectionIssuerProvider = (*Config)(nil) - _ IntrospectionJWTResponseSignerProvider = (*Config)(nil) + _ IntrospectionJWTResponseStrategyProvider = (*Config)(nil) ) diff --git a/fosite.go b/fosite.go index 8c78bbf3..e23cce4c 100644 --- a/fosite.go +++ b/fosite.go @@ -156,7 +156,7 @@ type Configurator interface { SanitationAllowedProvider JWTScopeFieldProvider JWTSecuredAuthorizeResponseModeIssuerProvider - JWTSecuredAuthorizeResponseModeSignerProvider + JWTSecuredAuthorizeResponseModeStrategyProvider JWTSecuredAuthorizeResponseModeLifespanProvider JWTProfileAccessTokensProvider AccessTokenIssuerProvider @@ -196,7 +196,8 @@ type Configurator interface { RFC8628UserAuthorizeEndpointHandlersProvider RFC9628DeviceAuthorizeConfigProvider IntrospectionIssuerProvider - IntrospectionJWTResponseSignerProvider + IntrospectionJWTResponseStrategyProvider + JWTStrategyProvider UseLegacyErrorFormatProvider } diff --git a/handler/oauth2/introspector_jwt.go b/handler/oauth2/introspector_jwt.go index e8793723..0db62d1c 100644 --- a/handler/oauth2/introspector_jwt.go +++ b/handler/oauth2/introspector_jwt.go @@ -14,52 +14,34 @@ import ( ) type StatelessJWTValidator struct { - jwt.Signer + jwt.Strategy Config interface { oauth2.ScopeStrategyProvider } } -func (v *StatelessJWTValidator) IntrospectToken(ctx context.Context, token string, tokenUse oauth2.TokenUse, accessRequest oauth2.AccessRequester, scopes []string) (oauth2.TokenUse, error) { - t, err := validateJWT(ctx, v.Signer, token) - if err != nil { +func (v *StatelessJWTValidator) IntrospectToken(ctx context.Context, tokenString string, tokenUse oauth2.TokenUse, requester oauth2.AccessRequester, scopes []string) (use oauth2.TokenUse, err error) { + var token *jwt.Token + + if token, err = validateJWT(ctx, v.Strategy, jwt.NewStatelessJWTProfileIntrospectionClient(requester.GetClient()), tokenString); err != nil { return "", err } - if !IsJWTProfileAccessToken(t) { + if !token.IsJWTProfileAccessToken() { return "", errorsx.WithStack(oauth2.ErrRequestUnauthorized.WithDebug("The provided token is not a valid RFC9068 JWT Profile Access Token as it is missing the header 'typ' value of 'at+jwt' ")) } - requester := AccessTokenJWTToRequest(t) + r := AccessTokenJWTToRequest(token) - if err := matchScopes(v.Config.GetScopeStrategy(ctx), requester.GetGrantedScopes(), scopes); err != nil { + if err = matchScopes(v.Config.GetScopeStrategy(ctx), r.GetGrantedScopes(), scopes); err != nil { return oauth2.AccessToken, err } - accessRequest.Merge(requester) + requester.Merge(r) return oauth2.AccessToken, nil } -// IsJWTProfileAccessToken validates a *jwt.Token is actually a RFC9068 JWT Profile Access Token by checking the -// relevant header as per https://datatracker.ietf.org/doc/html/rfc9068#section-2.1 which explicitly states that -// the header MUST include a typ of 'at+jwt' or 'application/at+jwt' with a preference of 'at+jwt'. -func IsJWTProfileAccessToken(token *jwt.Token) bool { - var ( - raw any - typ string - ok bool - ) - - if raw, ok = token.Header[jwt.JWTHeaderKeyValueType]; !ok { - return false - } - - typ, ok = raw.(string) - - return ok && (typ == jwt.JWTHeaderTypeValueAccessTokenJWT || typ == "application/at+jwt") -} - // AccessTokenJWTToRequest tries to reconstruct oauth2.Request from a JWT. func AccessTokenJWTToRequest(token *jwt.Token) oauth2.Requester { mapClaims := token.Claims diff --git a/handler/oauth2/introspector_jwt_test.go b/handler/oauth2/introspector_jwt_test.go index 33cc49a9..9398e1ef 100644 --- a/handler/oauth2/introspector_jwt_test.go +++ b/handler/oauth2/introspector_jwt_test.go @@ -19,8 +19,6 @@ import ( ) func TestIntrospectJWT(t *testing.T) { - rsaKey := gen.MustRSAKey() - config := &oauth2.Config{ EnforceJWTProfileAccessTokens: true, GlobalSecret: []byte("foofoofoofoofoofoofoofoofoofoofoo"), @@ -28,16 +26,15 @@ func TestIntrospectJWT(t *testing.T) { strategy := &JWTProfileCoreStrategy{ HMACCoreStrategy: NewHMACCoreStrategy(config, "authelia_%s_"), - Signer: &jwt.DefaultSigner{ - GetPrivateKey: func(_ context.Context) (any, error) { - return rsaKey, nil - }, + Strategy: &jwt.DefaultStrategy{ + Config: config, + Issuer: jwt.NewDefaultIssuerRS256Unverified(gen.MustRSAKey()), }, Config: config, } var v = &StatelessJWTValidator{ - Signer: strategy, + Strategy: strategy, Config: &oauth2.Config{ ScopeStrategy: oauth2.HierarchicScopeStrategy, }, @@ -126,16 +123,18 @@ func TestIntrospectJWT(t *testing.T) { } func BenchmarkIntrospectJWT(b *testing.B) { + config := &oauth2.Config{} + strategy := &JWTProfileCoreStrategy{ - Signer: &jwt.DefaultSigner{GetPrivateKey: func(_ context.Context) (any, error) { - return gen.MustRSAKey(), nil + Strategy: &jwt.DefaultStrategy{ + Config: config, + Issuer: jwt.NewDefaultIssuerRS256Unverified(gen.MustRSAKey()), }, - }, - Config: &oauth2.Config{}, + Config: config, } v := &StatelessJWTValidator{ - Signer: strategy, + Strategy: strategy, } jwt := jwtValidCase(oauth2.AccessToken) diff --git a/handler/oauth2/strategy.go b/handler/oauth2/strategy.go index 065539b6..4f165a28 100644 --- a/handler/oauth2/strategy.go +++ b/handler/oauth2/strategy.go @@ -12,13 +12,13 @@ import ( // NewCoreStrategy is a special constructor that if provided a signer will automatically decorate the HMACCoreStrategy // with a JWTProfileCoreStrategy, otherwise it just returns the HMACCoreStrategy. -func NewCoreStrategy(config CoreStrategyConfigurator, prefix string, signer jwt.Signer) (strategy CoreStrategy) { - if signer == nil { +func NewCoreStrategy(config CoreStrategyConfigurator, prefix string, strategy jwt.Strategy) (core CoreStrategy) { + if strategy == nil { return NewHMACCoreStrategy(config, prefix) } return &JWTProfileCoreStrategy{ - Signer: signer, + Strategy: strategy, HMACCoreStrategy: NewHMACCoreStrategy(config, prefix), Config: config, } diff --git a/handler/oauth2/strategy_jwt_profile.go b/handler/oauth2/strategy_jwt_profile.go index d8589563..4d688803 100644 --- a/handler/oauth2/strategy_jwt_profile.go +++ b/handler/oauth2/strategy_jwt_profile.go @@ -18,7 +18,8 @@ import ( // JWTProfileCoreStrategy is a JWT RS256 strategy. type JWTProfileCoreStrategy struct { - jwt.Signer + jwt.Strategy + HMACCoreStrategy *HMACCoreStrategy Config interface { oauth2.AccessTokenIssuerProvider @@ -56,7 +57,7 @@ func (s *JWTProfileCoreStrategy) GenerateAccessToken(ctx context.Context, reques func (s *JWTProfileCoreStrategy) ValidateAccessToken(ctx context.Context, requester oauth2.Requester, tokenString string) (err error) { if possible, _ := s.IsPossiblyJWTProfileAccessToken(ctx, tokenString); possible { - _, err = validateJWT(ctx, s.Signer, tokenString) + _, err = validateJWT(ctx, s.Strategy, jwt.NewJWTProfileAccessTokenClient(requester.GetClient()), tokenString) return } @@ -170,11 +171,11 @@ func (s *JWTProfileCoreStrategy) GenerateJWT(ctx context.Context, tokenType oaut s.Config.GetJWTScopeField(ctx), ) - return s.Signer.Generate(ctx, claims.ToMapClaims(), header) + return s.Strategy.Encode(ctx, jwt.WithClaims(claims.ToMapClaims()), jwt.WithHeaders(header), jwt.WithJWTProfileAccessTokenClient(client)) } -func validateJWT(ctx context.Context, jwtStrategy jwt.Signer, token string) (t *jwt.Token, err error) { - t, err = jwtStrategy.Decode(ctx, token) +func validateJWT(ctx context.Context, jwtStrategy jwt.Strategy, client jwt.Client, token string) (t *jwt.Token, err error) { + t, err = jwtStrategy.Decode(ctx, token, jwt.WithClient(client)) if err == nil { err = t.Claims.Valid() return diff --git a/handler/oauth2/strategy_jwt_profile_test.go b/handler/oauth2/strategy_jwt_profile_test.go index c83861ac..163e2ad8 100644 --- a/handler/oauth2/strategy_jwt_profile_test.go +++ b/handler/oauth2/strategy_jwt_profile_test.go @@ -176,19 +176,18 @@ func TestAccessToken(t *testing.T) { }, } { t.Run(fmt.Sprintf("case=%d/%d", s, k), func(t *testing.T) { - signer := &jwt.DefaultSigner{ - GetPrivateKey: func(_ context.Context) (any, error) { - return rsaKey, nil - }, - } - config := &oauth2.Config{ EnforceJWTProfileAccessTokens: true, GlobalSecret: []byte("foofoofoofoofoofoofoofoofoofoofoo"), JWTScopeClaimKey: scopeField, } - strategy := NewCoreStrategy(config, "authelia_%s_", signer) + jwtStrategy := &jwt.DefaultStrategy{ + Config: config, + Issuer: jwt.NewDefaultIssuerRS256Unverified(rsaKey), + } + + strategy := NewCoreStrategy(config, "authelia_%s_", jwtStrategy) token, signature, err := strategy.GenerateAccessToken(context.TODO(), c.r) assert.NoError(t, err) @@ -221,7 +220,7 @@ func TestAccessToken(t *testing.T) { require.NoError(t, json.Unmarshal(rawHeader, &header)) - assert.Equal(t, jwt.JWTHeaderTypeValueAccessTokenJWT, header[jwt.JWTHeaderKeyValueType]) + assert.Equal(t, consts.JSONWebTokenTypeAccessToken, header[consts.JSONWebTokenHeaderType]) extraClaimsSession, ok := c.r.GetSession().(oauth2.ExtraClaimsSession) require.True(t, ok) diff --git a/handler/oauth2/strategy_jwt_session.go b/handler/oauth2/strategy_jwt_session.go index 54ad57d0..9ec9e155 100644 --- a/handler/oauth2/strategy_jwt_session.go +++ b/handler/oauth2/strategy_jwt_session.go @@ -9,6 +9,7 @@ import ( "github.com/mohae/deepcopy" "authelia.com/provider/oauth2" + "authelia.com/provider/oauth2/internal/consts" "authelia.com/provider/oauth2/token/jwt" ) @@ -42,11 +43,11 @@ func (j *JWTSession) GetJWTHeader() *jwt.Headers { if j.JWTHeader == nil { j.JWTHeader = &jwt.Headers{ Extra: map[string]any{ - jwt.JWTHeaderKeyValueType: jwt.JWTHeaderTypeValueAccessTokenJWT, + consts.JSONWebTokenHeaderType: consts.JSONWebTokenTypeAccessToken, }, } - } else if j.JWTHeader.Extra[jwt.JWTHeaderKeyValueType] == nil { - j.JWTHeader.Extra[jwt.JWTHeaderKeyValueType] = jwt.JWTHeaderTypeValueAccessTokenJWT + } else if j.JWTHeader.Extra[consts.JSONWebTokenHeaderType] == nil { + j.JWTHeader.Extra[consts.JSONWebTokenHeaderType] = consts.JSONWebTokenTypeAccessToken } return j.JWTHeader diff --git a/handler/openid/flow_device_authorization_test.go b/handler/openid/flow_device_authorization_test.go index 3c5ce5ff..0f997f84 100644 --- a/handler/openid/flow_device_authorization_test.go +++ b/handler/openid/flow_device_authorization_test.go @@ -27,13 +27,15 @@ func TestOpenIDConnectDeviceAuthorizeHandler_PopulateRFC8628UserAuthorizeEndpoin AuthorizeCodeLifespan: time.Minute * 24, RFC8628CodeLifespan: time.Minute * 24, } - j := &DefaultStrategy{ - Signer: &jwt.DefaultSigner{ - GetPrivateKey: func(ctx context.Context) (any, error) { - return key, nil - }, - }, + + jwtStrategy := &jwt.DefaultStrategy{ Config: config, + Issuer: jwt.NewDefaultIssuerRS256Unverified(key), + } + + j := &DefaultStrategy{ + Strategy: jwtStrategy, + Config: config, } oidcStore := mock.NewMockOpenIDConnectRequestStorage(ctrl) @@ -41,7 +43,7 @@ func TestOpenIDConnectDeviceAuthorizeHandler_PopulateRFC8628UserAuthorizeEndpoin handler := &OpenIDConnectDeviceAuthorizeHandler{ OpenIDConnectRequestStorage: oidcStore, - OpenIDConnectRequestValidator: NewOpenIDConnectRequestValidator(j.Signer, config), + OpenIDConnectRequestValidator: NewOpenIDConnectRequestValidator(j.Strategy, config), CodeTokenEndpointHandler: tokenHandler, Config: config, IDTokenHandleHelper: &IDTokenHandleHelper{ @@ -128,13 +130,15 @@ func TestOpenIDConnectDeviceAuthorizeHandler_PopulateTokenEndpointResponse(t *te AuthorizeCodeLifespan: time.Minute * 24, RFC8628CodeLifespan: time.Minute * 24, } - j := &DefaultStrategy{ - Signer: &jwt.DefaultSigner{ - GetPrivateKey: func(ctx context.Context) (any, error) { - return key, nil - }, - }, + + jwtStrategy := &jwt.DefaultStrategy{ Config: config, + Issuer: jwt.NewDefaultIssuerRS256Unverified(key), + } + + j := &DefaultStrategy{ + Strategy: jwtStrategy, + Config: config, } oidcStore := mock.NewMockOpenIDConnectRequestStorage(ctrl) @@ -142,7 +146,7 @@ func TestOpenIDConnectDeviceAuthorizeHandler_PopulateTokenEndpointResponse(t *te handler := &OpenIDConnectDeviceAuthorizeHandler{ OpenIDConnectRequestStorage: oidcStore, - OpenIDConnectRequestValidator: NewOpenIDConnectRequestValidator(j.Signer, config), + OpenIDConnectRequestValidator: NewOpenIDConnectRequestValidator(j.Strategy, config), CodeTokenEndpointHandler: tokenHandler, Config: config, IDTokenHandleHelper: &IDTokenHandleHelper{ diff --git a/handler/openid/flow_explicit_auth_test.go b/handler/openid/flow_explicit_auth_test.go index d007dd47..4e0d3b1a 100644 --- a/handler/openid/flow_explicit_auth_test.go +++ b/handler/openid/flow_explicit_auth_test.go @@ -120,13 +120,14 @@ func makeOpenIDConnectExplicitHandler(ctrl *gomock.Controller, minParameterEntro store := mock.NewMockOpenIDConnectRequestStorage(ctrl) config := &oauth2.Config{MinParameterEntropy: minParameterEntropy} - var j = &DefaultStrategy{ - Signer: &jwt.DefaultSigner{ - GetPrivateKey: func(ctx context.Context) (any, error) { - return key, nil - }, - }, + jwtStrategy := &jwt.DefaultStrategy{ Config: config, + Issuer: jwt.NewDefaultIssuerRS256Unverified(key), + } + + var j = &DefaultStrategy{ + Strategy: jwtStrategy, + Config: config, } return OpenIDConnectExplicitHandler{ @@ -134,7 +135,7 @@ func makeOpenIDConnectExplicitHandler(ctrl *gomock.Controller, minParameterEntro IDTokenHandleHelper: &IDTokenHandleHelper{ IDTokenStrategy: j, }, - OpenIDConnectRequestValidator: NewOpenIDConnectRequestValidator(j.Signer, config), + OpenIDConnectRequestValidator: NewOpenIDConnectRequestValidator(j.Strategy, config), Config: config, }, store } diff --git a/handler/openid/flow_explicit_token_test.go b/handler/openid/flow_explicit_token_test.go index 1af76df9..8c655bca 100644 --- a/handler/openid/flow_explicit_token_test.go +++ b/handler/openid/flow_explicit_token_test.go @@ -211,15 +211,18 @@ func TestExplicit_PopulateTokenEndpointResponse(t *testing.T) { aresp := oauth2.NewAccessResponse() areq := oauth2.NewAccessRequest(session) + config := &oauth2.Config{ + MinParameterEntropy: oauth2.MinParameterEntropy, + } + + jwtStrategy := &jwt.DefaultStrategy{ + Config: config, + Issuer: jwt.NewDefaultIssuerRS256Unverified(key), + } + var j = &DefaultStrategy{ - Signer: &jwt.DefaultSigner{ - GetPrivateKey: func(ctx context.Context) (any, error) { - return key, nil - }, - }, - Config: &oauth2.Config{ - MinParameterEntropy: oauth2.MinParameterEntropy, - }, + Strategy: jwtStrategy, + Config: config, } h := &OpenIDConnectExplicitHandler{ diff --git a/handler/openid/flow_hybrid.go b/handler/openid/flow_hybrid.go index 4d7dd0c8..892b5df0 100644 --- a/handler/openid/flow_hybrid.go +++ b/handler/openid/flow_hybrid.go @@ -21,7 +21,7 @@ type OpenIDConnectHybridHandler struct { OpenIDConnectRequestValidator *OpenIDConnectRequestValidator OpenIDConnectRequestStorage OpenIDConnectRequestStorage - Enigma *jwt.DefaultSigner + Enigma *jwt.DefaultStrategy Config interface { oauth2.IDTokenLifespanProvider diff --git a/handler/openid/flow_hybrid_test.go b/handler/openid/flow_hybrid_test.go index cef1aa76..9364c9b4 100644 --- a/handler/openid/flow_hybrid_test.go +++ b/handler/openid/flow_hybrid_test.go @@ -18,7 +18,6 @@ import ( hoauth2 "authelia.com/provider/oauth2/handler/oauth2" "authelia.com/provider/oauth2/internal" "authelia.com/provider/oauth2/internal/consts" - "authelia.com/provider/oauth2/internal/gen" "authelia.com/provider/oauth2/storage" "authelia.com/provider/oauth2/token/hmac" "authelia.com/provider/oauth2/token/jwt" @@ -432,35 +431,34 @@ var hmacStrategy = &hoauth2.HMACCoreStrategy{ } func makeOpenIDConnectHybridHandler(minParameterEntropy int) OpenIDConnectHybridHandler { + config := &oauth2.Config{ + ScopeStrategy: oauth2.HierarchicScopeStrategy, + MinParameterEntropy: minParameterEntropy, + AccessTokenLifespan: time.Hour, + AuthorizeCodeLifespan: time.Hour, + RefreshTokenLifespan: time.Hour, + } + + jwtStrategy := &jwt.DefaultStrategy{ + Config: config, + Issuer: jwt.NewDefaultIssuerRS256Unverified(key), + } + var idStrategy = &DefaultStrategy{ - Signer: &jwt.DefaultSigner{ - GetPrivateKey: func(_ context.Context) (any, error) { - return gen.MustRSAKey(), nil - }, - }, - Config: &oauth2.Config{ - MinParameterEntropy: minParameterEntropy, + Strategy: &jwt.DefaultStrategy{ + Config: config, + Issuer: jwt.MustGenDefaultIssuer(), }, + Config: config, } var j = &DefaultStrategy{ - Signer: &jwt.DefaultSigner{ - GetPrivateKey: func(_ context.Context) (any, error) { - return key, nil - }, - }, + Strategy: jwtStrategy, Config: &oauth2.Config{ MinParameterEntropy: minParameterEntropy, }, } - config := &oauth2.Config{ - ScopeStrategy: oauth2.HierarchicScopeStrategy, - MinParameterEntropy: minParameterEntropy, - AccessTokenLifespan: time.Hour, - AuthorizeCodeLifespan: time.Hour, - RefreshTokenLifespan: time.Hour, - } return OpenIDConnectHybridHandler{ AuthorizeExplicitGrantHandler: &hoauth2.AuthorizeExplicitGrantHandler{ AuthorizeCodeStrategy: hmacStrategy, @@ -479,7 +477,7 @@ func makeOpenIDConnectHybridHandler(minParameterEntropy int) OpenIDConnectHybrid IDTokenStrategy: idStrategy, }, Config: config, - OpenIDConnectRequestValidator: NewOpenIDConnectRequestValidator(j.Signer, config), + OpenIDConnectRequestValidator: NewOpenIDConnectRequestValidator(j.Strategy, config), OpenIDConnectRequestStorage: storage.NewMemoryStore(), } } diff --git a/handler/openid/flow_implicit.go b/handler/openid/flow_implicit.go index 7fe3c47a..e46006f5 100644 --- a/handler/openid/flow_implicit.go +++ b/handler/openid/flow_implicit.go @@ -18,7 +18,7 @@ type OpenIDConnectImplicitHandler struct { AuthorizeImplicitGrantTypeHandler *hoauth2.AuthorizeImplicitGrantTypeHandler OpenIDConnectRequestValidator *OpenIDConnectRequestValidator - RS256JWTStrategy *jwt.DefaultSigner + RS256JWTStrategy *jwt.DefaultStrategy Config interface { oauth2.IDTokenLifespanProvider diff --git a/handler/openid/flow_implicit_test.go b/handler/openid/flow_implicit_test.go index 81350df9..bfd4e80b 100644 --- a/handler/openid/flow_implicit_test.go +++ b/handler/openid/flow_implicit_test.go @@ -17,7 +17,6 @@ import ( hoauth2 "authelia.com/provider/oauth2/handler/oauth2" "authelia.com/provider/oauth2/internal" "authelia.com/provider/oauth2/internal/consts" - "authelia.com/provider/oauth2/internal/gen" "authelia.com/provider/oauth2/storage" "authelia.com/provider/oauth2/token/jwt" ) @@ -30,19 +29,17 @@ func makeOpenIDConnectImplicitHandler(minParameterEntropy int) OpenIDConnectImpl } var idStrategy = &DefaultStrategy{ - Signer: &jwt.DefaultSigner{ - GetPrivateKey: func(ctx context.Context) (any, error) { - return gen.MustRSAKey(), nil - }, + Strategy: &jwt.DefaultStrategy{ + Config: config, + Issuer: jwt.NewDefaultIssuerRS256Unverified(key), }, Config: config, } var j = &DefaultStrategy{ - Signer: &jwt.DefaultSigner{ - GetPrivateKey: func(ctx context.Context) (any, error) { - return key, nil - }, + Strategy: &jwt.DefaultStrategy{ + Config: config, + Issuer: jwt.NewDefaultIssuerRS256Unverified(key), }, Config: config, } @@ -56,7 +53,7 @@ func makeOpenIDConnectImplicitHandler(minParameterEntropy int) OpenIDConnectImpl IDTokenHandleHelper: &IDTokenHandleHelper{ IDTokenStrategy: idStrategy, }, - OpenIDConnectRequestValidator: NewOpenIDConnectRequestValidator(j.Signer, config), + OpenIDConnectRequestValidator: NewOpenIDConnectRequestValidator(j.Strategy, config), Config: config, } } diff --git a/handler/openid/flow_refresh_token_test.go b/handler/openid/flow_refresh_token_test.go index 333bfbdb..727a5aad 100644 --- a/handler/openid/flow_refresh_token_test.go +++ b/handler/openid/flow_refresh_token_test.go @@ -82,11 +82,12 @@ func TestOpenIDConnectRefreshHandler_HandleTokenEndpointRequest(t *testing.T) { } func TestOpenIDConnectRefreshHandler_PopulateTokenEndpointResponse(t *testing.T) { + config := &oauth2.Config{} + var j = &DefaultStrategy{ - Signer: &jwt.DefaultSigner{ - GetPrivateKey: func(ctx context.Context) (any, error) { - return key, nil - }, + Strategy: &jwt.DefaultStrategy{ + Config: config, + Issuer: jwt.NewDefaultIssuerRS256Unverified(key), }, Config: &oauth2.Config{ MinParameterEntropy: oauth2.MinParameterEntropy, @@ -97,7 +98,7 @@ func TestOpenIDConnectRefreshHandler_PopulateTokenEndpointResponse(t *testing.T) IDTokenHandleHelper: &IDTokenHandleHelper{ IDTokenStrategy: j, }, - Config: &oauth2.Config{}, + Config: config, } for _, c := range []struct { areq *oauth2.AccessRequest diff --git a/handler/openid/helper_test.go b/handler/openid/helper_test.go index f72c3a6f..b8041513 100644 --- a/handler/openid/helper_test.go +++ b/handler/openid/helper_test.go @@ -15,16 +15,16 @@ import ( "authelia.com/provider/oauth2" "authelia.com/provider/oauth2/internal/consts" - "authelia.com/provider/oauth2/internal/gen" "authelia.com/provider/oauth2/testing/mock" "authelia.com/provider/oauth2/token/jwt" ) var strategy = &DefaultStrategy{ - Signer: &jwt.DefaultSigner{ - GetPrivateKey: func(_ context.Context) (any, error) { - return gen.MustRSAKey(), nil + Strategy: &jwt.DefaultStrategy{ + Config: &oauth2.Config{ + MinParameterEntropy: oauth2.MinParameterEntropy, }, + Issuer: jwt.MustGenDefaultIssuer(), }, Config: &oauth2.Config{ MinParameterEntropy: oauth2.MinParameterEntropy, diff --git a/handler/openid/strategy_jwt.go b/handler/openid/strategy_jwt.go index d573e17a..5ae1bfeb 100644 --- a/handler/openid/strategy_jwt.go +++ b/handler/openid/strategy_jwt.go @@ -110,7 +110,7 @@ func (s *DefaultSession) IDTokenClaims() *jwt.IDTokenClaims { } type DefaultStrategy struct { - jwt.Signer + jwt.Strategy Config interface { oauth2.IDTokenIssuerProvider @@ -141,6 +141,8 @@ func (h DefaultStrategy) GenerateIDToken(ctx context.Context, lifespan time.Dura return "", errorsx.WithStack(oauth2.ErrServerError.WithDebug("Failed to generate id token because subject is an empty string.")) } + jwtClient := jwt.NewIDTokenClient(requester.GetClient()) + if requester.GetRequestForm().Get(consts.FormParameterGrantType) != consts.GrantTypeRefreshToken { maxAge, err := strconv.ParseInt(requester.GetRequestForm().Get(consts.FormParameterMaximumAge), 10, 64) if err != nil { @@ -190,7 +192,7 @@ func (h DefaultStrategy) GenerateIDToken(ctx context.Context, lifespan time.Dura } if tokenHintString := requester.GetRequestForm().Get(consts.FormParameterIDTokenHint); tokenHintString != "" { - tokenHint, err := h.Signer.Decode(ctx, tokenHintString) + tokenHint, err := h.Strategy.Decode(ctx, tokenHintString, jwt.WithClient(jwtClient)) var ve *jwt.ValidationError if errors.As(err, &ve) && ve.Has(jwt.ValidationErrorExpired) { // Expired ID Tokens are allowed as values to id_token_hint @@ -234,6 +236,7 @@ func (h DefaultStrategy) GenerateIDToken(ctx context.Context, lifespan time.Dura claims.Audience = stringslice.Unique(append(claims.Audience, requester.GetClient().GetID())) claims.IssuedAt = time.Now().UTC() - token, _, err = h.Signer.Generate(ctx, claims.ToMapClaims(), sess.IDTokenHeaders()) + token, _, err = h.Strategy.Encode(ctx, jwt.WithClaims(claims.ToMapClaims()), jwt.WithHeaders(sess.IDTokenHeaders()), jwt.WithClient(jwtClient)) + return token, err } diff --git a/handler/openid/strategy_jwt_test.go b/handler/openid/strategy_jwt_test.go index fe6d2ae4..4b1377c5 100644 --- a/handler/openid/strategy_jwt_test.go +++ b/handler/openid/strategy_jwt_test.go @@ -17,14 +17,16 @@ import ( ) func TestJWTStrategy_GenerateIDToken(t *testing.T) { + config := &oauth2.Config{ + MinParameterEntropy: oauth2.MinParameterEntropy, + } + var j = &DefaultStrategy{ - Signer: &jwt.DefaultSigner{ - GetPrivateKey: func(_ context.Context) (any, error) { - return key, nil - }}, - Config: &oauth2.Config{ - MinParameterEntropy: oauth2.MinParameterEntropy, + Strategy: &jwt.DefaultStrategy{ + Config: config, + Issuer: jwt.NewDefaultIssuerRS256Unverified(key), }, + Config: config, } var req *oauth2.AccessRequest diff --git a/handler/openid/validator.go b/handler/openid/validator.go index 8fc147c5..e432f3a9 100644 --- a/handler/openid/validator.go +++ b/handler/openid/validator.go @@ -31,11 +31,11 @@ type openIDConnectRequestValidatorConfigProvider interface { } type OpenIDConnectRequestValidator struct { - Strategy jwt.Signer + Strategy jwt.Strategy Config openIDConnectRequestValidatorConfigProvider } -func NewOpenIDConnectRequestValidator(strategy jwt.Signer, config openIDConnectRequestValidatorConfigProvider) *OpenIDConnectRequestValidator { +func NewOpenIDConnectRequestValidator(strategy jwt.Strategy, config openIDConnectRequestValidatorConfigProvider) *OpenIDConnectRequestValidator { return &OpenIDConnectRequestValidator{ Strategy: strategy, Config: config, @@ -144,7 +144,7 @@ func (v *OpenIDConnectRequestValidator) ValidatePrompt(ctx context.Context, req return nil } - tokenHint, err := v.Strategy.Decode(ctx, idTokenHint) + tokenHint, err := v.Strategy.Decode(ctx, idTokenHint, jwt.WithIDTokenClient(req.GetClient())) var ve *jwt.ValidationError if errors.As(err, &ve) && ve.Has(jwt.ValidationErrorExpired) { // Expired tokens are ok diff --git a/handler/openid/validator_test.go b/handler/openid/validator_test.go index 0aedb23f..5a9c805f 100644 --- a/handler/openid/validator_test.go +++ b/handler/openid/validator_test.go @@ -21,11 +21,12 @@ func TestValidatePrompt(t *testing.T) { config := &oauth2.Config{ MinParameterEntropy: oauth2.MinParameterEntropy, } + var j = &DefaultStrategy{ - Signer: &jwt.DefaultSigner{ - GetPrivateKey: func(_ context.Context) (any, error) { - return key, nil - }}, + Strategy: &jwt.DefaultStrategy{ + Config: config, + Issuer: jwt.NewDefaultIssuerRS256Unverified(key), + }, Config: &oauth2.Config{ MinParameterEntropy: oauth2.MinParameterEntropy, }, @@ -34,7 +35,7 @@ func TestValidatePrompt(t *testing.T) { v := NewOpenIDConnectRequestValidator(j, config) var genIDToken = func(c jwt.IDTokenClaims) string { - s, _, err := j.Generate(context.TODO(), c.ToMapClaims(), jwt.NewHeaders()) + s, _, err := j.Encode(context.TODO(), jwt.WithClaims(c.ToMapClaims())) require.NoError(t, err) return s } diff --git a/handler/rfc7523/handler.go b/handler/rfc7523/handler.go index 166320d4..8054bb8e 100644 --- a/handler/rfc7523/handler.go +++ b/handler/rfc7523/handler.go @@ -45,8 +45,8 @@ var ( // TODO: Refactor time permitting. // //nolint:gocyclo -func (c *Handler) HandleTokenEndpointRequest(ctx context.Context, request oauth2.AccessRequester) error { - if err := c.CheckRequest(ctx, request); err != nil { +func (c *Handler) HandleTokenEndpointRequest(ctx context.Context, request oauth2.AccessRequester) (err error) { + if err = c.CheckRequest(ctx, request); err != nil { return err } diff --git a/handler/rfc8693/custom_jwt_type_handler.go b/handler/rfc8693/custom_jwt_type_handler.go index 768900f6..297c5349 100644 --- a/handler/rfc8693/custom_jwt_type_handler.go +++ b/handler/rfc8693/custom_jwt_type_handler.go @@ -15,8 +15,9 @@ import ( ) type CustomJWTTypeHandler struct { - Config oauth2.RFC8693ConfigProvider - JWTStrategy jwt.Signer + Config oauth2.RFC8693ConfigProvider + + jwt.Strategy Storage } @@ -154,7 +155,7 @@ func (c *CustomJWTTypeHandler) validate(ctx context.Context, _ oauth2.AccessRequ } } - return map[string]any(claims), nil + return claims, nil } func (c *CustomJWTTypeHandler) issue(ctx context.Context, request oauth2.AccessRequester, tokenType oauth2.RFC8693TokenType, response oauth2.AccessResponder) error { @@ -202,7 +203,7 @@ func (c *CustomJWTTypeHandler) issue(ctx context.Context, request oauth2.AccessR claims.IssuedAt = time.Now().UTC() - token, _, err := c.JWTStrategy.Generate(ctx, claims.ToMapClaims(), sess.IDTokenHeaders()) + token, _, err := c.Strategy.Encode(ctx, jwt.WithClaims(claims.ToMapClaims()), jwt.WithHeaders(sess.IDTokenHeaders()), jwt.WithIDTokenClient(request.GetClient())) if err != nil { return err } @@ -210,6 +211,7 @@ func (c *CustomJWTTypeHandler) issue(ctx context.Context, request oauth2.AccessR response.SetAccessToken(token) response.SetTokenType("N_A") response.SetExpiresIn(time.Duration(claims.ExpiresAt.UnixNano() - time.Now().UTC().UnixNano())) + return nil } diff --git a/handler/rfc8693/id_token_type_handler.go b/handler/rfc8693/id_token_type_handler.go index 6a768563..b000fbb6 100644 --- a/handler/rfc8693/id_token_type_handler.go +++ b/handler/rfc8693/id_token_type_handler.go @@ -12,7 +12,7 @@ import ( type IDTokenTypeHandler struct { Config oauth2.Configurator - JWTStrategy jwt.Signer + Strategy jwt.Strategy IssueStrategy openid.OpenIDConnectTokenStrategy ValidationStrategy openid.TokenValidationStrategy Storage @@ -118,7 +118,7 @@ func (c *IDTokenTypeHandler) validate(ctx context.Context, request oauth2.Access return nil, errorsx.WithStack(oauth2.ErrInvalidRequest.WithHint("Claim 'sub' is missing.")) } - return map[string]any(claims), nil + return claims, nil } func (c *IDTokenTypeHandler) issue(ctx context.Context, request oauth2.AccessRequester, response oauth2.AccessResponder) error { diff --git a/handler/rfc8693/token_exchange_test.go b/handler/rfc8693/token_exchange_test.go index e1a595d0..b34f1aff 100644 --- a/handler/rfc8693/token_exchange_test.go +++ b/handler/rfc8693/token_exchange_test.go @@ -30,12 +30,6 @@ func TestAccessTokenExchangeImpersonation(t *testing.T) { store := storage.NewExampleStore() jwtName := "urn:custom:jwt" - jwtSigner := &jwt.DefaultSigner{ - GetPrivateKey: func(_ context.Context) (any, error) { - return key, nil - }, - } - customJWTType := &JWTType{ Name: jwtName, JWTValidationConfig: JWTValidationConfig{ @@ -70,6 +64,11 @@ func TestAccessTokenExchangeImpersonation(t *testing.T) { DefaultRequestedTokenType: consts.TokenTypeRFC8693AccessToken, } + strategy := &jwt.DefaultStrategy{ + Config: config, + Issuer: jwt.NewDefaultIssuerRS256Unverified(key), + } + coreStrategy := &hoauth2.HMACCoreStrategy{ Enigma: &hmac.HMACStrategy{Config: config}, Config: config, @@ -93,10 +92,9 @@ func TestAccessTokenExchangeImpersonation(t *testing.T) { customJWTHandler := &CustomJWTTypeHandler{ Config: config, - JWTStrategy: &jwt.DefaultSigner{ - GetPrivateKey: func(_ context.Context) (any, error) { - return key, nil - }, + Strategy: &jwt.DefaultStrategy{ + Config: config, + Issuer: jwt.NewDefaultIssuerRS256Unverified(key), }, Storage: store, } @@ -142,7 +140,7 @@ func TestAccessTokenExchangeImpersonation(t *testing.T) { Client: store.Clients["my-client"], Form: url.Values{ "subject_token_type": []string{jwtName}, - "subject_token": []string{createJWT(context.Background(), jwtSigner, jwt.MapClaims{ + "subject_token": []string{createJWT(context.Background(), store.Clients["my-client"], strategy, jwt.MapClaims{ "subject": "peter_for_jwt", "jti": uuid.New(), "iss": "https://customory.com", @@ -267,8 +265,9 @@ func createAccessToken(ctx context.Context, coreStrategy hoauth2.CoreStrategy, s return token } -func createJWT(ctx context.Context, signer jwt.Signer, claims jwt.MapClaims) string { - token, _, err := signer.Generate(ctx, claims, &jwt.Headers{}) +func createJWT(ctx context.Context, client any, strategy jwt.Strategy, claims jwt.MapClaims) string { + token, _, err := strategy.Encode(ctx, jwt.WithClaims(claims), jwt.WithIDTokenClient(client)) + if err != nil { panic(err.Error()) } diff --git a/integration/helper_setup_test.go b/integration/helper_setup_test.go index e5ff0038..436f650b 100644 --- a/integration/helper_setup_test.go +++ b/integration/helper_setup_test.go @@ -4,7 +4,6 @@ package integration_test import ( - "context" "crypto" "crypto/rand" "crypto/rsa" @@ -186,10 +185,9 @@ var hmacStrategy = &hoauth2.HMACCoreStrategy{ var defaultRSAKey = gen.MustRSAKey() var jwtStrategy = &hoauth2.JWTProfileCoreStrategy{ - Signer: &jwt.DefaultSigner{ - GetPrivateKey: func(ctx context.Context) (any, error) { - return defaultRSAKey, nil - }, + Strategy: &jwt.DefaultStrategy{ + Config: &oauth2.Config{}, + Issuer: jwt.NewDefaultIssuerRS256Unverified(defaultRSAKey), }, Config: &oauth2.Config{}, HMACCoreStrategy: hmacStrategy, diff --git a/integration/introspect_token_test.go b/integration/introspect_token_test.go index 72adec7d..6d1ffa46 100644 --- a/integration/introspect_token_test.go +++ b/integration/introspect_token_test.go @@ -26,10 +26,9 @@ func TestIntrospectToken(t *testing.T) { EnforceJWTProfileAccessTokens: true, } - signer := &jwt.DefaultSigner{ - GetPrivateKey: func(ctx context.Context) (any, error) { - return defaultRSAKey, nil - }, + strategy := &jwt.DefaultStrategy{ + Config: config, + Issuer: jwt.NewDefaultIssuerRS256Unverified(defaultRSAKey), } for _, c := range []struct { @@ -44,12 +43,12 @@ func TestIntrospectToken(t *testing.T) { }, { description: "JWT strategy with OAuth2TokenIntrospectionFactory", - strategy: hoauth2.NewCoreStrategy(config, "authelia_%s_", signer), + strategy: hoauth2.NewCoreStrategy(config, "authelia_%s_", strategy), factory: compose.OAuth2TokenIntrospectionFactory, }, { description: "JWT strategy with OAuth2StatelessJWTIntrospectionFactory", - strategy: hoauth2.NewCoreStrategy(config, "authelia_%s_", signer), + strategy: hoauth2.NewCoreStrategy(config, "authelia_%s_", strategy), factory: compose.OAuth2StatelessJWTIntrospectionFactory, }, } { diff --git a/internal/consts/client_auth_method.go b/internal/consts/client_auth_method.go index 982284d0..b2bec6c9 100644 --- a/internal/consts/client_auth_method.go +++ b/internal/consts/client_auth_method.go @@ -1,6 +1,6 @@ package consts -// Client Auth Method strings. +// Client Auth SignatureAlgorithm strings. const ( ClientAuthMethodClientSecretBasic = "client_secret_basic" ClientAuthMethodClientSecretPost = "client_secret_post" diff --git a/internal/consts/const.go b/internal/consts/const.go index 22911e86..5fbd19de 100644 --- a/internal/consts/const.go +++ b/internal/consts/const.go @@ -16,4 +16,5 @@ const ( valueNonce = "nonce" valueDeviceCode = "device_code" valueUserCode = "user_code" + valueEnc = "enc" ) diff --git a/internal/consts/jwt.go b/internal/consts/jwt.go index 37ddff52..9dc52a9c 100644 --- a/internal/consts/jwt.go +++ b/internal/consts/jwt.go @@ -1,21 +1,27 @@ package consts const ( - JSONWebTokenHeaderKeyIdentifier = "kid" - JSONWebTokenHeaderAlgorithm = "alg" - JSONWebTokenHeaderUse = "use" - JSONWebTokenHeaderType = "typ" + JSONWebTokenHeaderKeyIdentifier = "kid" + JSONWebTokenHeaderAlgorithm = "alg" + JSONWebTokenHeaderEncryptionAlgorithm = valueEnc + JSONWebTokenHeaderCompressionAlgorithm = "zip" + JSONWebTokenHeaderPBES2Count = "p2c" + + JSONWebTokenHeaderUse = "use" + JSONWebTokenHeaderType = "typ" + JSONWebTokenHeaderContentType = "cty" ) const ( JSONWebTokenUseSignature = "sig" - JSONWebTokenUseEncryption = "enc" + JSONWebTokenUseEncryption = valueEnc ) const ( - JSONWebTokenTypeJWT = "JWT" - JSONWebTokenTypeAccessToken = "at+jwt" - JSONWebTokenTypeTokenIntrospection = "token-introspection+jwt" + JSONWebTokenTypeJWT = "JWT" + JSONWebTokenTypeAccessToken = "at+jwt" + JSONWebTokenTypeAccessTokenAlternative = "application/at+jwt" + JSONWebTokenTypeTokenIntrospection = "token-introspection+jwt" ) const ( diff --git a/internal/consts/spec.go b/internal/consts/spec.go index db8742b8..bdfa6647 100644 --- a/internal/consts/spec.go +++ b/internal/consts/spec.go @@ -7,7 +7,7 @@ const ( PromptTypeSelectAccount = "select_account" ) -// Proof Key Code Exchange Challenge Method strings. +// Proof Key Code Exchange Challenge SignatureAlgorithm strings. const ( PKCEChallengeMethodPlain = "plain" PKCEChallengeMethodSHA256 = "S256" diff --git a/introspection_response_writer.go b/introspection_response_writer.go index c1b86104..abb59121 100644 --- a/introspection_response_writer.go +++ b/introspection_response_writer.go @@ -283,7 +283,7 @@ func (f *Fosite) writeIntrospectionResponse(ctx context.Context, rw http.Respons claims[consts.ClaimAudience] = aud } - signer := f.Config.GetIntrospectionJWTResponseSigner(ctx) + signer := f.Config.GetIntrospectionJWTResponseStrategy(ctx) if signer == nil { f.WriteIntrospectionError(ctx, rw, errors.WithStack(ErrServerError.WithHint("Failed to generate the response.").WithDebug("The Introspection JWT could not be generated as the server is misconfigured. The Introspection Signer was not configured."))) @@ -291,7 +291,7 @@ func (f *Fosite) writeIntrospectionResponse(ctx context.Context, rw http.Respons return } - if token, _, err = signer.Generate(ctx, claims, header); err != nil { + if token, _, err = signer.Encode(ctx, jwt.WithClaims(claims), jwt.WithHeaders(header), jwt.WithIntrospectionClient(r.GetAccessRequester().GetClient())); err != nil { f.WriteIntrospectionError(ctx, rw, errors.WithStack(ErrServerError.WithHint("Failed to generate the response.").WithDebugf("The Introspection JWT itself could not be generated with error %+v.", err))) return diff --git a/response_handler.go b/response_handler.go index 6823956c..051f3495 100644 --- a/response_handler.go +++ b/response_handler.go @@ -256,7 +256,7 @@ type ResponseModeHandlerConfigurator interface { FormPostHTMLTemplateProvider FormPostResponseProvider JWTSecuredAuthorizeResponseModeIssuerProvider - JWTSecuredAuthorizeResponseModeSignerProvider + JWTSecuredAuthorizeResponseModeStrategyProvider JWTSecuredAuthorizeResponseModeLifespanProvider MessageCatalogProvider SendDebugMessagesToClientsProvider diff --git a/testing/mock/client.go b/testing/mock/client.go index 4d43f535..1532fc6c 100644 --- a/testing/mock/client.go +++ b/testing/mock/client.go @@ -67,6 +67,21 @@ func (mr *MockClientMockRecorder) GetClientSecret() *gomock.Call { return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetClientSecret", reflect.TypeOf((*MockClient)(nil).GetClientSecret)) } +// GetClientSecretPlainText mocks base method. +func (m *MockClient) GetClientSecretPlainText() ([]byte, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetClientSecretPlainText") + ret0, _ := ret[0].([]byte) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// GetClientSecretPlainText indicates an expected call of GetClientSecretPlainText. +func (mr *MockClientMockRecorder) GetClientSecretPlainText() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetClientSecretPlainText", reflect.TypeOf((*MockClient)(nil).GetClientSecretPlainText)) +} + // GetGrantTypes mocks base method. func (m *MockClient) GetGrantTypes() oauth2.Arguments { m.ctrl.T.Helper() diff --git a/token/jarm/generate.go b/token/jarm/generate.go index bb4e18bd..07fc1ff3 100644 --- a/token/jarm/generate.go +++ b/token/jarm/generate.go @@ -73,11 +73,11 @@ func Generate(ctx context.Context, config Configurator, client Client, session a claims.Extra[param] = in.Get(param) } - var signer jwt.Signer + var signer jwt.Strategy - if signer = config.GetJWTSecuredAuthorizeResponseModeSigner(ctx); signer == nil { + if signer = config.GetJWTSecuredAuthorizeResponseModeStrategy(ctx); signer == nil { return "", "", errors.New("The JARM response modes require the JWTSecuredAuthorizeResponseModeSignerProvider to return a jwt.Signer but it didn't.") } - return signer.Generate(ctx, claims.ToMapClaims(), &jwt.Headers{Extra: headers}) + return signer.Encode(ctx, jwt.WithClaims(claims.ToMapClaims()), jwt.WithHeaders(&jwt.Headers{Extra: headers}), jwt.WithJARMClient(client)) } diff --git a/token/jarm/types.go b/token/jarm/types.go index 2ad362da..e82bea65 100644 --- a/token/jarm/types.go +++ b/token/jarm/types.go @@ -9,7 +9,7 @@ import ( type Configurator interface { GetJWTSecuredAuthorizeResponseModeIssuer(ctx context.Context) string - GetJWTSecuredAuthorizeResponseModeSigner(ctx context.Context) jwt.Signer + GetJWTSecuredAuthorizeResponseModeStrategy(ctx context.Context) jwt.Strategy GetJWTSecuredAuthorizeResponseModeLifespan(ctx context.Context) time.Duration } @@ -17,6 +17,8 @@ type Client interface { // GetID returns the client ID. GetID() (id string) + IsPublic() (public bool) + // GetAuthorizationSignedResponseKeyID returns the specific key identifier used to satisfy JWS requirements of the // JWT-secured Authorization Response Method (JARM) specifications. If unspecified the other available parameters // will be utilized to select an appropriate key. diff --git a/token/jwt/client.go b/token/jwt/client.go new file mode 100644 index 00000000..5f723492 --- /dev/null +++ b/token/jwt/client.go @@ -0,0 +1,468 @@ +package jwt + +import ( + "github.com/go-jose/go-jose/v4" +) + +func NewJARClient(client any) Client { + switch c := client.(type) { + case JARClient: + return &decoratedJARClient{JARClient: c} + default: + return nil + } +} + +func NewIDTokenClient(client any) Client { + switch c := client.(type) { + case IDTokenClient: + return &decoratedIDTokenClient{IDTokenClient: c} + default: + return nil + } +} + +func NewJARMClient(client any) Client { + switch c := client.(type) { + case JARMClient: + return &decoratedJARMClient{JARMClient: c} + default: + return nil + } +} + +func NewUserInfoClient(client any) Client { + switch c := client.(type) { + case UserInfoClient: + return &decoratedUserInfoClient{UserInfoClient: c} + default: + return nil + } +} + +func NewJWTProfileAccessTokenClient(client any) Client { + switch c := client.(type) { + case JWTProfileAccessTokenClient: + return &decoratedJWTProfileAccessTokenClient{JWTProfileAccessTokenClient: c} + default: + return nil + } +} + +func NewIntrospectionClient(client any) Client { + switch c := client.(type) { + case IntrospectionClient: + return &decoratedIntrospectionClient{IntrospectionClient: c} + default: + return nil + } +} + +func NewStatelessJWTProfileIntrospectionClient(client any) Client { + switch c := client.(type) { + case IntrospectionClient: + return &decoratedIntrospectionClient{IntrospectionClient: c} + case JWTProfileAccessTokenClient: + return &decoratedJWTProfileAccessTokenClient{JWTProfileAccessTokenClient: c} + default: + return nil + } +} + +type Client interface { + GetSignatureKeyID() (kid string) + GetSignatureAlg() (alg string) + GetEncryptionKeyID() (kid string) + GetEncryptionAlg() (alg string) + GetEncryptionEnc() (enc string) + + IsClientSigned() (is bool) + + BaseClient +} + +type BaseClient interface { + // GetClientSecretPlainText returns the ClientSecret as plaintext if available. + GetClientSecretPlainText() (secret []byte, err error) + + // GetJSONWebKeys returns the JSON Web Key Set containing the public key used by the client to authenticate. + GetJSONWebKeys() (jwks *jose.JSONWebKeySet) + + // GetJSONWebKeysURI returns the URL for lookup of JSON Web Key Set containing the + // public key used by the client to authenticate. + GetJSONWebKeysURI() (uri string) +} + +type JARClient interface { + // GetRequestObjectSigningKeyID returns the specific key identifier used to satisfy JWS requirements of the request + // object specifications. If unspecified the other available parameters will be utilized to select an appropriate + // key. + GetRequestObjectSigningKeyID() (kid string) + + // GetRequestObjectSigningAlg is equivalent to the 'request_object_signing_alg' client metadata + // value which determines the JWS alg algorithm [JWA] that MUST be used for signing Request Objects sent to the OP. + // All Request Objects from this Client MUST be rejected, if not signed with this algorithm. Request Objects are + // described in Section 6.1 of OpenID Connect Core 1.0 [OpenID.Core]. This algorithm MUST be used both when the + // Request Object is passed by value (using the request parameter) and when it is passed by reference (using the + // request_uri parameter). Servers SHOULD support RS256. The value none MAY be used. The default, if omitted, is + // that any algorithm supported by the OP and the RP MAY be used. + GetRequestObjectSigningAlg() (alg string) + + // GetRequestObjectEncryptionKeyID returns the specific key identifier used to satisfy JWE requirements of the + // request object specifications. If unspecified the other available parameters will be utilized to select an + // appropriate key. + GetRequestObjectEncryptionKeyID() (kid string) + + // GetRequestObjectEncryptionAlg is equivalent to the 'request_object_encryption_alg' client metadata value which + // determines the JWE alg algorithm [JWA] the RP is declaring that it may use for encrypting Request Objects sent to + // the OP. This parameter SHOULD be included when symmetric encryption will be used, since this signals to the OP + // that a client_secret value needs to be returned from which the symmetric key will be derived, that might not + // otherwise be returned. The RP MAY still use other supported encryption algorithms or send unencrypted Request + // Objects, even when this parameter is present. If both signing and encryption are requested, the Request Object + // will be signed then encrypted, with the result being a Nested JWT, as defined in [JWT]. The default, if omitted, + // is that the RP is not declaring whether it might encrypt any Request Objects. + GetRequestObjectEncryptionAlg() (alg string) + + // GetRequestObjectEncryptionEnc is equivalent to the 'request_object_encryption_enc' client metadata value which + // determines the JWE enc algorithm [JWA] the RP is declaring that it may use for encrypting Request Objects sent to + // the OP. If request_object_encryption_alg is specified, the default request_object_encryption_enc value is + // A128CBC-HS256. When request_object_encryption_enc is included, request_object_encryption_alg MUST also be + // provided. + GetRequestObjectEncryptionEnc() (enc string) + + BaseClient +} + +type decoratedJARClient struct { + JARClient +} + +func (r *decoratedJARClient) GetSignatureKeyID() (kid string) { + return r.GetRequestObjectSigningKeyID() +} + +func (r *decoratedJARClient) GetSignatureAlg() (alg string) { + return r.GetRequestObjectSigningAlg() +} + +func (r *decoratedJARClient) GetEncryptionKeyID() (kid string) { + return r.GetRequestObjectEncryptionKeyID() +} + +func (r *decoratedJARClient) GetEncryptionAlg() (alg string) { + return r.GetRequestObjectEncryptionAlg() +} + +func (r *decoratedJARClient) GetEncryptionEnc() (enc string) { + return r.GetRequestObjectEncryptionEnc() +} + +func (r *decoratedJARClient) IsClientSigned() (is bool) { + return true +} + +type IDTokenClient interface { + // GetIDTokenSignedResponseKeyID returns the specific key identifier used to satisfy JWS requirements of the ID + // Token specifications. If unspecified the other available parameters will be utilized to select an appropriate + // key. + GetIDTokenSignedResponseKeyID() (kid string) + + // GetIDTokenSignedResponseAlg is equivalent to the 'id_token_signed_response_alg' client metadata value which + // determines the JWS alg algorithm [JWA] REQUIRED for signing the ID Token issued to this Client. The value none + // MUST NOT be used as the ID Token alg value unless the Client uses only Response Types that return no ID Token + // from the Authorization Endpoint (such as when only using the Authorization Code Flow). The default, if omitted, + // is RS256. The public key for validating the signature is provided by retrieving the JWK Set referenced by the + // jwks_uri element from OpenID Connect Discovery 1.0 [OpenID.Discovery]. + GetIDTokenSignedResponseAlg() (alg string) + + // GetIDTokenEncryptedResponseKeyID returns the specific key identifier used to satisfy JWE requirements of the ID + // Token specifications. If unspecified the other available parameters will be utilized to select an appropriate + // key. + GetIDTokenEncryptedResponseKeyID() (kid string) + + // GetIDTokenEncryptedResponseAlg is equivalent to the 'id_token_encrypted_response_alg' client metadata value which + // determines the JWE alg algorithm [JWA] REQUIRED for encrypting the ID Token issued to this Client. If this is + // requested, the response will be signed then encrypted, with the result being a Nested JWT, as defined in [JWT]. + // The default, if omitted, is that no encryption is performed. + GetIDTokenEncryptedResponseAlg() (alg string) + + // GetIDTokenEncryptedResponseEnc is equivalent to the 'id_token_encrypted_response_enc' client metadata value which + // determines the JWE enc algorithm [JWA] REQUIRED for encrypting the ID Token issued to this Client. If + // id_token_encrypted_response_alg is specified, the default id_token_encrypted_response_enc value is A128CBC-HS256. + // When id_token_encrypted_response_enc is included, id_token_encrypted_response_alg MUST also be provided. + GetIDTokenEncryptedResponseEnc() (enc string) + + BaseClient +} + +type decoratedIDTokenClient struct { + IDTokenClient +} + +func (r *decoratedIDTokenClient) GetSignatureKeyID() (kid string) { + return r.GetIDTokenSignedResponseKeyID() +} + +func (r *decoratedIDTokenClient) GetSignatureAlg() (alg string) { + return r.GetIDTokenSignedResponseAlg() +} + +func (r *decoratedIDTokenClient) GetEncryptionKeyID() (kid string) { + return r.GetIDTokenEncryptedResponseKeyID() +} + +func (r *decoratedIDTokenClient) GetEncryptionAlg() (alg string) { + return r.GetIDTokenEncryptedResponseAlg() +} + +func (r *decoratedIDTokenClient) GetEncryptionEnc() (enc string) { + return r.GetIDTokenEncryptedResponseEnc() +} + +func (r *decoratedIDTokenClient) IsClientSigned() (is bool) { + return false +} + +type JARMClient interface { + // GetAuthorizationSignedResponseKeyID returns the specific key identifier used to satisfy JWS requirements of the + // JWT-secured Authorization Response Method (JARM) specifications. If unspecified the other available parameters + // will be utilized to select an appropriate key. + GetAuthorizationSignedResponseKeyID() (kid string) + + // GetAuthorizationSignedResponseAlg is equivalent to the 'authorization_signed_response_alg' client metadata + // value which determines the JWS [RFC7515] alg algorithm JWA [RFC7518] REQUIRED for signing authorization + // responses. If this is specified, the response will be signed using JWS and the configured algorithm. The + // algorithm none is not allowed. The default, if omitted, is RS256. + GetAuthorizationSignedResponseAlg() (alg string) + + // GetAuthorizationEncryptedResponseKeyID returns the specific key identifier used to satisfy JWE requirements of + // the JWT-secured Authorization Response Method (JARM) specifications. If unspecified the other available parameters will be + // utilized to select an appropriate key. + GetAuthorizationEncryptedResponseKeyID() (kid string) + + // GetAuthorizationEncryptedResponseAlg is equivalent to the 'authorization_encrypted_response_alg' client metadata + // value which determines the JWE [RFC7516] alg algorithm JWA [RFC7518] REQUIRED for encrypting authorization + // responses. If both signing and encryption are requested, the response will be signed then encrypted, with the + // result being a Nested JWT, as defined in JWT [RFC7519]. The default, if omitted, is that no encryption is + // performed. + GetAuthorizationEncryptedResponseAlg() (alg string) + + // GetAuthorizationEncryptedResponseEnc is equivalent to the 'authorization_encrypted_response_enc' client + // metadata value which determines the JWE [RFC7516] enc algorithm JWA [RFC7518] REQUIRED for encrypting + // authorization responses. If authorization_encrypted_response_alg is specified, the default for this value is + // A128CBC-HS256. When authorization_encrypted_response_enc is included, authorization_encrypted_response_alg MUST + // also be provided. + GetAuthorizationEncryptedResponseEnc() (alg string) + + BaseClient +} + +type decoratedJARMClient struct { + JARMClient +} + +func (r *decoratedJARMClient) GetSignatureKeyID() (kid string) { + return r.GetAuthorizationSignedResponseKeyID() +} + +func (r *decoratedJARMClient) GetSignatureAlg() (alg string) { + return r.GetAuthorizationSignedResponseAlg() +} + +func (r *decoratedJARMClient) GetEncryptionKeyID() (kid string) { + return r.GetAuthorizationEncryptedResponseKeyID() +} + +func (r *decoratedJARMClient) GetEncryptionAlg() (alg string) { + return r.GetAuthorizationEncryptedResponseAlg() +} + +func (r *decoratedJARMClient) GetEncryptionEnc() (enc string) { + return r.GetAuthorizationEncryptedResponseEnc() +} + +func (r *decoratedJARMClient) IsClientSigned() (is bool) { + return false +} + +type UserInfoClient interface { + // GetUserinfoSignedResponseKeyID returns the specific key identifier used to satisfy JWS requirements of the User + // Info specifications. If unspecified the other available parameters will be utilized to select an appropriate + // key. + GetUserinfoSignedResponseKeyID() (kid string) + + // GetUserinfoSignedResponseAlg is equivalent to the 'userinfo_signed_response_alg' client metadata value which + // determines the JWS alg algorithm [JWA] REQUIRED for signing UserInfo Responses. If this is specified, the + // response will be JWT [JWT] serialized, and signed using JWS. The default, if omitted, is for the UserInfo + // Response to return the Claims as a UTF-8 [RFC3629] encoded JSON object using the application/json content-type. + GetUserinfoSignedResponseAlg() (alg string) + + // GetUserinfoEncryptedResponseKeyID returns the specific key identifier used to satisfy JWE requirements of the + // User Info specifications. If unspecified the other available parameters will be utilized to select an appropriate + // key. + GetUserinfoEncryptedResponseKeyID() (kid string) + + // GetUserinfoEncryptedResponseAlg is equivalent to the 'userinfo_encrypted_response_alg' client metadata value + // which determines the JWE alg algorithm [JWA] REQUIRED for encrypting the ID Token issued to this Client. If + // this is requested, the response will be signed then encrypted, with the result being a Nested JWT, as defined in + // [JWT]. The default, if omitted, is that no encryption is performed. + GetUserinfoEncryptedResponseAlg() (alg string) + + // GetUserinfoEncryptedResponseEnc is equivalent to the 'userinfo_encrypted_response_enc' client metadata value + // which determines the JWE enc algorithm [JWA] REQUIRED for encrypting UserInfo Responses. If + // userinfo_encrypted_response_alg is specified, the default userinfo_encrypted_response_enc value is A128CBC-HS256. + // When userinfo_encrypted_response_enc is included, userinfo_encrypted_response_alg MUST also be provided. + GetUserinfoEncryptedResponseEnc() (enc string) + + BaseClient +} + +type decoratedUserInfoClient struct { + UserInfoClient +} + +func (r *decoratedUserInfoClient) GetSignatureKeyID() (kid string) { + return r.GetUserinfoSignedResponseKeyID() +} + +func (r *decoratedUserInfoClient) GetSignatureAlg() (alg string) { + return r.GetUserinfoSignedResponseAlg() +} + +func (r *decoratedUserInfoClient) GetEncryptionKeyID() (kid string) { + return r.GetUserinfoEncryptedResponseKeyID() +} + +func (r *decoratedUserInfoClient) GetEncryptionAlg() (alg string) { + return r.GetUserinfoEncryptedResponseAlg() +} + +func (r *decoratedUserInfoClient) GetEncryptionEnc() (enc string) { + return r.GetUserinfoEncryptedResponseEnc() +} + +func (r *decoratedUserInfoClient) IsClientSigned() (is bool) { + return false +} + +type JWTProfileAccessTokenClient interface { + // GetAccessTokenSignedResponseKeyID returns the specific key identifier used to satisfy JWS requirements for + // JWT Profile for OAuth 2.0 Access Tokens specifications. If unspecified the other available parameters will be + // utilized to select an appropriate key. + GetAccessTokenSignedResponseKeyID() (kid string) + + // GetAccessTokenSignedResponseAlg determines the JWS [RFC7515] algorithm (alg value) as defined in JWA [RFC7518] + // for signing JWT Profile Access Token responses. If this is specified, the response will be signed using JWS and + // the configured algorithm. The default, if omitted, is none; i.e. unsigned responses unless the + // GetEnableJWTProfileOAuthAccessTokens receiver returns true in which case the default is RS256. + GetAccessTokenSignedResponseAlg() (alg string) + + // GetAccessTokenEncryptedResponseKeyID returns the specific key identifier used to satisfy JWE requirements for + // JWT Profile for OAuth 2.0 Access Tokens specifications. If unspecified the other available parameters will be + // utilized to select an appropriate key. + GetAccessTokenEncryptedResponseKeyID() (kid string) + + // GetAccessTokenEncryptedResponseAlg determines the JWE [RFC7516] algorithm (alg value) as defined in JWA [RFC7518] + // for content key encryption. If this is specified, the response will be encrypted using JWE and the configured + // content encryption algorithm (access_token_encrypted_response_enc). The default, if omitted, is that no + // encryption is performed. If both signing and encryption are requested, the response will be signed then + // encrypted, with the result being a Nested JWT, as defined in JWT [RFC7519]. + GetAccessTokenEncryptedResponseAlg() (alg string) + + // GetAccessTokenEncryptedResponseEnc determines the JWE [RFC7516] algorithm (enc value) as defined in JWA [RFC7518] + // for content encryption of access token responses. The default, if omitted, is A128CBC-HS256. Note: This parameter + // MUST NOT be specified without setting access_token_encrypted_response_alg. + GetAccessTokenEncryptedResponseEnc() (alg string) + + BaseClient +} + +type decoratedJWTProfileAccessTokenClient struct { + JWTProfileAccessTokenClient +} + +func (r *decoratedJWTProfileAccessTokenClient) GetSignatureKeyID() (kid string) { + return r.GetAccessTokenSignedResponseKeyID() +} + +func (r *decoratedJWTProfileAccessTokenClient) GetSignatureAlg() (alg string) { + return r.GetAccessTokenSignedResponseAlg() +} + +func (r *decoratedJWTProfileAccessTokenClient) GetEncryptionKeyID() (kid string) { + return r.GetAccessTokenEncryptedResponseKeyID() +} + +func (r *decoratedJWTProfileAccessTokenClient) GetEncryptionAlg() (alg string) { + return r.GetAccessTokenEncryptedResponseAlg() +} + +func (r *decoratedJWTProfileAccessTokenClient) GetEncryptionEnc() (enc string) { + return r.GetAccessTokenEncryptedResponseEnc() +} + +func (r *decoratedJWTProfileAccessTokenClient) IsClientSigned() (is bool) { + return false +} + +type IntrospectionClient interface { + // GetIntrospectionSignedResponseKeyID returns the specific key identifier used to satisfy JWS requirements for + // OAuth 2.0 JWT introspection response specifications. If unspecified the other available parameters will be + // // utilized to select an appropriate key. + GetIntrospectionSignedResponseKeyID() (kid string) + + // GetIntrospectionSignedResponseAlg is equivalent to the 'introspection_signed_response_alg' client metadata + // value which determines the JWS [RFC7515] algorithm (alg value) as defined in JWA [RFC7518] for signing + // introspection responses. If this is specified, the response will be signed using JWS and the configured + // algorithm. The default, if omitted, is RS256. + GetIntrospectionSignedResponseAlg() (alg string) + + // GetIntrospectionEncryptedResponseKeyID returns the specific key identifier used to satisfy JWE requirements for + // OAuth 2.0 JWT introspection response specifications. If unspecified the other available parameters will be + // // utilized to select an appropriate key. + GetIntrospectionEncryptedResponseKeyID() (kid string) + + // GetIntrospectionEncryptedResponseAlg is equivalent to the 'introspection_encrypted_response_alg' client metadata + // value which determines the JWE [RFC7516] algorithm (alg value) as defined in JWA [RFC7518] for content key + // encryption. If this is specified, the response will be encrypted using JWE and the configured content encryption + // algorithm (introspection_encrypted_response_enc). The default, if omitted, is that no encryption is performed. + // If both signing and encryption are requested, the response will be signed then encrypted, with the result being + // a Nested JWT, as defined in JWT [RFC7519]. + GetIntrospectionEncryptedResponseAlg() (alg string) + + // GetIntrospectionEncryptedResponseEnc is equivalent to the 'introspection_encrypted_response_enc' client metadata + // value which determines the JWE [RFC7516] algorithm (enc value) as defined in JWA [RFC7518] for content + // encryption of introspection responses. The default, if omitted, is A128CBC-HS256. Note: This parameter MUST NOT + // be specified without setting introspection_encrypted_response_alg. + GetIntrospectionEncryptedResponseEnc() (enc string) + + BaseClient +} + +type decoratedIntrospectionClient struct { + IntrospectionClient +} + +func (r *decoratedIntrospectionClient) GetSignatureKeyID() (kid string) { + return r.GetIntrospectionSignedResponseKeyID() +} + +func (r *decoratedIntrospectionClient) GetSignatureAlg() (alg string) { + return r.GetIntrospectionSignedResponseAlg() +} + +func (r *decoratedIntrospectionClient) GetEncryptionKeyID() (kid string) { + return r.GetIntrospectionEncryptedResponseKeyID() +} + +func (r *decoratedIntrospectionClient) GetEncryptionAlg() (alg string) { + return r.GetIntrospectionEncryptedResponseAlg() +} + +func (r *decoratedIntrospectionClient) GetEncryptionEnc() (enc string) { + return r.GetIntrospectionEncryptedResponseEnc() +} + +func (r *decoratedIntrospectionClient) IsClientSigned() (is bool) { + return false +} diff --git a/token/jwt/client_test.go b/token/jwt/client_test.go new file mode 100644 index 00000000..4ae0be76 --- /dev/null +++ b/token/jwt/client_test.go @@ -0,0 +1,56 @@ +package jwt + +import ( + "fmt" + + "github.com/go-jose/go-jose/v4" +) + +type testClient struct { + secret []byte + kid, alg string + encKID, encAlg, enc string + csigned bool + jwks *jose.JSONWebKeySet + jwksURI string +} + +func (r *testClient) GetClientSecretPlainText() (secret []byte, err error) { + if r.secret != nil { + return r.secret, nil + } + + return nil, fmt.Errorf("not supported") +} + +func (r *testClient) GetSignatureKeyID() (kid string) { + return r.kid +} + +func (r *testClient) GetSignatureAlg() (alg string) { + return r.alg +} + +func (r *testClient) GetEncryptionKeyID() (kid string) { + return r.encKID +} + +func (r *testClient) GetEncryptionAlg() (alg string) { + return r.encAlg +} + +func (r *testClient) GetEncryptionEnc() (enc string) { + return r.enc +} + +func (r *testClient) IsClientSigned() (is bool) { + return r.csigned +} + +func (r *testClient) GetJSONWebKeys() (jwks *jose.JSONWebKeySet) { + return r.jwks +} + +func (r *testClient) GetJSONWebKeysURI() (uri string) { + return r.jwksURI +} diff --git a/token/jwt/consts.go b/token/jwt/consts.go new file mode 100644 index 00000000..19fb5470 --- /dev/null +++ b/token/jwt/consts.go @@ -0,0 +1,34 @@ +package jwt + +import ( + "github.com/go-jose/go-jose/v4" + + "authelia.com/provider/oauth2/internal/consts" +) + +const ( + SigningMethodNone = jose.SignatureAlgorithm(consts.JSONWebTokenAlgNone) + + // UnsafeAllowNoneSignatureType is unsafe to use and should be use to correctly sign and verify alg:none JWT tokens. + UnsafeAllowNoneSignatureType unsafeNoneMagicConstant = "none signing method allowed" +) + +type unsafeNoneMagicConstant string + +// Keyfunc is used by parsing methods to supply the key for verification. The function receives the parsed, but +// unverified Token. This allows you to use properties in the Header of the token (such as `kid`) to identify which key +// to use. +type Keyfunc func(token *Token) (key any, err error) + +var ( + // SignatureAlgorithmsNone contain all algorithms including 'none'. + SignatureAlgorithmsNone = []jose.SignatureAlgorithm{consts.JSONWebTokenAlgNone, jose.HS256, jose.HS384, jose.HS512, jose.RS256, jose.RS384, jose.RS512, jose.PS256, jose.PS384, jose.PS512, jose.ES256, jose.ES384, jose.ES512} + + // SignatureAlgorithms contain all algorithms excluding 'none'. + SignatureAlgorithms = []jose.SignatureAlgorithm{jose.HS256, jose.HS384, jose.HS512, jose.RS256, jose.RS384, jose.RS512, jose.PS256, jose.PS384, jose.PS512, jose.ES256, jose.ES384, jose.ES512} + + // EncryptionKeyAlgorithms contains all valid JWE's for OAuth 2.0 and OpenID Connect 1.0. + EncryptionKeyAlgorithms = []jose.KeyAlgorithm{jose.RSA1_5, jose.RSA_OAEP, jose.RSA_OAEP_256, jose.A128KW, jose.A192KW, jose.A256KW, jose.DIRECT, jose.ECDH_ES, jose.ECDH_ES_A128KW, jose.ECDH_ES_A192KW, jose.ECDH_ES_A256KW, jose.A128GCMKW, jose.A192GCMKW, jose.A256GCMKW, jose.PBES2_HS256_A128KW, jose.PBES2_HS384_A192KW, jose.PBES2_HS512_A256KW} + + ContentEncryptionAlgorithms = []jose.ContentEncryption{jose.A128CBC_HS256, jose.A192CBC_HS384, jose.A256CBC_HS512, jose.A128GCM, jose.A192GCM, jose.A256GCM} +) diff --git a/token/jwt/errors.go b/token/jwt/errors.go new file mode 100644 index 00000000..6607be86 --- /dev/null +++ b/token/jwt/errors.go @@ -0,0 +1,7 @@ +package jwt + +import "errors" + +var ( + ErrNotRegistered = errors.New("error: no JWKS registered") +) diff --git a/token/jwt/issuer.go b/token/jwt/issuer.go new file mode 100644 index 00000000..da07d599 --- /dev/null +++ b/token/jwt/issuer.go @@ -0,0 +1,115 @@ +package jwt + +import ( + "context" + "crypto/rand" + "crypto/rsa" + "errors" + "fmt" + + "github.com/go-jose/go-jose/v4" + + "authelia.com/provider/oauth2/internal/consts" +) + +func NewDefaultIssuer(keys ...jose.JSONWebKey) (issuer *DefaultIssuer, err error) { + jwks := &jose.JSONWebKeySet{ + Keys: make([]jose.JSONWebKey, len(keys)), + } + + hasRS256 := false + + for i, key := range keys { + jwks.Keys[i] = key + + if key.Use != consts.JSONWebTokenUseSignature { + continue + } + + if key.Algorithm != string(jose.RS256) { + continue + } + + hasRS256 = true + } + + if !hasRS256 { + return nil, errors.New("no RS256 signature algorithm found") + } + + return issuer, nil +} + +func MustNewDefaultIssuerRS256(key any) (issuer *DefaultIssuer) { + var err error + + if issuer, err = NewDefaultIssuerRS256(key); err != nil { + panic(err) + } + + return issuer +} + +func NewDefaultIssuerRS256(key any) (issuer *DefaultIssuer, err error) { + switch k := key.(type) { + case *rsa.PrivateKey: + if n := k.Size(); n < 256 { + return nil, fmt.Errorf("key must be an *rsa.PrivateKey with at least 2048 bits but got %d", n*8) + } + + return NewDefaultIssuerRS256Unverified(key), nil + default: + return nil, fmt.Errorf("key must be an *rsa.PrivateKey but got %T", k) + } +} + +func NewDefaultIssuerRS256Unverified(key any) (issuer *DefaultIssuer) { + return &DefaultIssuer{ + jwks: &jose.JSONWebKeySet{ + Keys: []jose.JSONWebKey{ + { + Key: key, + KeyID: "default", + Algorithm: string(jose.RS256), + Use: consts.JSONWebTokenUseSignature, + }, + }, + }, + } +} + +func GenDefaultIssuer() (issuer *DefaultIssuer, err error) { + key, err := rsa.GenerateKey(rand.Reader, 2048) + if err != nil { + return nil, err + } + + return NewDefaultIssuerRS256(key) +} + +func MustGenDefaultIssuer() (issuer *DefaultIssuer) { + var err error + + if issuer, err = GenDefaultIssuer(); err != nil { + panic(err) + } + + return issuer +} + +type DefaultIssuer struct { + jwks *jose.JSONWebKeySet +} + +func (i *DefaultIssuer) GetIssuerJWK(ctx context.Context, kid, alg, use string) (jwk *jose.JSONWebKey, err error) { + return SearchJWKS(i.jwks, kid, alg, use, false) +} + +func (i *DefaultIssuer) GetIssuerStrictJWK(ctx context.Context, kid, alg, use string) (jwk *jose.JSONWebKey, err error) { + return SearchJWKS(i.jwks, kid, alg, use, true) +} + +type Issuer interface { + GetIssuerJWK(ctx context.Context, kid, alg, use string) (jwk *jose.JSONWebKey, err error) + GetIssuerStrictJWK(ctx context.Context, kid, alg, use string) (jwk *jose.JSONWebKey, err error) +} diff --git a/token/jwt/jwt.go b/token/jwt/jwt.go deleted file mode 100644 index 4b325415..00000000 --- a/token/jwt/jwt.go +++ /dev/null @@ -1,198 +0,0 @@ -// Copyright © 2023 Ory Corp -// SPDX-License-Identifier: Apache-2.0 - -// Package jwt is able to generate and validate json web tokens. -// Follows https://datatracker.ietf.org/doc/html/rfc7519 - -package jwt - -import ( - "context" - "crypto" - "crypto/ecdsa" - "crypto/rsa" - "crypto/sha256" - "strings" - - "github.com/go-jose/go-jose/v4" - "github.com/pkg/errors" - - "authelia.com/provider/oauth2/x/errorsx" -) - -type Signer interface { - Generate(ctx context.Context, claims MapClaims, header Mapper) (tokenString string, signature string, err error) - Validate(ctx context.Context, tokenString string) (signature string, err error) - Hash(ctx context.Context, in []byte) ([]byte, error) - Decode(ctx context.Context, tokenString string) (token *Token, err error) - GetSignature(ctx context.Context, token string) (signature string, err error) - GetSigningMethodLength(ctx context.Context) (length int) -} - -var SHA256HashSize = crypto.SHA256.Size() - -type GetPrivateKeyFunc func(ctx context.Context) (any, error) - -// DefaultSigner is responsible for generating and validating JWT challenges -type DefaultSigner struct { - GetPrivateKey GetPrivateKeyFunc -} - -// Generate generates a new authorize code or returns an error. set secret -func (j *DefaultSigner) Generate(ctx context.Context, claims MapClaims, header Mapper) (string, string, error) { - key, err := j.GetPrivateKey(ctx) - if err != nil { - return "", "", err - } - - switch t := key.(type) { - case *jose.JSONWebKey: - return generateToken(claims, header, jose.SignatureAlgorithm(t.Algorithm), t.Key) - case jose.JSONWebKey: - return generateToken(claims, header, jose.SignatureAlgorithm(t.Algorithm), t.Key) - case *rsa.PrivateKey: - return generateToken(claims, header, jose.RS256, t) - case *ecdsa.PrivateKey: - return generateToken(claims, header, jose.ES256, t) - case jose.OpaqueSigner: - switch tt := t.Public().Key.(type) { - case *rsa.PrivateKey: - alg := jose.RS256 - if len(t.Algs()) > 0 { - alg = t.Algs()[0] - } - - return generateToken(claims, header, alg, t) - case *ecdsa.PrivateKey: - alg := jose.ES256 - if len(t.Algs()) > 0 { - alg = t.Algs()[0] - } - - return generateToken(claims, header, alg, t) - default: - return "", "", errors.Errorf("unsupported private / public key pairs: %T, %T", t, tt) - } - default: - return "", "", errors.Errorf("unsupported private key type: %T", t) - } -} - -// Validate validates a token and returns its signature or an error if the token is not valid. -func (j *DefaultSigner) Validate(ctx context.Context, token string) (string, error) { - key, err := j.GetPrivateKey(ctx) - if err != nil { - return "", err - } - - if t, ok := key.(*jose.JSONWebKey); ok { - key = t.Key - } - - switch t := key.(type) { - case *rsa.PrivateKey: - return validateToken(token, t.PublicKey) - case *ecdsa.PrivateKey: - return validateToken(token, t.PublicKey) - case jose.OpaqueSigner: - return validateToken(token, t.Public().Key) - default: - return "", errors.New("Unable to validate token. Invalid PrivateKey type") - } -} - -// Decode will decode a JWT token -func (j *DefaultSigner) Decode(ctx context.Context, token string) (*Token, error) { - key, err := j.GetPrivateKey(ctx) - if err != nil { - return nil, err - } - - if t, ok := key.(*jose.JSONWebKey); ok { - key = t.Key - } - - switch t := key.(type) { - case *rsa.PrivateKey: - return decodeToken(token, t.PublicKey) - case *ecdsa.PrivateKey: - return decodeToken(token, t.PublicKey) - case jose.OpaqueSigner: - return decodeToken(token, t.Public().Key) - default: - return nil, errors.New("Unable to decode token. Invalid PrivateKey type") - } -} - -// GetSignature will return the signature of a token -func (j *DefaultSigner) GetSignature(ctx context.Context, token string) (string, error) { - return getTokenSignature(token) -} - -// Hash will return a given hash based on the byte input or an error upon fail -func (j *DefaultSigner) Hash(ctx context.Context, in []byte) ([]byte, error) { - return hashSHA256(in) -} - -// GetSigningMethodLength will return the length of the signing method -func (j *DefaultSigner) GetSigningMethodLength(ctx context.Context) int { - return SHA256HashSize -} - -func generateToken(claims MapClaims, header Mapper, signingMethod jose.SignatureAlgorithm, privateKey any) (rawToken string, sig string, err error) { - if header == nil || claims == nil { - err = errors.New("either claims or header is nil") - return - } - - token := NewWithClaims(signingMethod, claims) - token.Header = assign(token.Header, header.ToMap()) - - rawToken, err = token.SignedString(privateKey) - if err != nil { - return - } - - sig, err = getTokenSignature(rawToken) - return -} - -func decodeToken(token string, verificationKey any) (*Token, error) { - keyFunc := func(*Token) (any, error) { return verificationKey, nil } - return ParseWithClaims(token, MapClaims{}, keyFunc) -} - -func validateToken(tokenStr string, verificationKey any) (string, error) { - _, err := decodeToken(tokenStr, verificationKey) - if err != nil { - return "", err - } - return getTokenSignature(tokenStr) -} - -func getTokenSignature(token string) (string, error) { - split := strings.Split(token, ".") - if len(split) != 3 { - return "", errors.New("header, body and signature must all be set") - } - return split[2], nil -} - -func hashSHA256(in []byte) ([]byte, error) { - hash := sha256.New() - _, err := hash.Write(in) - if err != nil { - return []byte{}, errorsx.WithStack(err) - } - return hash.Sum([]byte{}), nil -} - -func assign(a, b map[string]any) map[string]any { - for k, w := range b { - if _, ok := a[k]; ok { - continue - } - a[k] = w - } - return a -} diff --git a/token/jwt/jwt_signer_test.go b/token/jwt/jwt_signer_test.go new file mode 100644 index 00000000..f1cf090f --- /dev/null +++ b/token/jwt/jwt_signer_test.go @@ -0,0 +1,320 @@ +// Copyright © 2023 Ory Corp +// SPDX-License-Identifier: Apache-2.0 + +package jwt + +/* + +import ( + "context" + "crypto/ecdsa" + "crypto/elliptic" + "crypto/rand" + "encoding/json" + "fmt" + "strings" + "testing" + "time" + + "github.com/go-jose/go-jose/v4" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "authelia.com/provider/oauth2/internal/gen" +) + + +var header = &Headers{ + Extra: map[string]any{ + "foo": "bar", + }, +} + +func TestEncrypt(t *testing.T) { + i, err := ecdsa.GenerateKey(elliptic.P521(), rand.Reader) + require.NoError(t, err) + + c, err := ecdsa.GenerateKey(elliptic.P521(), rand.Reader) + require.NoError(t, err) + + issuer := jose.JSONWebKey{ + Key: i, + KeyID: "iss-abc123-es512", + Algorithm: string(jose.ES512), + Use: "sig", + } + + clientP := jose.JSONWebKey{ + Key: c, + KeyID: "client-abc123-es512", + Algorithm: string(jose.ECDH_ES_A256KW), + Use: "enc", + } + + client := jose.JSONWebKey{ + Key: &c.PublicKey, + KeyID: "client-abc123-es512", + Algorithm: string(jose.ECDH_ES_A256KW), + Use: "enc", + } + + issuerPublic := jose.JSONWebKey{ + Key: &i.PublicKey, + KeyID: "iss-abc123-es512", + Algorithm: string(jose.ES512), + Use: "sig", + } + + key := make([]byte, 64) + + _, err = rand.Read(key) + require.NoError(t, err) + + issuerDirect := jose.JSONWebKey{ + Key: key, + KeyID: "iss-abc123-es512", + Algorithm: string(jose.DIRECT), + Use: "enc", + } + + data, err := json.Marshal(issuer) + require.NoError(t, err) + fmt.Println(string(data)) + + data, err = json.Marshal(issuer.Public()) + require.NoError(t, err) + fmt.Println(string(data)) + + data, err = json.Marshal(issuerPublic) + require.NoError(t, err) + fmt.Println(string(data)) + + data, err = json.Marshal(issuerPublic.Public()) + require.NoError(t, err) + fmt.Println(string(data)) + + data, err = json.Marshal(client) + require.NoError(t, err) + fmt.Println(string(data)) + + data, err = json.Marshal(clientP) + require.NoError(t, err) + fmt.Println(string(data)) + + data, err = json.Marshal(issuerDirect) + require.NoError(t, err) + fmt.Println(string(data)) + + jwk2 := New() + jwk := New() + + claims := MapClaims{ + "name": "example", + } + + jwsHeaders := &Headers{} + jweHeaders := &Headers{} + + jwk.SetJWS(jwsHeaders, claims, jose.SignatureAlgorithm(issuer.Algorithm)) + jwk2.SetJWS(jwsHeaders, claims, jose.ES256) + jwk.SetJWE(jweHeaders, jose.KeyAlgorithm(client.Algorithm), jose.A256GCM, jose.NONE) + + token, signature, err := jwk.CompactEncrypted(&issuer, &client) + require.NoError(t, err) + + fmt.Println(token) + fmt.Println(signature) + + token, signature, err = jwk2.CompactSigned(&issuer) + require.NoError(t, err) + + fmt.Println(token) + fmt.Println(signature) +} + +func TestHash(t *testing.T) { + for k, tc := range []struct { + d string + strategy Signer + }{ + { + d: "RS256", + strategy: &DefaultSigner{GetPrivateKey: func(_ context.Context) (any, error) { + return gen.MustRSAKey(), nil + }}, + }, + { + d: "ES256", + strategy: &DefaultSigner{GetPrivateKey: func(_ context.Context) (any, error) { + return gen.MustES256Key(), nil + }}, + }, + } { + t.Run(fmt.Sprintf("case=%d/strategy=%s", k, tc.d), func(t *testing.T) { + in := []byte("foo") + out, err := tc.strategy.Hash(context.TODO(), in) + assert.NoError(t, err) + assert.NotEqual(t, in, out) + }) + } +} + +func TestAssign(t *testing.T) { + for k, c := range [][]map[string]any{ + { + {"foo": "bar"}, + {"baz": "bar"}, + {"foo": "bar", "baz": "bar"}, + }, + { + {"foo": "bar"}, + {"foo": "baz"}, + {"foo": "bar"}, + }, + { + {}, + {"foo": "baz"}, + {"foo": "baz"}, + }, + { + {"foo": "bar"}, + {"foo": "baz", "bar": "baz"}, + {"foo": "bar", "bar": "baz"}, + }, + } { + assert.EqualValues(t, c[2], assign(c[0], c[1]), "Case %d", k) + } +} + +func TestGenerateJWT(t *testing.T) { + testCases := []struct { + name string + key func() any + }{ + { + name: "DefaultSigner", + key: func() any { + return gen.MustRSAKey() + }, + }, + { + name: "ES256JWTStrategy", + key: func() any { + return gen.MustES256Key() + }, + }, + { + name: "ES256JWTStrategyWithJSONWebKey", + key: func() any { + return &jose.JSONWebKey{ + Key: gen.MustES521Key(), + Algorithm: "ES512", + } + }, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + ctx := context.Background() + + key := tc.key() + + strategy := &DefaultSigner{ + GetPrivateKey: func(_ context.Context) (any, error) { + return key, nil + }, + } + + claims := &JWTClaims{ + ExpiresAt: time.Now().UTC().Add(time.Hour), + } + + token, sig, err := strategy.Generate(ctx, claims.ToMapClaims(), header) + require.NoError(t, err) + require.NotNil(t, token) + assert.NotEmpty(t, sig) + + sig, err = strategy.Validate(ctx, token) + require.NoError(t, err) + assert.NotEmpty(t, sig) + + sig, err = strategy.Validate(ctx, token+"."+"0123456789") + require.Error(t, err) + assert.Empty(t, sig) + + partToken := strings.Split(token, ".")[2] + + sig, err = strategy.Validate(ctx, partToken) + require.Error(t, err) + assert.Empty(t, sig) + + key = tc.key() + + claims = &JWTClaims{ + ExpiresAt: time.Now().UTC().Add(-time.Hour), + } + + token, sig, err = strategy.Generate(ctx, claims.ToMapClaims(), header) + require.NoError(t, err) + require.NotNil(t, token) + assert.NotEmpty(t, sig) + + sig, err = strategy.Validate(ctx, token) + require.Error(t, err) + require.Empty(t, sig) + + claims = &JWTClaims{ + NotBefore: time.Now().UTC().Add(time.Hour), + } + + token, sig, err = strategy.Generate(ctx, claims.ToMapClaims(), header) + require.NoError(t, err) + require.NotNil(t, token) + assert.NotEmpty(t, sig) + + sig, err = strategy.Validate(ctx, token) + require.Error(t, err) + require.Empty(t, sig, "%s", err) + }) + } +} + +func TestValidateSignatureRejectsJWT(t *testing.T) { + for k, tc := range []struct { + d string + strategy Signer + }{ + { + d: "RS256", + strategy: &DefaultSigner{GetPrivateKey: func(_ context.Context) (any, error) { + return gen.MustRSAKey(), nil + }, + }, + }, + { + d: "ES256", + strategy: &DefaultSigner{ + GetPrivateKey: func(_ context.Context) (any, error) { + return gen.MustES256Key(), nil + }, + }, + }, + } { + t.Run(fmt.Sprintf("case=%d/strategy=%s", k, tc.d), func(t *testing.T) { + for k, c := range []string{ + "", + " ", + "foo.bar", + "foo.", + ".foo", + } { + _, err := tc.strategy.Validate(context.TODO(), c) + assert.Error(t, err) + t.Logf("Passed test case %d", k) + } + }) + } +} + +*/ diff --git a/token/jwt/jwt_strategy.go b/token/jwt/jwt_strategy.go new file mode 100644 index 00000000..b69d9692 --- /dev/null +++ b/token/jwt/jwt_strategy.go @@ -0,0 +1,284 @@ +package jwt + +import ( + "context" + "fmt" + + "github.com/go-jose/go-jose/v4" + "github.com/go-jose/go-jose/v4/jwt" + + "authelia.com/provider/oauth2/internal/consts" + "authelia.com/provider/oauth2/x/errorsx" +) + +// Strategy represents the strategy for encoding and decoding JWT's. +type Strategy interface { + Encode(ctx context.Context, opts ...StrategyOpt) (tokenString string, signature string, err error) + Decode(ctx context.Context, tokenString string, opts ...StrategyOpt) (token *Token, err error) + Decrypt(ctx context.Context, tokenStringEnc string, opts ...StrategyOpt) (tokenString, signature string, jwe *jose.JSONWebEncryption, err error) +} + +type StrategyConfig interface { + // GetJWKSFetcherStrategy returns the JWKS fetcher strategy. + GetJWKSFetcherStrategy(ctx context.Context) (strategy JWKSFetcherStrategy) +} + +type JWKSFetcherStrategy interface { + // Resolve returns the JSON Web Key Set, or an error if something went wrong. The forceRefresh, if true, forces + // the strategy to fetch the key from the remote. If forceRefresh is false, the strategy may use a caching strategy + // to fetch the key. + Resolve(ctx context.Context, location string, ignoreCache bool) (jwks *jose.JSONWebKeySet, err error) +} + +// DefaultStrategy is responsible for providing JWK encoding and cryptographic functionality. +type DefaultStrategy struct { + Config StrategyConfig + Issuer Issuer +} + +func (j *DefaultStrategy) Encode(ctx context.Context, opts ...StrategyOpt) (tokenString string, signature string, err error) { + o := &optsStrategy{ + claims: MapClaims{}, + headers: NewHeaders(), + } + + for _, opt := range opts { + if err = opt(o); err != nil { + return "", "", err + } + } + + var ( + keySig *jose.JSONWebKey + ) + + if o.client == nil { + if keySig, err = j.Issuer.GetIssuerJWK(ctx, "", string(jose.RS256), consts.JSONWebTokenUseSignature); err != nil { + return "", "", errorsx.WithStack(fmt.Errorf("error occurred retrieving issuer jwk: %w", err)) + } + } else if keySig, err = j.Issuer.GetIssuerJWK(ctx, o.client.GetSignatureKeyID(), o.client.GetSignatureAlg(), consts.JSONWebTokenUseSignature); err != nil { + return "", "", errorsx.WithStack(fmt.Errorf("error occurred retrieving issuer jwk: %w", err)) + } + + if o.client == nil { + return encodeCompactSigned(ctx, o.claims, o.headers, keySig) + } + + kid, alg, enc := o.client.GetEncryptionKeyID(), o.client.GetEncryptionAlg(), o.client.GetEncryptionEnc() + + if len(kid) == 0 && len(alg) == 0 { + return encodeCompactSigned(ctx, o.claims, o.headers, keySig) + } + + if len(enc) == 0 { + enc = string(jose.A128CBC_HS256) + } + + var keyEnc *jose.JSONWebKey + + if IsEncryptedJWTClientSecretAlg(alg) { + if keyEnc, err = NewJWKFromClientSecret(ctx, o.client, kid, alg, consts.JSONWebTokenUseEncryption); err != nil { + return "", "", errorsx.WithStack(fmt.Errorf("error occurred retrieving issuer jwk: error occurred retrieving the client secret: %w", err)) + } + } else if keyEnc, err = FindClientPublicJWK(ctx, o.client, j.Config.GetJWKSFetcherStrategy(ctx), kid, alg, consts.JSONWebTokenUseEncryption, false); err != nil { + return "", "", errorsx.WithStack(fmt.Errorf("error occurred retrieving client jwk: %w", err)) + } + + return encodeNestedCompactEncrypted(ctx, o.claims, o.headers, o.headersJWE, keySig, keyEnc, jose.ContentEncryption(enc)) +} + +func (j *DefaultStrategy) Decode(ctx context.Context, tokenString string, opts ...StrategyOpt) (token *Token, err error) { + o := &optsStrategy{ + sigAlgorithm: SignatureAlgorithms, + keyAlgorithm: EncryptionKeyAlgorithms, + contentEncryption: ContentEncryptionAlgorithms, + jwsKeyFunc: nil, + jweKeyFunc: nil, + } + + for _, opt := range opts { + if err = opt(o); err != nil { + return nil, errorsx.WithStack(err) + } + } + + var ( + key *jose.JSONWebKey + t *jwt.JSONWebToken + jwe *jose.JSONWebEncryption + ) + + if IsEncryptedJWT(tokenString) { + if jwe, err = jose.ParseEncryptedCompact(tokenString, o.keyAlgorithm, o.contentEncryption); err != nil { + return nil, errorsx.WithStack(&ValidationError{Errors: ValidationErrorMalformed, Inner: err}) + } + + var ( + kid, alg, cty string + ) + + if kid, alg, _, cty, err = headerValidateJWE(jwe.Header); err != nil { + return nil, errorsx.WithStack(&ValidationError{Errors: ValidationErrorMalformed, Inner: err}) + } + + if o.jweKeyFunc != nil { + if key, err = o.jweKeyFunc(ctx, jwe, kid, alg); err != nil { + return nil, errorsx.WithStack(&ValidationError{Errors: ValidationErrorUnverifiable, Inner: err}) + } + } else if IsEncryptedJWTClientSecretAlg(alg) { + if o.client == nil { + return nil, errorsx.WithStack(&ValidationError{Errors: ValidationErrorUnverifiable, Inner: err}) + } + + if key, err = NewJWKFromClientSecret(ctx, o.client, kid, alg, consts.JSONWebTokenUseEncryption); err != nil { + return nil, errorsx.WithStack(&ValidationError{Errors: ValidationErrorUnverifiable, Inner: err}) + } + } else if key, err = j.Issuer.GetIssuerStrictJWK(ctx, kid, alg, consts.JSONWebTokenUseEncryption); err != nil { + return nil, errorsx.WithStack(&ValidationError{Errors: ValidationErrorUnverifiable, Inner: err}) + } + + var rawJWT []byte + + if rawJWT, err = jwe.Decrypt(key); err != nil { + return nil, errorsx.WithStack(&ValidationError{Errors: ValidationErrorMalformed, Inner: err}) + } + + if t, err = jwt.ParseSigned(string(rawJWT), o.sigAlgorithm); err != nil { + return nil, errorsx.WithStack(&ValidationError{Errors: ValidationErrorMalformed, Inner: err}) + } + + if err = headerValidateJWSNested(t.Headers, cty); err != nil { + return nil, errorsx.WithStack(&ValidationError{Errors: ValidationErrorMalformed, Inner: err}) + } + } else if t, err = jwt.ParseSigned(tokenString, o.sigAlgorithm); err != nil { + return nil, errorsx.WithStack(&ValidationError{Errors: ValidationErrorMalformed, Inner: err}) + } + + claims := MapClaims{} + + if err = t.UnsafeClaimsWithoutVerification(&claims); err != nil { + return nil, &ValidationError{Errors: ValidationErrorClaimsInvalid, Inner: err} + } + + var kid, alg string + + if kid, alg, err = headerValidateJWS(t.Headers); err != nil { + return nil, errorsx.WithStack(&ValidationError{Errors: ValidationErrorMalformed, Inner: err}) + } + + if o.jwsKeyFunc != nil { + if key, err = o.jwsKeyFunc(ctx, t, claims); err != nil { + return nil, errorsx.WithStack(&ValidationError{Errors: ValidationErrorUnverifiable, Inner: err}) + } + } else if o.client != nil && o.client.IsClientSigned() { + if ckid := o.client.GetSignatureKeyID(); ckid != "" && ckid != kid { + return nil, errorsx.WithStack(&ValidationError{Errors: ValidationErrorMalformed, Inner: fmt.Errorf("error validating the jws header: kid '%s' does not match the registered kid '%s'", kid, ckid)}) + } + + if calg := o.client.GetSignatureAlg(); calg != "" && calg != alg { + return nil, errorsx.WithStack(&ValidationError{Errors: ValidationErrorMalformed, Inner: fmt.Errorf("error validating the jws header: alg '%s' does not match the registered alg '%s'", alg, calg)}) + } + + if key, err = FindClientPublicJWK(ctx, o.client, j.Config.GetJWKSFetcherStrategy(ctx), kid, alg, consts.JSONWebTokenUseSignature, true); err != nil { + return nil, errorsx.WithStack(&ValidationError{Errors: ValidationErrorUnverifiable, Inner: err}) + } + } else if key, err = j.Issuer.GetIssuerStrictJWK(ctx, kid, alg, consts.JSONWebTokenUseSignature); err != nil { + return nil, errorsx.WithStack(&ValidationError{Errors: ValidationErrorUnverifiable, Inner: err}) + } + + if err = t.Claims(key.Public(), &claims); err != nil { + return nil, errorsx.WithStack(&ValidationError{Errors: ValidationErrorSignatureInvalid, Inner: err}) + } + + if token, err = newToken(t, claims); err != nil { + return nil, errorsx.WithStack(err) + } + + token.AssignJWE(jwe) + + if err = claims.Valid(); err != nil { + return token, errorsx.WithStack(err) + } + + token.valid = true + + return token, nil +} + +func (j *DefaultStrategy) Decrypt(ctx context.Context, tokenStringEnc string, opts ...StrategyOpt) (tokenString, signature string, jwe *jose.JSONWebEncryption, err error) { + if !IsEncryptedJWT(tokenStringEnc) { + if IsSignedJWT(tokenStringEnc) { + return tokenStringEnc, "", nil, nil + } else { + return tokenStringEnc, "", nil, fmt.Errorf("token does not appear to be a jwe or jws compact serializd jwt") + } + } + + o := &optsStrategy{ + sigAlgorithm: SignatureAlgorithmsNone, + keyAlgorithm: EncryptionKeyAlgorithms, + contentEncryption: ContentEncryptionAlgorithms, + } + + for _, opt := range opts { + if err = opt(o); err != nil { + return "", "", nil, errorsx.WithStack(&ValidationError{Errors: ValidationErrorUnverifiable, Inner: err}) + } + } + + var ( + key *jose.JSONWebKey + ) + + if jwe, err = jose.ParseEncryptedCompact(tokenStringEnc, o.keyAlgorithm, o.contentEncryption); err != nil { + return "", "", nil, errorsx.WithStack(&ValidationError{Errors: ValidationErrorMalformed, Inner: err}) + } + + var ( + kid, alg, cty string + ) + + if kid, alg, _, cty, err = headerValidateJWE(jwe.Header); err != nil { + return "", "", nil, errorsx.WithStack(&ValidationError{Errors: ValidationErrorMalformed, Inner: err}) + } + + if o.jweKeyFunc != nil { + if key, err = o.jweKeyFunc(ctx, jwe, kid, alg); err != nil { + return "", "", nil, errorsx.WithStack(&ValidationError{Errors: ValidationErrorUnverifiable, Inner: err}) + } + } else if IsEncryptedJWTClientSecretAlg(alg) { + if o.client == nil { + return "", "", nil, errorsx.WithStack(&ValidationError{Errors: ValidationErrorUnverifiable, Inner: err}) + } + + if key, err = NewJWKFromClientSecret(ctx, o.client, kid, alg, consts.JSONWebTokenUseEncryption); err != nil { + return "", "", nil, errorsx.WithStack(&ValidationError{Errors: ValidationErrorUnverifiable, Inner: err}) + } + } else if key, err = j.Issuer.GetIssuerStrictJWK(ctx, kid, alg, consts.JSONWebTokenUseEncryption); err != nil { + return "", "", nil, errorsx.WithStack(&ValidationError{Errors: ValidationErrorUnverifiable, Inner: err}) + } + + var tokenRaw []byte + + if tokenRaw, err = jwe.Decrypt(key); err != nil { + return "", "", nil, errorsx.WithStack(&ValidationError{Errors: ValidationErrorMalformed, Inner: err}) + } + + tokenString = string(tokenRaw) + + var t *jwt.JSONWebToken + + if t, err = jwt.ParseSigned(tokenString, SignatureAlgorithmsNone); err != nil { + return "", "", nil, errorsx.WithStack(&ValidationError{Errors: ValidationErrorMalformed, Inner: err}) + } + + if err = headerValidateJWSNested(t.Headers, cty); err != nil { + return "", "", nil, errorsx.WithStack(&ValidationError{Errors: ValidationErrorMalformed, Inner: err}) + } + + if signature, err = getJWTSignature(tokenString); err != nil { + return "", "", nil, errorsx.WithStack(&ValidationError{Errors: ValidationErrorMalformed, Inner: err}) + } + + return string(tokenRaw), signature, jwe, nil +} diff --git a/token/jwt/jwt_strategy_opts.go b/token/jwt/jwt_strategy_opts.go new file mode 100644 index 00000000..6291bb50 --- /dev/null +++ b/token/jwt/jwt_strategy_opts.go @@ -0,0 +1,179 @@ +package jwt + +import ( + "context" + + "github.com/go-jose/go-jose/v4" + "github.com/go-jose/go-jose/v4/jwt" +) + +type optsStrategy struct { + client Client + claims MapClaims + + headers, headersJWE Mapper + + sigAlgorithm []jose.SignatureAlgorithm + keyAlgorithm []jose.KeyAlgorithm + contentEncryption []jose.ContentEncryption + + jwsKeyFunc KeyFuncJWS + jweKeyFunc KeyFuncJWE +} + +type ( + KeyFuncJWS func(ctx context.Context, token *jwt.JSONWebToken, claims MapClaims) (jwk *jose.JSONWebKey, err error) + KeyFuncJWE func(ctx context.Context, jwe *jose.JSONWebEncryption, kid, alg string) (jwk *jose.JSONWebKey, err error) + StrategyOpt func(opts *optsStrategy) (err error) +) + +func WithHeaders(headers Mapper) StrategyOpt { + return func(opts *optsStrategy) (err error) { + opts.headers = headers + + return nil + } +} + +func WithHeadersJWE(headers Mapper) StrategyOpt { + return func(opts *optsStrategy) (err error) { + opts.headersJWE = headers + + return nil + } +} + +func WithClaims(claims MapClaims) StrategyOpt { + return func(opts *optsStrategy) (err error) { + opts.claims = claims + + return nil + } +} + +func WithClient(client Client) StrategyOpt { + return func(opts *optsStrategy) (err error) { + opts.client = client + + return nil + } +} + +func WithIDTokenClient(client any) StrategyOpt { + return func(opts *optsStrategy) (err error) { + switch c := client.(type) { + case IDTokenClient: + opts.client = &decoratedIDTokenClient{IDTokenClient: c} + } + + return nil + } +} + +func WithUserInfoClient(client any) StrategyOpt { + return func(opts *optsStrategy) (err error) { + switch c := client.(type) { + case UserInfoClient: + opts.client = &decoratedUserInfoClient{UserInfoClient: c} + } + + return nil + } +} + +func WithIntrospectionClient(client any) StrategyOpt { + return func(opts *optsStrategy) (err error) { + switch c := client.(type) { + case IntrospectionClient: + opts.client = &decoratedIntrospectionClient{IntrospectionClient: c} + } + + return nil + } +} + +func WithJARMClient(client any) StrategyOpt { + return func(opts *optsStrategy) (err error) { + switch c := client.(type) { + case JARMClient: + opts.client = &decoratedJARMClient{JARMClient: c} + } + + return nil + } +} + +func WithJARClient(client any) StrategyOpt { + return func(opts *optsStrategy) (err error) { + switch c := client.(type) { + case JARClient: + opts.client = &decoratedJARClient{JARClient: c} + } + + return nil + } +} + +func WithJWTProfileAccessTokenClient(client any) StrategyOpt { + return func(opts *optsStrategy) (err error) { + switch c := client.(type) { + case JWTProfileAccessTokenClient: + opts.client = &decoratedJWTProfileAccessTokenClient{JWTProfileAccessTokenClient: c} + } + + return nil + } +} + +func WithNewStatelessJWTProfileIntrospectionClient(client any) StrategyOpt { + return func(opts *optsStrategy) (err error) { + switch c := client.(type) { + case IntrospectionClient: + opts.client = &decoratedIntrospectionClient{IntrospectionClient: c} + case JWTProfileAccessTokenClient: + opts.client = &decoratedJWTProfileAccessTokenClient{JWTProfileAccessTokenClient: c} + } + + return nil + } +} + +func WithSigAlgorithm(algs ...jose.SignatureAlgorithm) StrategyOpt { + return func(opts *optsStrategy) (err error) { + opts.sigAlgorithm = algs + + return nil + } +} + +func WithKeyAlgorithm(algs ...jose.KeyAlgorithm) StrategyOpt { + return func(opts *optsStrategy) (err error) { + opts.keyAlgorithm = algs + + return nil + } +} + +func WithContentEncryption(enc ...jose.ContentEncryption) StrategyOpt { + return func(opts *optsStrategy) (err error) { + opts.contentEncryption = enc + + return nil + } +} + +func WithKeyFunc(f KeyFuncJWS) StrategyOpt { + return func(opts *optsStrategy) (err error) { + opts.jwsKeyFunc = f + + return nil + } +} + +func WithKeyFuncJWE(f KeyFuncJWE) StrategyOpt { + return func(opts *optsStrategy) (err error) { + opts.jweKeyFunc = f + + return nil + } +} diff --git a/token/jwt/jwt_strategy_test.go b/token/jwt/jwt_strategy_test.go new file mode 100644 index 00000000..65c37418 --- /dev/null +++ b/token/jwt/jwt_strategy_test.go @@ -0,0 +1,361 @@ +package jwt + +import ( + "context" + "crypto/ecdsa" + "crypto/elliptic" + "crypto/rand" + "crypto/rsa" + "encoding/json" + "fmt" + "net/http" + "testing" + + "github.com/go-jose/go-jose/v4" + "github.com/stretchr/testify/require" + + "authelia.com/provider/oauth2/internal/consts" +) + +func TestDefaultStrategy(t *testing.T) { + ctx := context.TODO() + + config := &testConfig{} + + issuerRS256, err := rsa.GenerateKey(rand.Reader, 2048) + require.NoError(t, err) + + issuerES512, err := ecdsa.GenerateKey(elliptic.P521(), rand.Reader) + require.NoError(t, err) + + issuerES512enc, err := ecdsa.GenerateKey(elliptic.P521(), rand.Reader) + require.NoError(t, err) + + clientES512, err := ecdsa.GenerateKey(elliptic.P521(), rand.Reader) + require.NoError(t, err) + + clientES512enc, err := ecdsa.GenerateKey(elliptic.P521(), rand.Reader) + require.NoError(t, err) + + issuerJWKS := &jose.JSONWebKeySet{ + Keys: []jose.JSONWebKey{ + { + KeyID: "rs256-sig", + Key: issuerRS256, + Use: consts.JSONWebTokenUseSignature, + Algorithm: string(jose.RS256), + }, + { + KeyID: "es512-sig", + Key: issuerES512, + Use: consts.JSONWebTokenUseSignature, + Algorithm: string(jose.ES512), + }, + { + KeyID: "es512-enc", + Key: issuerES512enc, + Use: consts.JSONWebTokenUseEncryption, + Algorithm: string(jose.ECDH_ES_A256KW), + }, + }, + } + + issuerClientJWKS := &jose.JSONWebKeySet{ + Keys: []jose.JSONWebKey{ + { + KeyID: "rs256-sig", + Key: &issuerRS256.PublicKey, + Use: consts.JSONWebTokenUseSignature, + Algorithm: string(jose.RS256), + }, + { + KeyID: "es512-sig", + Key: &issuerES512.PublicKey, + Use: consts.JSONWebTokenUseSignature, + Algorithm: string(jose.ES512), + }, + { + KeyID: "es512-enc", + Key: &issuerES512enc.PublicKey, + Use: consts.JSONWebTokenUseEncryption, + Algorithm: string(jose.ECDH_ES_A256KW), + }, + }, + } + + issuer := &DefaultIssuer{ + jwks: issuerJWKS, + } + + clientIssuerJWKS := &jose.JSONWebKeySet{ + Keys: []jose.JSONWebKey{ + { + KeyID: "es512-sig", + Key: clientES512, + Use: consts.JSONWebTokenUseSignature, + Algorithm: string(jose.ES512), + }, + { + KeyID: "es512-enc", + Key: clientES512enc, + Use: consts.JSONWebTokenUseEncryption, + Algorithm: string(jose.ECDH_ES_A256KW), + }, + }, + } + + clientJWKS := &jose.JSONWebKeySet{ + Keys: []jose.JSONWebKey{ + { + KeyID: "es512-sig", + Key: &clientES512.PublicKey, + Use: consts.JSONWebTokenUseSignature, + Algorithm: string(jose.ES512), + }, + { + KeyID: "es512-enc", + Key: &clientES512enc.PublicKey, + Use: consts.JSONWebTokenUseEncryption, + Algorithm: string(jose.ECDH_ES_A256KW), + }, + }, + } + + issuerJWKSenc := &jose.JSONWebKeySet{ + Keys: []jose.JSONWebKey{ + { + KeyID: "es512-sig", + Key: &issuerES512.PublicKey, + Use: consts.JSONWebTokenUseSignature, + Algorithm: string(jose.ES512), + }, + { + KeyID: "es512-enc", + Key: &issuerES512enc.PublicKey, + Use: consts.JSONWebTokenUseEncryption, + Algorithm: string(jose.ECDH_ES_A256KW), + }, + }, + } + + clientJWKSenc := &jose.JSONWebKeySet{ + Keys: []jose.JSONWebKey{ + { + KeyID: "es512-sig", + Key: &clientES512.PublicKey, + Use: consts.JSONWebTokenUseSignature, + Algorithm: string(jose.ES512), + }, + { + KeyID: "es512-enc", + Key: &clientES512enc.PublicKey, + Use: consts.JSONWebTokenUseEncryption, + Algorithm: string(jose.ECDH_ES_A256KW), + }, + }, + } + + client := &testClient{ + kid: "es512-sig", + alg: "ES512", + encKID: "", + encAlg: "", + enc: "", + csigned: false, + jwks: clientJWKS, + jwksURI: "", + } + + clientEnc := &testClient{ + kid: "es512-sig", + alg: "ES512", + encKID: "es512-enc", + encAlg: string(jose.ECDH_ES_A256KW), + enc: string(jose.A256GCM), + csigned: false, + jwks: clientJWKSenc, + jwksURI: "", + } + + key128 := make([]byte, 32) + + _, err = rand.Read(key128) + require.NoError(t, err) + + clientEncAsymmetric := &testClient{ + kid: "es512-sig", + alg: "ES512", + encKID: "", + encAlg: string(jose.PBES2_HS256_A128KW), + enc: string(jose.A256GCM), + csigned: true, + secret: key128, + jwks: issuerJWKSenc, + jwksURI: "", + } + + strategy := &DefaultStrategy{ + Config: config, + Issuer: issuer, + } + + claims := MapClaims{ + "value": 1, + } + + headers1 := &Headers{ + Extra: map[string]any{ + consts.JSONWebTokenHeaderType: consts.JSONWebTokenTypeAccessToken, + }, + } + + var headersEnc *Headers + + var ( + token1, signature1 string + ) + + token1, signature1, err = strategy.Encode(ctx, WithClaims(claims), WithHeaders(headers1), WithClient(client)) + require.NoError(t, err) + + require.True(t, IsSignedJWT(token1)) + + fmt.Println("---------") + fmt.Println("Token 1:") + fmt.Println("\tValue:", token1) + fmt.Println("\tSignature:", signature1) + fmt.Println("---------") + fmt.Println("") + + headersEnc = &Headers{} + + var ( + token2, signature2 string + ) + + headers2 := &Headers{ + Extra: map[string]any{ + consts.JSONWebTokenHeaderType: consts.JSONWebTokenTypeJWT, + }, + } + + token2, signature2, err = strategy.Encode(ctx, WithClaims(claims), WithHeaders(headers2), WithHeadersJWE(headersEnc), WithClient(clientEnc)) + require.NoError(t, err) + require.True(t, IsEncryptedJWT(token2)) + + fmt.Println("---------") + fmt.Println("Token 2:") + fmt.Println("\tValue:", token2) + fmt.Println("\tSignature:", signature2) + fmt.Println("---------") + fmt.Println("") + + var ( + token3, signature3 string + ) + + token3, signature3, err = strategy.Encode(ctx, WithClaims(claims), WithHeaders(headers1), WithHeadersJWE(headersEnc), WithClient(clientEncAsymmetric)) + require.NoError(t, err) + + fmt.Println("---------") + fmt.Println("Token 3:") + fmt.Println("\tValue:", token3) + fmt.Println("\tSignature:", signature3) + fmt.Println("---------") + fmt.Println("") + + clientIssuer := &DefaultIssuer{ + jwks: clientIssuerJWKS, + } + + clientStrategy := &DefaultStrategy{ + Config: config, + Issuer: clientIssuer, + } + + issuerClient := &testClient{ + kid: "es512-sig", + alg: "ES512", + encKID: "", + encAlg: "", + enc: "", + csigned: true, + jwks: issuerClientJWKS, + jwksURI: "", + } + + tokenString, signature, jwe, err := clientStrategy.Decrypt(ctx, token2, WithClient(clientEncAsymmetric)) + require.NoError(t, err) + + fmt.Println("---------") + fmt.Println("Token 2 (Decrypted):") + fmt.Println("\tValue:", tokenString) + fmt.Println("\tSignature:", signature) + fmt.Println("\tJWE:", jwe) + fmt.Println("---------") + fmt.Println("") + + tokenString, signature, jwe, err = clientStrategy.Decrypt(ctx, token3, WithClient(clientEncAsymmetric)) + require.NoError(t, err) + + fmt.Println("---------") + fmt.Println("Token 3 (Decrypted):") + fmt.Println("\tValue:", tokenString) + fmt.Println("\tSignature:", signature) + fmt.Println("\tJWE:", jwe) + fmt.Println("---------") + fmt.Println("") + + tok, err := clientStrategy.Decode(ctx, token1, WithClient(issuerClient)) + require.NoError(t, err) + + fmt.Printf("%v+\n", tok) + + tok, err = clientStrategy.Decode(ctx, token2, WithClient(issuerClient)) + require.NoError(t, err) + + fmt.Printf("%v+\n", tok) + + tok, err = clientStrategy.Decode(ctx, token3, WithClient(clientEncAsymmetric)) + require.NoError(t, err) + + fmt.Printf("%v+\n", tok) +} + +type testConfig struct{} + +func (*testConfig) GetJWKSFetcherStrategy(ctx context.Context) (strategy JWKSFetcherStrategy) { + return &testFetcher{client: http.DefaultClient} +} + +type testFetcher struct { + client *http.Client +} + +func (f *testFetcher) Resolve(ctx context.Context, location string, _ bool) (jwks *jose.JSONWebKeySet, err error) { + var req *http.Request + + if req, err = http.NewRequest(http.MethodGet, location, nil); err != nil { + return nil, err + } + + req.WithContext(ctx) + + var resp *http.Response + + if resp, err = f.client.Do(req); err != nil { + return nil, err + } + + defer resp.Body.Close() + + decoder := json.NewDecoder(resp.Body) + + jwks = &jose.JSONWebKeySet{} + + if err = decoder.Decode(jwks); err != nil { + return nil, err + } + + return jwks, nil +} diff --git a/token/jwt/jwt_test.go b/token/jwt/jwt_test.go deleted file mode 100644 index 42236f31..00000000 --- a/token/jwt/jwt_test.go +++ /dev/null @@ -1,214 +0,0 @@ -// Copyright © 2023 Ory Corp -// SPDX-License-Identifier: Apache-2.0 - -package jwt - -import ( - "context" - "fmt" - "strings" - "testing" - "time" - - "github.com/go-jose/go-jose/v4" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" - - "authelia.com/provider/oauth2/internal/gen" -) - -var header = &Headers{ - Extra: map[string]any{ - "foo": "bar", - }, -} - -func TestHash(t *testing.T) { - for k, tc := range []struct { - d string - strategy Signer - }{ - { - d: "RS256", - strategy: &DefaultSigner{GetPrivateKey: func(_ context.Context) (any, error) { - return gen.MustRSAKey(), nil - }}, - }, - { - d: "ES256", - strategy: &DefaultSigner{GetPrivateKey: func(_ context.Context) (any, error) { - return gen.MustES256Key(), nil - }}, - }, - } { - t.Run(fmt.Sprintf("case=%d/strategy=%s", k, tc.d), func(t *testing.T) { - in := []byte("foo") - out, err := tc.strategy.Hash(context.TODO(), in) - assert.NoError(t, err) - assert.NotEqual(t, in, out) - }) - } -} - -func TestAssign(t *testing.T) { - for k, c := range [][]map[string]any{ - { - {"foo": "bar"}, - {"baz": "bar"}, - {"foo": "bar", "baz": "bar"}, - }, - { - {"foo": "bar"}, - {"foo": "baz"}, - {"foo": "bar"}, - }, - { - {}, - {"foo": "baz"}, - {"foo": "baz"}, - }, - { - {"foo": "bar"}, - {"foo": "baz", "bar": "baz"}, - {"foo": "bar", "bar": "baz"}, - }, - } { - assert.EqualValues(t, c[2], assign(c[0], c[1]), "Case %d", k) - } -} - -func TestGenerateJWT(t *testing.T) { - var key any = gen.MustRSAKey() - for k, tc := range []struct { - d string - strategy Signer - resetKey func(strategy Signer) - }{ - { - d: "DefaultSigner", - strategy: &DefaultSigner{ - GetPrivateKey: func(_ context.Context) (any, error) { - return key, nil - }, - }, - resetKey: func(strategy Signer) { - key = gen.MustRSAKey() - }, - }, - { - d: "ES256JWTStrategy", - strategy: &DefaultSigner{ - GetPrivateKey: func(_ context.Context) (any, error) { - return key, nil - }, - }, - resetKey: func(strategy Signer) { - key = &jose.JSONWebKey{ - Key: gen.MustES521Key(), - Algorithm: "ES512", - } - }, - }, - { - d: "ES256JWTStrategy", - strategy: &DefaultSigner{ - GetPrivateKey: func(_ context.Context) (any, error) { - return key, nil - }, - }, - resetKey: func(strategy Signer) { - key = gen.MustES256Key() - }, - }, - } { - t.Run(fmt.Sprintf("case=%d/strategy=%s", k, tc.d), func(t *testing.T) { - claims := &JWTClaims{ - ExpiresAt: time.Now().UTC().Add(time.Hour), - } - - token, sig, err := tc.strategy.Generate(context.TODO(), claims.ToMapClaims(), header) - require.NoError(t, err) - require.NotNil(t, token) - assert.NotEmpty(t, sig) - - sig, err = tc.strategy.Validate(context.TODO(), token) - require.NoError(t, err) - assert.NotEmpty(t, sig) - - sig, err = tc.strategy.Validate(context.TODO(), token+"."+"0123456789") - require.Error(t, err) - assert.Empty(t, sig) - - partToken := strings.Split(token, ".")[2] - - sig, err = tc.strategy.Validate(context.TODO(), partToken) - require.Error(t, err) - assert.Empty(t, sig) - - tc.resetKey(tc.strategy) - - claims = &JWTClaims{ - ExpiresAt: time.Now().UTC().Add(-time.Hour), - } - - token, sig, err = tc.strategy.Generate(context.TODO(), claims.ToMapClaims(), header) - require.NoError(t, err) - require.NotNil(t, token) - assert.NotEmpty(t, sig) - - sig, err = tc.strategy.Validate(context.TODO(), token) - require.Error(t, err) - require.Empty(t, sig) - - claims = &JWTClaims{ - NotBefore: time.Now().UTC().Add(time.Hour), - } - - token, sig, err = tc.strategy.Generate(context.TODO(), claims.ToMapClaims(), header) - require.NoError(t, err) - require.NotNil(t, token) - assert.NotEmpty(t, sig) - - sig, err = tc.strategy.Validate(context.TODO(), token) - require.Error(t, err) - require.Empty(t, sig, "%s", err) - }) - } -} - -func TestValidateSignatureRejectsJWT(t *testing.T) { - for k, tc := range []struct { - d string - strategy Signer - }{ - { - d: "RS256", - strategy: &DefaultSigner{GetPrivateKey: func(_ context.Context) (any, error) { - return gen.MustRSAKey(), nil - }, - }, - }, - { - d: "ES256", - strategy: &DefaultSigner{ - GetPrivateKey: func(_ context.Context) (any, error) { - return gen.MustES256Key(), nil - }, - }, - }, - } { - t.Run(fmt.Sprintf("case=%d/strategy=%s", k, tc.d), func(t *testing.T) { - for k, c := range []string{ - "", - " ", - "foo.bar", - "foo.", - ".foo", - } { - _, err := tc.strategy.Validate(context.TODO(), c) - assert.Error(t, err) - t.Logf("Passed test case %d", k) - } - }) - } -} diff --git a/token/jwt/token.go b/token/jwt/token.go index 150b5da3..d76cbf57 100644 --- a/token/jwt/token.go +++ b/token/jwt/token.go @@ -9,6 +9,7 @@ import ( "errors" "fmt" "reflect" + "strings" "github.com/go-jose/go-jose/v4" "github.com/go-jose/go-jose/v4/jwt" @@ -17,6 +18,115 @@ import ( "authelia.com/provider/oauth2/x/errorsx" ) +func New() *Token { + return &Token{ + Header: map[string]any{}, + HeaderJWE: map[string]any{}, + } +} + +// NewWithClaims creates an unverified Token with the given claims and signing method +func NewWithClaims(alg jose.SignatureAlgorithm, claims MapClaims) *Token { + return &Token{ + Claims: claims, + SignatureAlgorithm: alg, + Header: map[string]any{}, + HeaderJWE: map[string]any{}, + } +} + +// Parse is an overload for ParseCustom which accepts all normal algs including 'none'. +func Parse(tokenString string, keyFunc Keyfunc) (*Token, error) { + return ParseCustom(tokenString, keyFunc, SignatureAlgorithmsNone...) +} + +// ParseCustom parses, validates, and returns a token. The keyFunc will receive the parsed token and should +// return the key for validating. If everything is kosher, err will be nil. +func ParseCustom(tokenString string, keyFunc Keyfunc, algs ...jose.SignatureAlgorithm) (token *Token, err error) { + return ParseCustomWithClaims(tokenString, MapClaims{}, keyFunc, algs...) +} + +// ParseWithClaims is an overload for ParseCustomWithClaims which accepts all normal algs including 'none'. +func ParseWithClaims(tokenString string, claims MapClaims, keyFunc Keyfunc) (token *Token, err error) { + return ParseCustomWithClaims(tokenString, claims, keyFunc, SignatureAlgorithmsNone...) +} + +// ParseCustomWithClaims parses, validates, and returns a token with its respective claims. The keyFunc will receive the parsed token and should +// return the key for validating. If everything is kosher, err will be nil. +func ParseCustomWithClaims(tokenString string, claims MapClaims, keyFunc Keyfunc, algs ...jose.SignatureAlgorithm) (token *Token, err error) { + var parsed *jwt.JSONWebToken + + if parsed, err = jwt.ParseSigned(tokenString, algs); err != nil { + return &Token{}, &ValidationError{Errors: ValidationErrorMalformed, Inner: err} + } + + // fill unverified claims + // This conversion is required because go-jose supports + // only marshalling structs or maps but not alias types from maps + // + // The KeyFunc(*Token) function requires the claims to be set into the + // Token, that is an unverified token, therefore an UnsafeClaimsWithoutVerification is done first + // then with the returned key, the claims gets verified. + if err = parsed.UnsafeClaimsWithoutVerification(&claims); err != nil { + return nil, &ValidationError{Errors: ValidationErrorClaimsInvalid, Inner: err} + } + + // creates an unsafe token + if token, err = newToken(parsed, claims); err != nil { + return nil, err + } + + if keyFunc == nil { + return token, &ValidationError{Errors: ValidationErrorUnverifiable, text: "no Keyfunc was provided."} + } + + var key any + + if key, err = keyFunc(token); err != nil { + // keyFunc returned an error + var ve *ValidationError + + if errors.As(err, &ve) { + return token, ve + } + + return token, &ValidationError{Errors: ValidationErrorUnverifiable, Inner: err} + } + + if key == nil { + return token, &ValidationError{Errors: ValidationErrorSignatureInvalid, text: "keyfunc returned a nil verification key"} + } + // To verify signature go-jose requires a pointer to + // public key instead of the public key value. + // The pointer values provides that pointer. + // E.g. transform rsa.PublicKey -> *rsa.PublicKey + key = pointer(key) + + // verify signature with returned key + _, validNoneKey := key.(*unsafeNoneMagicConstant) + isSignedToken := !(token.SignatureAlgorithm == SigningMethodNone && validNoneKey) + if isSignedToken { + if err = parsed.Claims(key, &claims); err != nil { + return token, &ValidationError{Errors: ValidationErrorSignatureInvalid, text: err.Error()} + } + } + + // Validate claims + // This validation is performed to be backwards compatible + // with jwt-go library behavior + if err = claims.Valid(); err != nil { + if e, ok := err.(*ValidationError); !ok { + err = &ValidationError{Inner: e, Errors: ValidationErrorClaimsInvalid} + } + + return token, err + } + + token.valid = true + + return token, nil +} + // Token represets a JWT Token // This token provide an adaptation to // transit from [jwt-go](https://github.com/dgrijalva/jwt-go) @@ -24,30 +134,19 @@ import ( // It provides method signatures compatible with jwt-go but implemented // using go-json type Token struct { - Header map[string]any // The first segment of the token - Claims MapClaims // The second segment of the token - Method jose.SignatureAlgorithm - valid bool -} - -const ( - SigningMethodNone = jose.SignatureAlgorithm(consts.JSONWebTokenAlgNone) - // This key should be use to correctly sign and verify alg:none JWT tokens - UnsafeAllowNoneSignatureType unsafeNoneMagicConstant = "none signing method allowed" - - JWTHeaderType = jose.HeaderKey(consts.JSONWebTokenHeaderType) -) + KeyID string + SignatureAlgorithm jose.SignatureAlgorithm // alg (JWS) + KeyAlgorithm jose.KeyAlgorithm // alg (JWE) + ContentEncryption jose.ContentEncryption // enc (JWE) + CompressionAlgorithm jose.CompressionAlgorithm // zip (JWE) -const ( - JWTHeaderKeyValueType = consts.JSONWebTokenHeaderType -) + Header map[string]any + HeaderJWE map[string]any -const ( - JWTHeaderTypeValueJWT = consts.JSONWebTokenTypeJWT - JWTHeaderTypeValueAccessTokenJWT = consts.JSONWebTokenTypeAccessToken -) + Claims MapClaims -type unsafeNoneMagicConstant string + valid bool +} // Valid informs if the token was verified against a given verification key // and claims are valid @@ -64,194 +163,254 @@ type Claims interface { Valid() error } -// NewWithClaims creates an unverified Token with the given claims and signing method -func NewWithClaims(method jose.SignatureAlgorithm, claims MapClaims) *Token { - return &Token{ - Claims: claims, - Method: method, - Header: map[string]any{}, +func (t *Token) toSignedJoseHeader() (header map[jose.HeaderKey]any) { + header = map[jose.HeaderKey]any{ + consts.JSONWebTokenHeaderType: consts.JSONWebTokenTypeJWT, } -} -func (t *Token) toJoseHeader() map[jose.HeaderKey]any { - h := map[jose.HeaderKey]any{ - JWTHeaderType: JWTHeaderTypeValueJWT, - } for k, v := range t.Header { - h[jose.HeaderKey(k)] = v + header[jose.HeaderKey(k)] = v } - return h + + return header } -// SignedString provides a compatible `jwt-go` Token.SignedString method -// -// > Get the complete, signed token -func (t *Token) SignedString(k any) (rawToken string, err error) { - if _, ok := k.(unsafeNoneMagicConstant); ok { - rawToken, err = unsignedToken(t) - return +func (t *Token) toEncryptedJoseHeader() (header map[jose.HeaderKey]any) { + header = map[jose.HeaderKey]any{ + consts.JSONWebTokenHeaderType: consts.JSONWebTokenTypeJWT, } - var signer jose.Signer - - key := jose.SigningKey{ - Algorithm: t.Method, - Key: k, + if cty, ok := t.Header[consts.JSONWebTokenHeaderType]; ok { + header[consts.JSONWebTokenHeaderContentType] = cty } - opts := &jose.SignerOptions{ExtraHeaders: t.toJoseHeader()} - signer, err = jose.NewSigner(key, opts) - if err != nil { - err = errorsx.WithStack(err) - return + + for k, v := range t.HeaderJWE { + header[jose.HeaderKey(k)] = v } - // A explicit conversion from type alias MapClaims - // to map[string]any is required because the - // go-jose CompactSerialize() only support explicit maps - // as claims or structs but not type aliases from maps. - claims := map[string]any(t.Claims) - rawToken, err = jwt.Signed(signer).Claims(claims).Serialize() - if err != nil { - err = &ValidationError{Errors: ValidationErrorClaimsInvalid, Inner: err} + return header +} + +// SetJWS sets the JWS output values. +func (t *Token) SetJWS(header Mapper, claims MapClaims, alg jose.SignatureAlgorithm) { + assign(t.Header, header.ToMap()) + + t.SignatureAlgorithm = alg + + t.Claims = claims +} + +// SetJWE sets the JWE output values. +func (t *Token) SetJWE(header Mapper, alg jose.KeyAlgorithm, enc jose.ContentEncryption, zip jose.CompressionAlgorithm) { + assign(t.HeaderJWE, header.ToMap()) + + t.KeyAlgorithm = alg + t.ContentEncryption = enc + t.CompressionAlgorithm = zip +} + +// AssignJWE assigns values derived from the JWE decryption process to the Token. +func (t *Token) AssignJWE(jwe *jose.JSONWebEncryption) { + if jwe == nil { return } - return -} -func unsignedToken(t *Token) (string, error) { - t.Header[consts.JSONWebTokenHeaderAlgorithm] = consts.JSONWebTokenAlgNone - if _, ok := t.Header[string(JWTHeaderType)]; !ok { - t.Header[string(JWTHeaderType)] = JWTHeaderTypeValueJWT + t.HeaderJWE = map[string]any{ + consts.JSONWebTokenHeaderAlgorithm: jwe.Header.Algorithm, } - hbytes, err := json.Marshal(&t.Header) - if err != nil { - return "", errorsx.WithStack(err) + + if jwe.Header.KeyID != "" { + t.HeaderJWE[consts.JSONWebTokenHeaderKeyIdentifier] = jwe.Header.KeyID } - bbytes, err := json.Marshal(&t.Claims) - if err != nil { - return "", errorsx.WithStack(err) + + for header, value := range jwe.Header.ExtraHeaders { + h := string(header) + + t.HeaderJWE[h] = value + + switch h { + case consts.JSONWebTokenHeaderEncryptionAlgorithm: + if v, ok := value.(string); ok { + t.ContentEncryption = jose.ContentEncryption(v) + } + case consts.JSONWebTokenHeaderCompressionAlgorithm: + if v, ok := value.(string); ok { + t.CompressionAlgorithm = jose.CompressionAlgorithm(v) + } + } } - h := base64.RawURLEncoding.EncodeToString(hbytes) - b := base64.RawURLEncoding.EncodeToString(bbytes) - return fmt.Sprintf("%v.%v.", h, b), nil + + t.KeyAlgorithm = jose.KeyAlgorithm(jwe.Header.Algorithm) } -func newToken(parsedToken *jwt.JSONWebToken, claims MapClaims) (*Token, error) { - token := &Token{Claims: claims} - if len(parsedToken.Headers) != 1 { - return nil, &ValidationError{text: fmt.Sprintf("only one header supported, got %v", len(parsedToken.Headers)), Errors: ValidationErrorMalformed} +func (t *Token) CompactEncrypted(keySig, keyEnc any) (tokenString, signature string, err error) { + var ( + signed string + ) + + if signed, signature, err = t.CompactSigned(keySig); err != nil { + return "", "", err } - // copy headers - h := parsedToken.Headers[0] - token.Header = map[string]any{ - consts.JSONWebTokenHeaderAlgorithm: h.Algorithm, + rcpt := jose.Recipient{ + Algorithm: t.KeyAlgorithm, + Key: keyEnc, } - if h.KeyID != "" { - token.Header[consts.JSONWebTokenHeaderKeyIdentifier] = h.KeyID + + opts := &jose.EncrypterOptions{ + Compression: t.CompressionAlgorithm, + ExtraHeaders: t.toEncryptedJoseHeader(), } - for k, v := range h.ExtraHeaders { - token.Header[string(k)] = v + + if _, ok := opts.ExtraHeaders[consts.JSONWebTokenHeaderContentType]; !ok { + var typ any + + if typ, ok = t.Header[consts.JSONWebTokenHeaderType]; ok { + opts.ExtraHeaders[consts.JSONWebTokenHeaderContentType] = typ + } else { + opts.ExtraHeaders[consts.JSONWebTokenHeaderContentType] = consts.JSONWebTokenTypeJWT + } } - token.Method = jose.SignatureAlgorithm(h.Algorithm) + var encrypter jose.Encrypter - return token, nil -} + if encrypter, err = jose.NewEncrypter(t.ContentEncryption, rcpt, opts); err != nil { + return "", "", errorsx.WithStack(err) + } -// Keyfunc is used by parsing methods to supply the key for verification. The function receives the parsed, but -// unverified Token. This allows you to use properties in the Header of the token (such as `kid`) to identify which key -// to use. -type Keyfunc func(*Token) (any, error) + var token *jose.JSONWebEncryption -// Parse is an overload for ParseCustom which accepts all normal algs including 'none'. -func Parse(tokenString string, keyFunc Keyfunc) (*Token, error) { - return ParseCustom(tokenString, keyFunc, consts.JSONWebTokenAlgNone, jose.HS256, jose.HS384, jose.HS512, jose.RS256, jose.RS384, jose.RS512, jose.PS256, jose.PS384, jose.PS512, jose.ES256, jose.ES384, jose.ES512) -} + if token, err = encrypter.Encrypt([]byte(signed)); err != nil { + return "", "", errorsx.WithStack(err) + } -// ParseCustom parses, validates, and returns a token. The keyFunc will receive the parsed token and should -// return the key for validating. If everything is kosher, err will be nil. -func ParseCustom(tokenString string, keyFunc Keyfunc, algs ...jose.SignatureAlgorithm) (*Token, error) { - return ParseCustomWithClaims(tokenString, MapClaims{}, keyFunc, algs...) + if tokenString, err = token.CompactSerialize(); err != nil { + return "", "", errorsx.WithStack(err) + } + + return tokenString, signature, nil } -// ParseWithClaims is an overload for ParseCustomWithClaims which accepts all normal algs including 'none'. -func ParseWithClaims(rawToken string, claims MapClaims, keyFunc Keyfunc) (*Token, error) { - return ParseCustomWithClaims(rawToken, claims, keyFunc, consts.JSONWebTokenAlgNone, jose.HS256, jose.HS384, jose.HS512, jose.RS256, jose.RS384, jose.RS512, jose.PS256, jose.PS384, jose.PS512, jose.ES256, jose.ES384, jose.ES512) +func (t *Token) CompactSigned(k any) (tokenString, signature string, err error) { + if tokenString, err = t.CompactSignedString(k); err != nil { + return "", "", err + } + + if signature, err = getJWTSignature(tokenString); err != nil { + return "", "", err + } + + return tokenString, signature, nil } -// ParseCustomWithClaims parses, validates, and returns a token with its respective claims. The keyFunc will receive the parsed token and should -// return the key for validating. If everything is kosher, err will be nil. -func ParseCustomWithClaims(rawToken string, claims MapClaims, keyFunc Keyfunc, algs ...jose.SignatureAlgorithm) (*Token, error) { - // Parse the token. - parsedToken, err := jwt.ParseSigned(rawToken, algs) - if err != nil { - return &Token{}, &ValidationError{Errors: ValidationErrorMalformed, text: err.Error()} +// CompactSignedString provides a compatible `jwt-go` Token.CompactSigned method +// +// > Get the complete, signed token +func (t *Token) CompactSignedString(k any) (tokenString string, err error) { + if _, ok := k.(unsafeNoneMagicConstant); ok { + return unsignedToken(t) } - // fill unverified claims - // This conversion is required because go-jose supports - // only marshalling structs or maps but not alias types from maps - // - // The KeyFunc(*Token) function requires the claims to be set into the - // Token, that is an unverified token, therefore an UnsafeClaimsWithoutVerification is done first - // then with the returned key, the claims gets verified. - if err := parsedToken.UnsafeClaimsWithoutVerification(&claims); err != nil { - return nil, &ValidationError{Errors: ValidationErrorClaimsInvalid, text: err.Error()} + key := jose.SigningKey{ + Algorithm: t.SignatureAlgorithm, + Key: k, } - // creates an unsafe token - token, err := newToken(parsedToken, claims) - if err != nil { - return nil, err + opts := &jose.SignerOptions{ExtraHeaders: t.toSignedJoseHeader()} + + var signer jose.Signer + + if signer, err = jose.NewSigner(key, opts); err != nil { + return "", errorsx.WithStack(err) } - if keyFunc == nil { - return token, &ValidationError{Errors: ValidationErrorUnverifiable, text: "no Keyfunc was provided."} + // A explicit conversion from type alias MapClaims + // to map[string]any is required because the + // go-jose CompactSerialize() only support explicit maps + // as claims or structs but not type aliases from maps. + claims := map[string]any(t.Claims) + + if tokenString, err = jwt.Signed(signer).Claims(claims).Serialize(); err != nil { + return "", &ValidationError{Errors: ValidationErrorClaimsInvalid, Inner: err} } - // Call keyFunc callback to get verification key - verificationKey, err := keyFunc(token) - if err != nil { - // keyFunc returned an error - var ve *ValidationError + return tokenString, nil +} - if errors.As(err, &ve) { - return token, ve +func (t *Token) IsJWTProfileAccessToken() bool { + var ( + raw any + cty, typ string + ok bool + ) + + if t.HeaderJWE != nil && len(t.HeaderJWE) > 0 { + if raw, ok = t.HeaderJWE[consts.JSONWebTokenHeaderContentType]; ok { + cty, ok = raw.(string) + + if !ok { + return false + } + + if cty != consts.JSONWebTokenTypeAccessToken && cty != consts.JSONWebTokenTypeAccessTokenAlternative { + return false + } } + } - return token, &ValidationError{Errors: ValidationErrorUnverifiable, Inner: err} + if raw, ok = t.Header[consts.JSONWebTokenHeaderType]; !ok { + return false } - if verificationKey == nil { - return token, &ValidationError{Errors: ValidationErrorSignatureInvalid, text: "keyfunc returned a nil verification key"} + + typ, ok = raw.(string) + + return ok && (typ == consts.JSONWebTokenTypeAccessToken || typ == consts.JSONWebTokenTypeAccessTokenAlternative) +} + +func unsignedToken(token *Token) (tokenString string, err error) { + token.Header[consts.JSONWebTokenHeaderAlgorithm] = consts.JSONWebTokenAlgNone + + if _, ok := token.Header[consts.JSONWebTokenHeaderType]; !ok { + token.Header[consts.JSONWebTokenHeaderType] = consts.JSONWebTokenTypeJWT } - // To verify signature go-jose requires a pointer to - // public key instead of the public key value. - // The pointer values provides that pointer. - // E.g. transform rsa.PublicKey -> *rsa.PublicKey - verificationKey = pointer(verificationKey) - // verify signature with returned key - _, validNoneKey := verificationKey.(*unsafeNoneMagicConstant) - isSignedToken := !(token.Method == SigningMethodNone && validNoneKey) - if isSignedToken { - if err := parsedToken.Claims(verificationKey, &claims); err != nil { - return token, &ValidationError{Errors: ValidationErrorSignatureInvalid, text: err.Error()} - } + var ( + hbytes, bbytes []byte + ) + + if hbytes, err = json.Marshal(&token.Header); err != nil { + return "", errorsx.WithStack(err) } - // Validate claims - // This validation is performed to be backwards compatible - // with jwt-go library behavior - if err := claims.Valid(); err != nil { - if e, ok := err.(*ValidationError); !ok { - err = &ValidationError{Inner: e, Errors: ValidationErrorClaimsInvalid} - } - return token, err + if bbytes, err = json.Marshal(&token.Claims); err != nil { + return "", errorsx.WithStack(err) } - // set token as verified and validated - token.valid = true + return fmt.Sprintf("%s.%s.", base64.RawURLEncoding.EncodeToString(hbytes), base64.RawURLEncoding.EncodeToString(bbytes)), nil +} + +func newToken(parsedToken *jwt.JSONWebToken, claims MapClaims) (*Token, error) { + token := &Token{Claims: claims} + if len(parsedToken.Headers) != 1 { + return nil, &ValidationError{text: fmt.Sprintf("only one header supported, got %v", len(parsedToken.Headers)), Errors: ValidationErrorMalformed} + } + + // copy headers + h := parsedToken.Headers[0] + token.Header = map[string]any{ + consts.JSONWebTokenHeaderAlgorithm: h.Algorithm, + } + if h.KeyID != "" { + token.Header[consts.JSONWebTokenHeaderKeyIdentifier] = h.KeyID + token.KeyID = h.KeyID + } + + for k, v := range h.ExtraHeaders { + token.Header[string(k)] = v + } + + token.SignatureAlgorithm = jose.SignatureAlgorithm(h.Algorithm) + return token, nil } @@ -265,3 +424,25 @@ func pointer(v any) any { } return v } + +type PotentialTokenType int + +const ( + Unknown PotentialTokenType = iota + Opaque + SignedJWT + EncryptedJWT +) + +func GetPotentialTokenType(token string) PotentialTokenType { + switch strings.Count(token, ".") { + case 1: + return Opaque + case 2: + return SignedJWT + case 4: + return EncryptedJWT + default: + return Unknown + } +} diff --git a/token/jwt/token_test.go b/token/jwt/token_test.go index 81cb8fcc..3a0d0d4e 100644 --- a/token/jwt/token_test.go +++ b/token/jwt/token_test.go @@ -49,7 +49,7 @@ func TestUnsignedToken(t *testing.T) { "sub": "nestor", }) token.Header = tc.jwtHeaders - rawToken, err := token.SignedString(key) + rawToken, err := token.CompactSignedString(key) require.NoError(t, err) require.NotEmpty(t, rawToken) parts := strings.Split(rawToken, ".") @@ -72,12 +72,12 @@ func TestJWTHeaders(t *testing.T) { { name: "set JWT as 'typ' when the the type is not specified in the headers", jwtHeaders: map[string]any{}, - expectedType: JWTHeaderTypeValueJWT, + expectedType: consts.JSONWebTokenTypeJWT, }, { name: "'typ' set explicitly", - jwtHeaders: map[string]any{JWTHeaderKeyValueType: JWTHeaderTypeValueAccessTokenJWT}, - expectedType: JWTHeaderTypeValueAccessTokenJWT, + jwtHeaders: map[string]any{consts.JSONWebTokenHeaderType: consts.JSONWebTokenTypeAccessToken}, + expectedType: consts.JSONWebTokenTypeAccessToken, }, } for _, tc := range testCases { @@ -87,7 +87,7 @@ func TestJWTHeaders(t *testing.T) { require.NoError(t, err) require.Len(t, tk.Headers, 1) require.Equal(t, tk.Headers[0].Algorithm, "RS256") - require.Equal(t, tc.expectedType, tk.Headers[0].ExtraHeaders[(JWTHeaderKeyValueType)]) + require.Equal(t, tc.expectedType, tk.Headers[0].ExtraHeaders[(consts.JSONWebTokenHeaderType)]) }) } } @@ -453,7 +453,7 @@ func TestParser_Parse(t *testing.T) { func makeSampleToken(c MapClaims, m jose.SignatureAlgorithm, key any) string { token := NewWithClaims(m, c) - s, e := token.SignedString(key) + s, e := token.CompactSignedString(key) if e != nil { panic(e.Error()) @@ -465,7 +465,7 @@ func makeSampleToken(c MapClaims, m jose.SignatureAlgorithm, key any) string { func makeSampleTokenWithCustomHeaders(c MapClaims, m jose.SignatureAlgorithm, headers map[string]any, key any) string { token := NewWithClaims(m, c) token.Header = headers - s, e := token.SignedString(key) + s, e := token.CompactSignedString(key) if e != nil { panic(e.Error()) diff --git a/token/jwt/util.go b/token/jwt/util.go new file mode 100644 index 00000000..b1fd986c --- /dev/null +++ b/token/jwt/util.go @@ -0,0 +1,310 @@ +package jwt + +import ( + "context" + "crypto" + "fmt" + "regexp" + "strings" + + "github.com/go-jose/go-jose/v4" + "github.com/pkg/errors" + + "authelia.com/provider/oauth2/internal/consts" +) + +var ( + reSignedJWT = regexp.MustCompile(`^[-_A-Za-z0-9]+\.[-_A-Za-z0-9]+\.[-_A-Za-z0-9]+$`) + reEncryptedJWT = regexp.MustCompile(`^[-_A-Za-z0-9]+\.[-_A-Za-z0-9]+\.[-_A-Za-z0-9]+\.[-_A-Za-z0-9]+\.[-_A-Za-z0-9]+$`) +) + +// IsSignedJWT returns true if a given token string meets the basic criteria of a compact serialized signed JWT. +func IsSignedJWT(tokenString string) (signed bool) { + return reSignedJWT.MatchString(tokenString) +} + +// IsEncryptedJWT returns true if a given token string meets the basic criteria of a compact serialized encrypted JWT. +func IsEncryptedJWT(tokenString string) (encrypted bool) { + return reEncryptedJWT.MatchString(tokenString) +} + +// IsEncryptedJWTClientSecretAlg returns true if a given alg string is a client secret based algorithm i.e. symmetric. +func IsEncryptedJWTClientSecretAlg(alg string) (csa bool) { + switch a := jose.KeyAlgorithm(alg); a { + case jose.A128KW, jose.A192KW, jose.A256KW, jose.DIRECT, jose.A128GCMKW, jose.A192GCMKW, jose.A256GCMKW: + return true + default: + return IsEncryptedJWTPasswordBasedAlg(a) + } +} + +// IsEncryptedJWTPasswordBasedAlg returns true if a given jose.KeyAlgorithm is a Password Based Algorithm. +func IsEncryptedJWTPasswordBasedAlg(alg jose.KeyAlgorithm) (pba bool) { + switch alg { + case jose.PBES2_HS256_A128KW, jose.PBES2_HS384_A192KW, jose.PBES2_HS512_A256KW: + return true + default: + return false + } +} + +func headerValidateJWS(headers []jose.Header) (kid, alg string, err error) { + switch len(headers) { + case 1: + break + case 0: + return "", "", fmt.Errorf("jws header is missing") + default: + return "", "", fmt.Errorf("jws header is malformed") + } + + if headers[0].KeyID == "" { + return "", "", fmt.Errorf("jws header 'kid' value is missing or empty") + } + + if headers[0].Algorithm == "" { + return "", "", fmt.Errorf("jws header 'alg' value is missing or empty") + } + + if headers[0].JSONWebKey != nil { + return "", "", fmt.Errorf("jws header 'jwk' value is present but not supported") + } + + return headers[0].KeyID, headers[0].Algorithm, nil +} + +func headerValidateJWSNested(headers []jose.Header, cty string) (err error) { + switch len(headers) { + case 1: + break + case 0: + return fmt.Errorf("jws header is missing") + default: + return fmt.Errorf("jws header is malformed") + } + + typ, ok := headers[0].ExtraHeaders[consts.JSONWebTokenHeaderType] + if !ok { + return fmt.Errorf("jws header 'typ' value is missing") + } + + switch typ { + case "": + return fmt.Errorf("jws header 'typ' value is empty") + case cty: + return nil + default: + return fmt.Errorf("jws header 'typ' value '%s' is invalid: jwe header 'cty' value '%s' should match the jws header 'typ' value", typ, cty) + } +} + +func headerValidateJWE(header jose.Header) (kid, alg, enc, cty string, err error) { + if header.KeyID == "" && !IsEncryptedJWTClientSecretAlg(header.Algorithm) { + return "", "", "", "", fmt.Errorf("jwe header 'kid' value is missing or empty") + } + + if header.Algorithm == "" { + return "", "", "", "", fmt.Errorf("jwe header 'alg' value is missing or empty") + } + + var ( + value any + ok bool + ) + + if IsEncryptedJWTPasswordBasedAlg(jose.KeyAlgorithm(header.Algorithm)) { + if value, ok = header.ExtraHeaders[consts.JSONWebTokenHeaderPBES2Count]; ok { + switch p2c := value.(type) { + case float64: + if p2c > 5000000 { + return "", "", "", "", fmt.Errorf("jwe header 'p2c' has an invalid value '%d': more than 5,000,000", int(p2c)) + } else if p2c < 200000 { + return "", "", "", "", fmt.Errorf("jwe header 'p2c' has an invalid value '%d': less than 200,000", int(p2c)) + } + + default: + return "", "", "", "", fmt.Errorf("jwe header 'p2c' value has invalid type %T", p2c) + } + } + } + + if value, ok = header.ExtraHeaders[consts.JSONWebTokenHeaderEncryptionAlgorithm]; ok { + switch encv := value.(type) { + case string: + if encv != "" { + enc = encv + + break + } + + return "", "", "", "", fmt.Errorf("jwe header 'enc' value is empty") + default: + return "", "", "", "", fmt.Errorf("jwe header 'enc' value has invalid type %T", encv) + } + } + + if value, ok = header.ExtraHeaders[consts.JSONWebTokenHeaderContentType]; !ok { + return "", "", "", "", fmt.Errorf("jwe header 'cty' value is missing") + } else { + switch ctyv := value.(type) { + case string: + switch ctyv { + case consts.JSONWebTokenTypeJWT, consts.JSONWebTokenTypeAccessToken, consts.JSONWebTokenTypeAccessTokenAlternative, consts.JSONWebTokenTypeTokenIntrospection: + cty = ctyv + break + default: + return "", "", "", "", fmt.Errorf("jwe header 'cty' value '%s' is invalid", cty) + } + default: + return "", "", "", "", fmt.Errorf("jwe header 'cty' value has invalid type %T", cty) + } + } + + if header.JSONWebKey != nil { + return "", "", "", "", fmt.Errorf("jwe header 'jwk' value is present but not supported") + } + + return header.KeyID, header.Algorithm, enc, cty, nil +} + +// PrivateKey properly describes crypto.PrivateKey. +type PrivateKey interface { + Public() crypto.PublicKey + Equal(x crypto.PrivateKey) bool +} + +type JWKLookupError struct { + Description string +} + +func (e *JWKLookupError) GetDescription() string { + return e.Description +} + +func (e *JWKLookupError) Error() string { + return fmt.Sprintf("Error occurrered looking up JSON Web Key: %s", e.Description) +} + +// FindClientPublicJWK given a BaseClient, JWKSFetcherStrategy, and search parameters will return a *jose.JSONWebKey on +// a valid match. The *jose.JSONWebKey is guaranteed to match the alg and use values, and if strict is true it must +// match the kid value as well. +func FindClientPublicJWK(ctx context.Context, client BaseClient, fetcher JWKSFetcherStrategy, kid, alg, use string, strict bool) (key *jose.JSONWebKey, err error) { + if strict && kid == "" { + return nil, &JWKLookupError{Description: "The JSON Web Key strict search was attempted without a kid but the strict search doesn't permit this."} + } + + var ( + keys *jose.JSONWebKeySet + ) + + if keys = client.GetJSONWebKeys(); keys != nil { + return SearchJWKS(keys, kid, alg, use, strict) + } + + if location := client.GetJSONWebKeysURI(); len(location) > 0 { + if keys, err = fetcher.Resolve(ctx, location, false); err != nil { + return nil, err + } + + if key, err = SearchJWKS(keys, kid, alg, use, strict); err == nil { + return key, nil + } + + if keys, err = fetcher.Resolve(ctx, location, true); err != nil { + return nil, err + } + + return SearchJWKS(keys, kid, alg, use, strict) + } + + return nil, &JWKLookupError{Description: "No JWKs have been registered for the client."} +} + +func SearchJWKS(jwks *jose.JSONWebKeySet, kid, alg, use string, strict bool) (key *jose.JSONWebKey, err error) { + if len(jwks.Keys) == 0 { + return nil, &JWKLookupError{Description: "The retrieved JSON Web Key Set does not contain any key."} + } + + if strict && kid == "" { + return nil, &JWKLookupError{Description: "The JSON Web Key strict search was attempted without a kid but the strict search doesn't permit this."} + } + + var keys []jose.JSONWebKey + + if kid == "" { + keys = jwks.Keys + } else { + keys = jwks.Key(kid) + } + + if len(keys) == 0 { + return nil, &JWKLookupError{Description: fmt.Sprintf("The JSON Web Token uses signing key with kid '%s' which was not found.", kid)} + } + + for _, k := range keys { + if k.Use != use { + continue + } + + if k.Algorithm != alg { + continue + } + + return &k, nil + } + + return nil, &JWKLookupError{Description: fmt.Sprintf("Unable to find JSON web key with kid '%s', use '%s', and alg '%s' in JSON Web Key Set.", kid, use, alg)} +} + +func NewJWKFromClientSecret(ctx context.Context, client BaseClient, kid, alg, use string) (jwk *jose.JSONWebKey, err error) { + var secret []byte + + if secret, err = client.GetClientSecretPlainText(); err != nil { + return nil, err + } + + return &jose.JSONWebKey{ + Key: secret, + KeyID: kid, + Algorithm: alg, + Use: use, + }, nil +} + +func encodeCompactSigned(ctx context.Context, claims MapClaims, headers Mapper, key *jose.JSONWebKey) (tokenString string, signature string, err error) { + token := New() + + token.SetJWS(headers, claims, jose.SignatureAlgorithm(key.Algorithm)) + + return token.CompactSigned(key) +} + +func encodeNestedCompactEncrypted(ctx context.Context, claims MapClaims, headers, headersJWE Mapper, keySig, keyEnc *jose.JSONWebKey, enc jose.ContentEncryption) (tokenString string, signature string, err error) { + token := New() + + token.SetJWS(headers, claims, jose.SignatureAlgorithm(keySig.Algorithm)) + token.SetJWE(headersJWE, jose.KeyAlgorithm(keyEnc.Algorithm), enc, jose.NONE) + + return token.CompactEncrypted(keySig, keyEnc) +} + +func getJWTSignature(tokenString string) (signature string, err error) { + switch segments := strings.SplitN(tokenString, ".", 5); len(segments) { + case 5: + return "", errors.WithStack(errors.New("invalid token: the token is probably encrypted")) + case 3: + return segments[2], nil + default: + return "", errors.WithStack(fmt.Errorf("invalid token: the format is unknown")) + } +} + +func assign(a, b map[string]any) map[string]any { + for k, w := range b { + if _, ok := a[k]; ok { + continue + } + a[k] = w + } + return a +} From edd61871b1dd8d0d70da46a54e902f1041baa31c Mon Sep 17 00:00:00 2001 From: James Elliott Date: Wed, 4 Sep 2024 16:22:11 +1000 Subject: [PATCH 02/33] refactor: misc --- authorize_request_handler.go | 27 ++++++++++--- ...orize_request_handler_oidc_request_test.go | 4 +- token/jwt/jwt_strategy.go | 40 +++++++++++-------- token/jwt/util.go | 8 ++-- 4 files changed, 51 insertions(+), 28 deletions(-) diff --git a/authorize_request_handler.go b/authorize_request_handler.go index 46573d2a..cca3f587 100644 --- a/authorize_request_handler.go +++ b/authorize_request_handler.go @@ -19,12 +19,29 @@ import ( "authelia.com/provider/oauth2/x/errorsx" ) -func wrapSigningKeyFailure(outer *RFC6749Error, inner error) *RFC6749Error { - outer = outer.WithWrap(inner).WithDebugError(inner) +func wrapSigningKeyFailure(outer *RFC6749Error, inner *jwt.ValidationError) *RFC6749Error { + outer = outer.WithWrap(inner) if e := new(RFC6749Error); errors.As(inner, &e) { - return outer.WithHintf("%s %s", outer.Reason(), e.Reason()) + return outer.WithHintf("%s %s", outer.Reason(), e.Reason()).WithDebugError(inner) + } + + switch { + case inner.Has(jwt.ValidationErrorMalformed): + return outer.WithDebugf("The object is malformed. The following error occurred trying to validate the object: %s.", strings.TrimPrefix(inner.Error(), "go-jose/go-jose: ")) + case inner.Has(jwt.ValidationErrorUnverifiable | jwt.ValidationErrorSignatureInvalid): + return outer.WithDebugf("The object signature could not be verified. The following error occurred trying to verify the signature: %s.", strings.TrimPrefix(inner.Error(), "go-jose/go-jose: ")) + case inner.Has(jwt.ValidationErrorExpired): + return outer.WithDebugf("The object could not be verified. The object is expired.") + case inner.Has(jwt.ValidationErrorAudience | + jwt.ValidationErrorIssuedAt | + jwt.ValidationErrorIssuer | + jwt.ValidationErrorNotValidYet | + jwt.ValidationErrorId | + jwt.ValidationErrorClaimsInvalid): + return outer.WithDebugf("The object could not be verified. One or more claims were not expected. The following error occurred trying to validate the claims: %s.", strings.TrimPrefix(inner.Error(), "go-jose/go-jose: ")) + default: + return outer.WithDebugf("The object could not be verified. The following error occurred trying to validate the object: %s.", strings.TrimPrefix(inner.Error(), "go-jose/go-jose: ")) } - return outer } // TODO: Refactor time permitting. @@ -129,7 +146,7 @@ func (f *Fosite) authorizeRequestParametersFromOpenIDConnectRequestObject(ctx co if err != nil { var e *jwt.ValidationError if errors.As(err, &e) { - return wrapSigningKeyFailure(ErrInvalidRequestObject.WithHintf("The OAuth 2.0 client with id '%s' could not validate the request object.", client.GetID()), err) + return wrapSigningKeyFailure(ErrInvalidRequestObject.WithHintf("The OAuth 2.0 client with id '%s' could not validate the request object.", client.GetID()), e) } else { return errorsx.WithStack(ErrInvalidRequestObject.WithHintf("The OAuth 2.0 client with id '%s' could not validate the request object.", client.GetID()).WithDebugError(err)) } diff --git a/authorize_request_handler_oidc_request_test.go b/authorize_request_handler_oidc_request_test.go index ac618198..974d8495 100644 --- a/authorize_request_handler_oidc_request_test.go +++ b/authorize_request_handler_oidc_request_test.go @@ -187,7 +187,7 @@ func TestAuthorizeRequestParametersFromOpenIDConnectRequestObject(t *testing.T) client: &DefaultJARClient{JSONWebKeys: jwks, RequestObjectSigningAlg: "RS256", DefaultClient: &DefaultClient{ID: "foo"}}, expected: url.Values{consts.FormParameterScope: {consts.ScopeOpenID}}, err: ErrInvalidRequestObject, - errString: "The request parameter contains an invalid Request Object. OpenID Connect 1.0 request failed with an error attempting to validate the request object. The OAuth 2.0 client with id 'foo' provided a request object which failed to validate with error: go-jose/go-jose: compact JWS format must have three parts.", + errString: "The request parameter contains an invalid Request Object. The OAuth 2.0 client with id 'foo' could not validate the request object. The object is malformed. The following error occurred trying to validate the object: compact JWS format must have three parts.", }, { name: "ShouldFailUnknownKID", @@ -308,7 +308,7 @@ func TestAuthorizeRequestParametersFromOpenIDConnectRequestObject(t *testing.T) { name: "ShouldPassRequestAlgNone", have: url.Values{consts.FormParameterScope: {consts.ScopeOpenID}, consts.FormParameterClientID: {"foo"}, consts.FormParameterResponseType: {consts.ResponseTypeAuthorizationCodeFlow}, consts.FormParameterRequest: {assertionRequestObjectValidNone}}, - client: &DefaultJARClient{JSONWebKeysURI: root.JoinPath("jwks.json").String(), RequestObjectSigningAlg: consts.JSONWebTokenAlgNone}, + client: &DefaultJARClient{JSONWebKeysURI: root.JoinPath("jwks.json").String(), RequestObjectSigningAlg: consts.JSONWebTokenAlgNone, DefaultClient: &DefaultClient{ID: "foo"}}, expected: url.Values{consts.FormParameterState: {"some-state"}, consts.FormParameterClientID: {"foo"}, consts.FormParameterResponseType: {consts.ResponseTypeAuthorizationCodeFlow}, consts.FormParameterScope: {"foo openid"}, consts.FormParameterRequest: {assertionRequestObjectValidNone}, "foo": {"bar"}, "baz": {"baz"}}, }, { diff --git a/token/jwt/jwt_strategy.go b/token/jwt/jwt_strategy.go index b69d9692..e1d02a72 100644 --- a/token/jwt/jwt_strategy.go +++ b/token/jwt/jwt_strategy.go @@ -157,7 +157,7 @@ func (j *DefaultStrategy) Decode(ctx context.Context, tokenString string, opts . claims := MapClaims{} if err = t.UnsafeClaimsWithoutVerification(&claims); err != nil { - return nil, &ValidationError{Errors: ValidationErrorClaimsInvalid, Inner: err} + return nil, errorsx.WithStack(&ValidationError{Errors: ValidationErrorClaimsInvalid, Inner: err}) } var kid, alg string @@ -166,28 +166,34 @@ func (j *DefaultStrategy) Decode(ctx context.Context, tokenString string, opts . return nil, errorsx.WithStack(&ValidationError{Errors: ValidationErrorMalformed, Inner: err}) } - if o.jwsKeyFunc != nil { - if key, err = o.jwsKeyFunc(ctx, t, claims); err != nil { - return nil, errorsx.WithStack(&ValidationError{Errors: ValidationErrorUnverifiable, Inner: err}) - } - } else if o.client != nil && o.client.IsClientSigned() { - if ckid := o.client.GetSignatureKeyID(); ckid != "" && ckid != kid { - return nil, errorsx.WithStack(&ValidationError{Errors: ValidationErrorMalformed, Inner: fmt.Errorf("error validating the jws header: kid '%s' does not match the registered kid '%s'", kid, ckid)}) + if alg == consts.JSONWebTokenAlgNone { + if err = t.UnsafeClaimsWithoutVerification(&claims); err != nil { + return nil, errorsx.WithStack(&ValidationError{Errors: ValidationErrorSignatureInvalid, Inner: err}) } + } else { + if o.jwsKeyFunc != nil { + if key, err = o.jwsKeyFunc(ctx, t, claims); err != nil { + return nil, errorsx.WithStack(&ValidationError{Errors: ValidationErrorUnverifiable, Inner: err}) + } + } else if o.client != nil && o.client.IsClientSigned() { + if ckid := o.client.GetSignatureKeyID(); ckid != "" && ckid != kid { + return nil, errorsx.WithStack(&ValidationError{Errors: ValidationErrorMalformed, Inner: fmt.Errorf("error validating the jws header: kid '%s' does not match the registered kid '%s'", kid, ckid)}) + } - if calg := o.client.GetSignatureAlg(); calg != "" && calg != alg { - return nil, errorsx.WithStack(&ValidationError{Errors: ValidationErrorMalformed, Inner: fmt.Errorf("error validating the jws header: alg '%s' does not match the registered alg '%s'", alg, calg)}) - } + if calg := o.client.GetSignatureAlg(); calg != "" && calg != alg { + return nil, errorsx.WithStack(&ValidationError{Errors: ValidationErrorMalformed, Inner: fmt.Errorf("error validating the jws header: alg '%s' does not match the registered alg '%s'", alg, calg)}) + } - if key, err = FindClientPublicJWK(ctx, o.client, j.Config.GetJWKSFetcherStrategy(ctx), kid, alg, consts.JSONWebTokenUseSignature, true); err != nil { + if key, err = FindClientPublicJWK(ctx, o.client, j.Config.GetJWKSFetcherStrategy(ctx), kid, alg, consts.JSONWebTokenUseSignature, true); err != nil { + return nil, errorsx.WithStack(&ValidationError{Errors: ValidationErrorUnverifiable, Inner: err}) + } + } else if key, err = j.Issuer.GetIssuerStrictJWK(ctx, kid, alg, consts.JSONWebTokenUseSignature); err != nil { return nil, errorsx.WithStack(&ValidationError{Errors: ValidationErrorUnverifiable, Inner: err}) } - } else if key, err = j.Issuer.GetIssuerStrictJWK(ctx, kid, alg, consts.JSONWebTokenUseSignature); err != nil { - return nil, errorsx.WithStack(&ValidationError{Errors: ValidationErrorUnverifiable, Inner: err}) - } - if err = t.Claims(key.Public(), &claims); err != nil { - return nil, errorsx.WithStack(&ValidationError{Errors: ValidationErrorSignatureInvalid, Inner: err}) + if err = t.Claims(key.Public(), &claims); err != nil { + return nil, errorsx.WithStack(&ValidationError{Errors: ValidationErrorSignatureInvalid, Inner: err}) + } } if token, err = newToken(t, claims); err != nil { diff --git a/token/jwt/util.go b/token/jwt/util.go index b1fd986c..fc5e29fe 100644 --- a/token/jwt/util.go +++ b/token/jwt/util.go @@ -58,14 +58,14 @@ func headerValidateJWS(headers []jose.Header) (kid, alg string, err error) { return "", "", fmt.Errorf("jws header is malformed") } - if headers[0].KeyID == "" { - return "", "", fmt.Errorf("jws header 'kid' value is missing or empty") - } - if headers[0].Algorithm == "" { return "", "", fmt.Errorf("jws header 'alg' value is missing or empty") } + if headers[0].KeyID == "" && headers[0].Algorithm != consts.JSONWebTokenAlgNone { + return "", "", fmt.Errorf("jws header 'kid' value is missing or empty") + } + if headers[0].JSONWebKey != nil { return "", "", fmt.Errorf("jws header 'jwk' value is present but not supported") } From 1fd7ad629e998a94c9f6ad340ce4aac5d5bd50d8 Mon Sep 17 00:00:00 2001 From: James Elliott Date: Wed, 4 Sep 2024 22:35:36 +1000 Subject: [PATCH 03/33] temp --- authorize_request_handler.go | 103 +++++++++++------- ...orize_request_handler_oidc_request_test.go | 34 +++--- compose/compose.go | 3 +- token/jwt/claims_map.go | 34 ++++-- token/jwt/jwt_strategy.go | 84 ++++++++------ token/jwt/token.go | 32 ++---- token/jwt/util.go | 49 ++++++--- token/jwt/validation_error.go | 16 +-- 8 files changed, 205 insertions(+), 150 deletions(-) diff --git a/authorize_request_handler.go b/authorize_request_handler.go index cca3f587..7aea7bc2 100644 --- a/authorize_request_handler.go +++ b/authorize_request_handler.go @@ -19,31 +19,6 @@ import ( "authelia.com/provider/oauth2/x/errorsx" ) -func wrapSigningKeyFailure(outer *RFC6749Error, inner *jwt.ValidationError) *RFC6749Error { - outer = outer.WithWrap(inner) - if e := new(RFC6749Error); errors.As(inner, &e) { - return outer.WithHintf("%s %s", outer.Reason(), e.Reason()).WithDebugError(inner) - } - - switch { - case inner.Has(jwt.ValidationErrorMalformed): - return outer.WithDebugf("The object is malformed. The following error occurred trying to validate the object: %s.", strings.TrimPrefix(inner.Error(), "go-jose/go-jose: ")) - case inner.Has(jwt.ValidationErrorUnverifiable | jwt.ValidationErrorSignatureInvalid): - return outer.WithDebugf("The object signature could not be verified. The following error occurred trying to verify the signature: %s.", strings.TrimPrefix(inner.Error(), "go-jose/go-jose: ")) - case inner.Has(jwt.ValidationErrorExpired): - return outer.WithDebugf("The object could not be verified. The object is expired.") - case inner.Has(jwt.ValidationErrorAudience | - jwt.ValidationErrorIssuedAt | - jwt.ValidationErrorIssuer | - jwt.ValidationErrorNotValidYet | - jwt.ValidationErrorId | - jwt.ValidationErrorClaimsInvalid): - return outer.WithDebugf("The object could not be verified. One or more claims were not expected. The following error occurred trying to validate the claims: %s.", strings.TrimPrefix(inner.Error(), "go-jose/go-jose: ")) - default: - return outer.WithDebugf("The object could not be verified. The following error occurred trying to validate the object: %s.", strings.TrimPrefix(inner.Error(), "go-jose/go-jose: ")) - } -} - // TODO: Refactor time permitting. // //nolint:gocyclo @@ -144,30 +119,28 @@ func (f *Fosite) authorizeRequestParametersFromOpenIDConnectRequestObject(ctx co token, err := strategy.Decode(ctx, assertion, jwt.WithSigAlgorithm(jwt.SignatureAlgorithmsNone...), jwt.WithJARClient(client)) if err != nil { - var e *jwt.ValidationError - if errors.As(err, &e) { - return wrapSigningKeyFailure(ErrInvalidRequestObject.WithHintf("The OAuth 2.0 client with id '%s' could not validate the request object.", client.GetID()), e) - } else { - return errorsx.WithStack(ErrInvalidRequestObject.WithHintf("The OAuth 2.0 client with id '%s' could not validate the request object.", client.GetID()).WithDebugError(err)) - } + return errorsx.WithStack(wrapRequestObjectDecodeError(token, client, openid, err)) } if algAny { if token.SignatureAlgorithm == "none" { - - } - - if kid := client.GetRequestObjectSigningKeyID(); kid != "" && kid != token.KeyID { - + return errorsx.WithStack( + ErrInvalidRequestObject. + WithHintf("%s client provided a request object that has an invalid 'kid' or 'alg' header value.", hintRequestObjectPrefix(openid)). + WithDebugf("%s client with id '%s' was not explicitly registered with a 'request_object_signing_alg' value of 'none' but the request object had the 'alg' value 'none' in the header.", hintRequestObjectPrefix(openid), client.GetID())) } } else if string(token.SignatureAlgorithm) != alg { - + return errorsx.WithStack( + ErrInvalidRequestObject. + WithHintf("%s client provided a request object that has an invalid 'kid' or 'alg' header value.", hintRequestObjectPrefix(openid)). + WithDebugf("%s client with id '%s' was registered with a 'request_object_signing_alg' value of '%s' but the request object had the 'alg' value '%s' in the header.", hintRequestObjectPrefix(openid), client.GetID(), alg, token.SignatureAlgorithm)) } - if token.SignatureAlgorithm == "none" && !algNone { - print("") - } else if !algAny && string(token.SignatureAlgorithm) != client.GetRequestObjectSigningAlg() { - print("") + if kid := client.GetRequestObjectSigningKeyID(); kid != "" && kid != token.KeyID { + return errorsx.WithStack( + ErrInvalidRequestObject. + WithHintf("%s client provided a request object that has an invalid 'kid' or 'alg' header value.", hintRequestObjectPrefix(openid)). + WithDebugf("%s client with id '%s' was registered with a 'request_object_signing_key_id' value of '%s' but the request object had the 'kid' value '%s' in the header.", hintRequestObjectPrefix(openid), client.GetID(), kid, token.KeyID)) } claims := token.Claims @@ -581,3 +554,51 @@ func (f *Fosite) newAuthorizeRequest(ctx context.Context, r *http.Request, isPAR return request, nil } + +func wrapRequestObjectDecodeError(token *jwt.Token, client JARClient, openid bool, inner error) (outer *RFC6749Error) { + outer = ErrInvalidRequestObject.WithWrap(inner).WithHintf("%s request object could not be decoded or validated.", hintRequestObjectPrefix(openid)) + + if errJWTValidation := new(jwt.ValidationError); errors.As(inner, &errJWTValidation) { + switch { + case errJWTValidation.Has(jwt.ValidationErrorKeyIDInvalid): + return outer.WithDebugf("%s client with id '%s' expects request objects to be signed with the 'kid' value '%s' but the request object was signed with the 'kid' value '%s'.", hintRequestObjectPrefix(openid), client.GetID(), client.GetRequestObjectSigningKeyID(), token.KeyID) + case errJWTValidation.Has(jwt.ValidationErrorAlgorithmInvalid): + return outer.WithDebugf("%s client with id '%s' expects request objects to be signed with the 'alg' value '%s' but the request object was signed with the 'alg' value '%s'.", hintRequestObjectPrefix(openid), client.GetID(), client.GetRequestObjectSigningAlg(), token.SignatureAlgorithm) + case errJWTValidation.Has(jwt.ValidationErrorMalformed): + return outer.WithDebugf("%s client with id '%s' provided a request object that was malformed. The following error occurred trying to validate the request object: %s.", hintRequestObjectPrefix(openid), client.GetID(), strings.TrimPrefix(errJWTValidation.Error(), "go-jose/go-jose: ")) + case errJWTValidation.Has(jwt.ValidationErrorUnverifiable): + return outer.WithDebugf("%s client with id '%s' provided a request object that was not able to be verified. The following error occurred trying to validate the object: %s.", hintRequestObjectPrefix(openid), client.GetID(), strings.TrimPrefix(errJWTValidation.Error(), "go-jose/go-jose: ")) + case errJWTValidation.Has(jwt.ValidationErrorSignatureInvalid): + return outer.WithDebugf("%s client with id '%s' provided a request object that has an invalid signature. The following error occurred trying to validate the request object signature: %s.", hintRequestObjectPrefix(openid), client.GetID(), strings.TrimPrefix(errJWTValidation.Error(), "go-jose/go-jose: ")) + case errJWTValidation.Has(jwt.ValidationErrorExpired): + exp, ok := token.Claims.GetExpiresAt() + if ok { + return outer.WithDebugf("%s client with id '%s' provided a request object that was expired. The request object expired at %d.", hintRequestObjectPrefix(openid), client.GetID(), exp) + } else { + return outer.WithDebugf("%s client with id '%s' provided a request object that was expired. Error occurred trying to validate the 'exp' claim': %s", hintRequestObjectPrefix(openid), client.GetID(), strings.TrimPrefix(errJWTValidation.Error(), "go-jose/go-jose: ")) + } + case errJWTValidation.Has(jwt.ValidationErrorIssuedAt): + iat, ok := token.Claims.GetIssuedAt() + if ok { + return outer.WithDebugf("%s client with id '%s' provided a request object that was issued in the future. The request object was issued at %d.", hintRequestObjectPrefix(openid), client.GetID(), iat) + } else { + return outer.WithDebugf("%s client with id '%s' provided a request object that was issued in the future. Error occurred trying to validate the 'iat' claim: %s", hintRequestObjectPrefix(openid), client.GetID(), strings.TrimPrefix(errJWTValidation.Error(), "go-jose/go-jose: ")) + } + case errJWTValidation.Has(jwt.ValidationErrorNotValidYet): + nbf, ok := token.Claims.GetNotBefore() + if ok { + return outer.WithDebugf("%s client with id '%s' provided a request object that was issued in the future. The request object is not valid before %d.", hintRequestObjectPrefix(openid), client.GetID(), nbf) + } else { + return outer.WithDebugf("%s client with id '%s' provided a request object that was issued in the future. Error occurred trying to validate the 'nbf' claim: %s", hintRequestObjectPrefix(openid), client.GetID(), strings.TrimPrefix(errJWTValidation.Error(), "go-jose/go-jose: ")) + } + case errJWTValidation.Has(jwt.ValidationErrorClaimsInvalid): + return outer.WithDebugf("%s client with id '%s' provided a request object that had one or more invalid claims. Error occurred trying to validate the request objects claims: %s", hintRequestObjectPrefix(openid), client.GetID(), strings.TrimPrefix(errJWTValidation.Error(), "go-jose/go-jose: ")) + default: + return outer.WithDebugf("%s client with id '%s' provided a request object that could not be validated. Error occurred trying to validate the request object: %s", hintRequestObjectPrefix(openid), client.GetID(), strings.TrimPrefix(errJWTValidation.Error(), "go-jose/go-jose: ")) + } + } else if errJWKLookup := new(jwt.JWKLookupError); errors.As(inner, &errJWKLookup) { + return outer.WithDebugf("%s client with id '%s' provided a request object that could not be validated due to a key lookup error. %s", hintRequestObjectPrefix(openid), client.GetID(), errJWKLookup.Description) + } else { + return outer.WithDebugf("%s client with id '%s' provided a request object that could not be validated. %s", hintRequestObjectPrefix(openid), client.GetID(), ErrorToDebugRFC6749Error(inner).Error()) + } +} diff --git a/authorize_request_handler_oidc_request_test.go b/authorize_request_handler_oidc_request_test.go index 974d8495..af3e9cd3 100644 --- a/authorize_request_handler_oidc_request_test.go +++ b/authorize_request_handler_oidc_request_test.go @@ -187,7 +187,7 @@ func TestAuthorizeRequestParametersFromOpenIDConnectRequestObject(t *testing.T) client: &DefaultJARClient{JSONWebKeys: jwks, RequestObjectSigningAlg: "RS256", DefaultClient: &DefaultClient{ID: "foo"}}, expected: url.Values{consts.FormParameterScope: {consts.ScopeOpenID}}, err: ErrInvalidRequestObject, - errString: "The request parameter contains an invalid Request Object. The OAuth 2.0 client with id 'foo' could not validate the request object. The object is malformed. The following error occurred trying to validate the object: compact JWS format must have three parts.", + errString: "The request parameter contains an invalid Request Object. OpenID Connect 1.0 request object could not be decoded or validated. OpenID Connect 1.0 client with id 'foo' provided a request object that was malformed. The following error occurred trying to validate the request object: compact JWS format must have three parts.", }, { name: "ShouldFailUnknownKID", @@ -195,7 +195,7 @@ func TestAuthorizeRequestParametersFromOpenIDConnectRequestObject(t *testing.T) client: &DefaultJARClient{JSONWebKeys: jwks, RequestObjectSigningAlg: "RS256", DefaultClient: &DefaultClient{ID: "test"}}, expected: url.Values{consts.FormParameterScope: {consts.ScopeOpenID}}, err: ErrInvalidRequestObject, - errString: "The request parameter contains an invalid Request Object. Unable to retrieve RSA signing key from OAuth 2.0 Client. The JSON Web Token uses signing key with kid 'does-not-exists', which could not be found. The request is missing a required parameter, includes an invalid parameter value, includes a parameter more than once, or is otherwise malformed. The JSON Web Token uses signing key with kid 'does-not-exists', which could not be found.", + errString: "The request parameter contains an invalid Request Object. OpenID Connect 1.0 request object could not be decoded or validated. OpenID Connect 1.0 client with id 'test' provided a request object that was not able to be verified. The following error occurred trying to validate the object: Error occurred looking up JSON Web Key: The JSON Web Token uses signing key with kid 'does-not-exists' which was not found..", }, { name: "ShouldFailBadAlgRS256", @@ -203,7 +203,7 @@ func TestAuthorizeRequestParametersFromOpenIDConnectRequestObject(t *testing.T) client: &DefaultJARClient{JSONWebKeys: jwks, RequestObjectSigningAlg: "RS256", DefaultClient: &DefaultClient{ID: "test"}}, expected: url.Values{consts.FormParameterScope: {consts.ScopeOpenID}}, err: ErrInvalidRequestObject, - errString: "The request parameter contains an invalid Request Object. OpenID Connect 1.0 request failed with an error attempting to validate the request object. The OAuth 2.0 client with id 'test' expects request objects to be signed with the 'RS256' algorithm but the request object was signed with the 'HS256' algorithm.", + errString: "The request parameter contains an invalid Request Object. OpenID Connect 1.0 request object could not be decoded or validated. OpenID Connect 1.0 client with id 'test' expects request objects to be signed with the 'alg' value 'RS256' but the request object was signed with the 'alg' value 'HS256'.", }, { name: "ShouldFailMismatchedClientID", @@ -279,7 +279,7 @@ func TestAuthorizeRequestParametersFromOpenIDConnectRequestObject(t *testing.T) client: &DefaultJARClient{JSONWebKeysURI: root.JoinPath("jwks.json").String(), RequestObjectSigningAlg: "RS256", DefaultClient: &DefaultClient{ID: "foo"}}, expected: url.Values{consts.FormParameterState: {"some-state"}, consts.FormParameterClientID: {"foo"}, consts.FormParameterResponseType: {consts.ResponseTypeAuthorizationCodeFlow}, consts.FormParameterScope: {"foo openid"}, consts.FormParameterRequest: {assertionRequestObjectValidNone}, "foo": {"bar"}, "baz": {"baz"}}, err: ErrInvalidRequestObject, - errString: "The request parameter contains an invalid Request Object. OpenID Connect 1.0 request failed with an error attempting to validate the request object. The OAuth 2.0 client with id 'foo' expects request objects to be signed with the 'RS256' algorithm but the request object was signed with the 'none' algorithm.", + errString: "The request parameter contains an invalid Request Object. OpenID Connect 1.0 client provided a request object that has an invalid 'kid' or 'alg' header value. OpenID Connect 1.0 client with id 'foo' was registered with a 'request_object_signing_alg' value of 'RS256' but the request object had the 'alg' value 'none' in the header.", }, { name: "ShouldFailRequestURIAlgNone", @@ -287,7 +287,7 @@ func TestAuthorizeRequestParametersFromOpenIDConnectRequestObject(t *testing.T) client: &DefaultJARClient{JSONWebKeysURI: root.JoinPath("jwks.json").String(), RequestObjectSigningAlg: "RS256", RequestURIs: []string{root.JoinPath("request-object", "valid", "none.jwk").String()}, DefaultClient: &DefaultClient{ID: "foo"}}, expected: url.Values{consts.FormParameterResponseType: {"token"}, consts.FormParameterClientID: {"foo"}, consts.FormParameterState: {"some-state"}, consts.FormParameterScope: {"foo openid"}, consts.FormParameterRequestURI: {root.JoinPath("request-object", "valid", "none.jwk").String()}, "foo": {"bar"}, "baz": {"baz"}}, err: ErrInvalidRequestObject, - errString: "The request parameter contains an invalid Request Object. OpenID Connect 1.0 request failed with an error attempting to validate the request object. The OAuth 2.0 client with id 'foo' expects request objects to be signed with the 'RS256' algorithm but the request object was signed with the 'none' algorithm.", + errString: "The request parameter contains an invalid Request Object. OpenID Connect 1.0 client provided a request object that has an invalid 'kid' or 'alg' header value. OpenID Connect 1.0 client with id 'foo' was registered with a 'request_object_signing_alg' value of 'RS256' but the request object had the 'alg' value 'none' in the header.", }, { name: "ShouldFailRequestRS256", @@ -295,7 +295,7 @@ func TestAuthorizeRequestParametersFromOpenIDConnectRequestObject(t *testing.T) client: &DefaultJARClient{JSONWebKeysURI: root.JoinPath("jwks.json").String(), RequestObjectSigningAlg: consts.JSONWebTokenAlgNone, DefaultClient: &DefaultClient{ID: "foo"}}, expected: url.Values{consts.FormParameterState: {"some-state"}, consts.FormParameterClientID: {"foo"}, consts.FormParameterResponseType: {consts.ResponseTypeAuthorizationCodeFlow}, consts.FormParameterScope: {"foo openid"}, consts.FormParameterRequest: {assertionRequestObjectValid}, "foo": {"bar"}, "baz": {"baz"}}, err: ErrInvalidRequestObject, - errString: "The request parameter contains an invalid Request Object. OpenID Connect 1.0 request failed with an error attempting to validate the request object. The OAuth 2.0 client with id 'foo' expects request objects to be signed with the 'none' algorithm but the request object was signed with the 'RS256' algorithm.", + errString: "The request parameter contains an invalid Request Object. OpenID Connect 1.0 request object could not be decoded or validated. OpenID Connect 1.0 client with id 'foo' expects request objects to be signed with the 'alg' value 'none' but the request object was signed with the 'alg' value 'RS256'.", }, { name: "ShouldFailRequestURIRS256", @@ -303,7 +303,7 @@ func TestAuthorizeRequestParametersFromOpenIDConnectRequestObject(t *testing.T) client: &DefaultJARClient{JSONWebKeysURI: root.JoinPath("jwks.json").String(), RequestObjectSigningAlg: consts.JSONWebTokenAlgNone, RequestURIs: []string{root.JoinPath("request-object", "valid", "standard.jwk").String()}, DefaultClient: &DefaultClient{ID: "foo"}}, expected: url.Values{consts.FormParameterResponseType: {"token"}, consts.FormParameterClientID: {"foo"}, consts.FormParameterState: {"some-state"}, consts.FormParameterScope: {"foo openid"}, consts.FormParameterRequestURI: {root.JoinPath("request-object", "valid", "standard.jwk").String()}, "foo": {"bar"}, "baz": {"baz"}}, err: ErrInvalidRequestObject, - errString: "The request parameter contains an invalid Request Object. OpenID Connect 1.0 request failed with an error attempting to validate the request object. The OAuth 2.0 client with id 'foo' expects request objects to be signed with the 'none' algorithm but the request object was signed with the 'RS256' algorithm.", + errString: "The request parameter contains an invalid Request Object. OpenID Connect 1.0 request object could not be decoded or validated. OpenID Connect 1.0 client with id 'foo' expects request objects to be signed with the 'alg' value 'none' but the request object was signed with the 'alg' value 'RS256'.", }, { name: "ShouldPassRequestAlgNone", @@ -318,16 +318,20 @@ func TestAuthorizeRequestParametersFromOpenIDConnectRequestObject(t *testing.T) expected: url.Values{consts.FormParameterResponseType: {"token"}, consts.FormParameterClientID: {"foo"}, consts.FormParameterState: {"some-state"}, consts.FormParameterScope: {"foo openid"}, consts.FormParameterRequestURI: {root.JoinPath("request-object", "valid", "none.jwk").String()}, "foo": {"bar"}, "baz": {"baz"}}, }, { - name: "ShouldPassRequestAlgNoneAllowAny", - have: url.Values{consts.FormParameterScope: {consts.ScopeOpenID}, consts.FormParameterClientID: {"foo"}, consts.FormParameterResponseType: {consts.ResponseTypeAuthorizationCodeFlow}, consts.FormParameterRequest: {assertionRequestObjectValidNone}}, - client: &DefaultJARClient{JSONWebKeysURI: root.JoinPath("jwks.json").String()}, - expected: url.Values{consts.FormParameterState: {"some-state"}, consts.FormParameterClientID: {"foo"}, consts.FormParameterResponseType: {consts.ResponseTypeAuthorizationCodeFlow}, consts.FormParameterScope: {"foo openid"}, consts.FormParameterRequest: {assertionRequestObjectValidNone}, "foo": {"bar"}, "baz": {"baz"}}, + name: "ShouldPassRequestAlgNoneAllowAny", + have: url.Values{consts.FormParameterScope: {consts.ScopeOpenID}, consts.FormParameterClientID: {"foo"}, consts.FormParameterResponseType: {consts.ResponseTypeAuthorizationCodeFlow}, consts.FormParameterRequest: {assertionRequestObjectValidNone}}, + client: &DefaultJARClient{JSONWebKeysURI: root.JoinPath("jwks.json").String(), DefaultClient: &DefaultClient{ID: "foo"}}, + expected: url.Values{consts.FormParameterState: {"some-state"}, consts.FormParameterClientID: {"foo"}, consts.FormParameterResponseType: {consts.ResponseTypeAuthorizationCodeFlow}, consts.FormParameterScope: {"foo openid"}, consts.FormParameterRequest: {assertionRequestObjectValidNone}, "foo": {"bar"}, "baz": {"baz"}}, + err: ErrInvalidRequestObject, + errString: "The request parameter contains an invalid Request Object. OpenID Connect 1.0 client provided a request object that has an invalid 'kid' or 'alg' header value. OpenID Connect 1.0 client with id 'foo' was not explicitly registered with a 'request_object_signing_alg' value of 'none' but the request object had the 'alg' value 'none' in the header.", }, { - name: "ShouldPassRequestURIAlgNoneAllowAny", - have: url.Values{consts.FormParameterScope: {consts.ScopeOpenID}, consts.FormParameterClientID: {"foo"}, consts.FormParameterResponseType: {consts.ResponseTypeImplicitFlowToken}, consts.FormParameterRequestURI: {root.JoinPath("request-object", "valid", "none.jwk").String()}}, - client: &DefaultJARClient{JSONWebKeysURI: root.JoinPath("jwks.json").String(), RequestObjectSigningAlg: "", RequestURIs: []string{root.JoinPath("request-object", "valid", "none.jwk").String()}, DefaultClient: &DefaultClient{ID: "foo"}}, - expected: url.Values{consts.FormParameterResponseType: {"token"}, consts.FormParameterClientID: {"foo"}, consts.FormParameterState: {"some-state"}, consts.FormParameterScope: {"foo openid"}, consts.FormParameterRequestURI: {root.JoinPath("request-object", "valid", "none.jwk").String()}, "foo": {"bar"}, "baz": {"baz"}}, + name: "ShouldPassRequestURIAlgNoneAllowAny", + have: url.Values{consts.FormParameterScope: {consts.ScopeOpenID}, consts.FormParameterClientID: {"foo"}, consts.FormParameterResponseType: {consts.ResponseTypeImplicitFlowToken}, consts.FormParameterRequestURI: {root.JoinPath("request-object", "valid", "none.jwk").String()}}, + client: &DefaultJARClient{JSONWebKeysURI: root.JoinPath("jwks.json").String(), RequestObjectSigningAlg: "", RequestURIs: []string{root.JoinPath("request-object", "valid", "none.jwk").String()}, DefaultClient: &DefaultClient{ID: "foo"}}, + expected: url.Values{consts.FormParameterResponseType: {"token"}, consts.FormParameterClientID: {"foo"}, consts.FormParameterState: {"some-state"}, consts.FormParameterScope: {"foo openid"}, consts.FormParameterRequestURI: {root.JoinPath("request-object", "valid", "none.jwk").String()}, "foo": {"bar"}, "baz": {"baz"}}, + err: ErrInvalidRequestObject, + errString: "The request parameter contains an invalid Request Object. OpenID Connect 1.0 client provided a request object that has an invalid 'kid' or 'alg' header value. OpenID Connect 1.0 client with id 'foo' was not explicitly registered with a 'request_object_signing_alg' value of 'none' but the request object had the 'alg' value 'none' in the header.", }, { name: "ShouldFailRequestBadAudience", diff --git a/compose/compose.go b/compose/compose.go index 7596445b..dde1380c 100644 --- a/compose/compose.go +++ b/compose/compose.go @@ -4,9 +4,10 @@ package compose import ( + "context" + "authelia.com/provider/oauth2" "authelia.com/provider/oauth2/token/jwt" - "context" ) type Factory func(config oauth2.Configurator, storage any, strategy any) any diff --git a/token/jwt/claims_map.go b/token/jwt/claims_map.go index 85ed22de..b5b6b063 100644 --- a/token/jwt/claims_map.go +++ b/token/jwt/claims_map.go @@ -25,6 +25,14 @@ var TimeFunc = time.Now // This is the default claims type if you don't supply one type MapClaims map[string]any +// VerifyIssuer compares the iss claim against cmp. +// If required is false, this method will return true if the value matches or is unset +func (m MapClaims) VerifyIssuer(cmp string, req bool) bool { + iss, _ := m[consts.ClaimIssuer].(string) + + return verifyIss(iss, cmp, req) +} + // VerifyAudience compares the aud claim against cmp. // If required is false, this method will return true if the value matches or is unset func (m MapClaims) VerifyAudience(cmp string, req bool) bool { @@ -40,35 +48,45 @@ func (m MapClaims) VerifyAudience(cmp string, req bool) bool { return false } +// GetExpiresAt returns the exp claim. +func (m MapClaims) GetExpiresAt() (exp int64, ok bool) { + return m.toInt64(consts.ClaimExpirationTime) +} + // VerifyExpiresAt compares the exp claim against cmp. // If required is false, this method will return true if the value matches or is unset func (m MapClaims) VerifyExpiresAt(cmp int64, req bool) bool { - if v, ok := m.toInt64(consts.ClaimExpirationTime); ok { + if v, ok := m.GetExpiresAt(); ok { return verifyExp(v, cmp, req) } + return !req } +// GetIssuedAt returns the iat claim. +func (m MapClaims) GetIssuedAt() (exp int64, ok bool) { + return m.toInt64(consts.ClaimIssuedAt) +} + // VerifyIssuedAt compares the iat claim against cmp. // If required is false, this method will return true if the value matches or is unset func (m MapClaims) VerifyIssuedAt(cmp int64, req bool) bool { - if v, ok := m.toInt64(consts.ClaimIssuedAt); ok { + if v, ok := m.GetIssuedAt(); ok { return verifyIat(v, cmp, req) } + return !req } -// VerifyIssuer compares the iss claim against cmp. -// If required is false, this method will return true if the value matches or is unset -func (m MapClaims) VerifyIssuer(cmp string, req bool) bool { - iss, _ := m[consts.ClaimIssuer].(string) - return verifyIss(iss, cmp, req) +// GetNotBefore returns the nbf claim. +func (m MapClaims) GetNotBefore() (exp int64, ok bool) { + return m.toInt64(consts.ClaimNotBefore) } // VerifyNotBefore compares the nbf claim against cmp. // If required is false, this method will return true if the value matches or is unset func (m MapClaims) VerifyNotBefore(cmp int64, req bool) bool { - if v, ok := m.toInt64(consts.ClaimNotBefore); ok { + if v, ok := m.GetNotBefore(); ok { return verifyNbf(v, cmp, req) } diff --git a/token/jwt/jwt_strategy.go b/token/jwt/jwt_strategy.go index e1d02a72..76174b0b 100644 --- a/token/jwt/jwt_strategy.go +++ b/token/jwt/jwt_strategy.go @@ -98,7 +98,7 @@ func (j *DefaultStrategy) Decode(ctx context.Context, tokenString string, opts . for _, opt := range opts { if err = opt(o); err != nil { - return nil, errorsx.WithStack(err) + return token, errorsx.WithStack(err) } } @@ -110,7 +110,7 @@ func (j *DefaultStrategy) Decode(ctx context.Context, tokenString string, opts . if IsEncryptedJWT(tokenString) { if jwe, err = jose.ParseEncryptedCompact(tokenString, o.keyAlgorithm, o.contentEncryption); err != nil { - return nil, errorsx.WithStack(&ValidationError{Errors: ValidationErrorMalformed, Inner: err}) + return token, errorsx.WithStack(&ValidationError{Errors: ValidationErrorMalformed, Inner: err}) } var ( @@ -118,90 +118,104 @@ func (j *DefaultStrategy) Decode(ctx context.Context, tokenString string, opts . ) if kid, alg, _, cty, err = headerValidateJWE(jwe.Header); err != nil { - return nil, errorsx.WithStack(&ValidationError{Errors: ValidationErrorMalformed, Inner: err}) + return token, errorsx.WithStack(&ValidationError{Errors: ValidationErrorMalformed, Inner: err}) } if o.jweKeyFunc != nil { if key, err = o.jweKeyFunc(ctx, jwe, kid, alg); err != nil { - return nil, errorsx.WithStack(&ValidationError{Errors: ValidationErrorUnverifiable, Inner: err}) + return token, errorsx.WithStack(&ValidationError{Errors: ValidationErrorUnverifiable, Inner: err}) } } else if IsEncryptedJWTClientSecretAlg(alg) { if o.client == nil { - return nil, errorsx.WithStack(&ValidationError{Errors: ValidationErrorUnverifiable, Inner: err}) + return token, errorsx.WithStack(&ValidationError{Errors: ValidationErrorUnverifiable, Inner: err}) } if key, err = NewJWKFromClientSecret(ctx, o.client, kid, alg, consts.JSONWebTokenUseEncryption); err != nil { - return nil, errorsx.WithStack(&ValidationError{Errors: ValidationErrorUnverifiable, Inner: err}) + return token, errorsx.WithStack(&ValidationError{Errors: ValidationErrorUnverifiable, Inner: err}) } } else if key, err = j.Issuer.GetIssuerStrictJWK(ctx, kid, alg, consts.JSONWebTokenUseEncryption); err != nil { - return nil, errorsx.WithStack(&ValidationError{Errors: ValidationErrorUnverifiable, Inner: err}) + return token, errorsx.WithStack(&ValidationError{Errors: ValidationErrorUnverifiable, Inner: err}) } var rawJWT []byte if rawJWT, err = jwe.Decrypt(key); err != nil { - return nil, errorsx.WithStack(&ValidationError{Errors: ValidationErrorMalformed, Inner: err}) + return token, errorsx.WithStack(&ValidationError{Errors: ValidationErrorMalformed, Inner: err}) } if t, err = jwt.ParseSigned(string(rawJWT), o.sigAlgorithm); err != nil { - return nil, errorsx.WithStack(&ValidationError{Errors: ValidationErrorMalformed, Inner: err}) + return token, errorsx.WithStack(&ValidationError{Errors: ValidationErrorMalformed, Inner: err}) } if err = headerValidateJWSNested(t.Headers, cty); err != nil { - return nil, errorsx.WithStack(&ValidationError{Errors: ValidationErrorMalformed, Inner: err}) + return token, errorsx.WithStack(&ValidationError{Errors: ValidationErrorMalformed, Inner: err}) } } else if t, err = jwt.ParseSigned(tokenString, o.sigAlgorithm); err != nil { - return nil, errorsx.WithStack(&ValidationError{Errors: ValidationErrorMalformed, Inner: err}) + return token, errorsx.WithStack(&ValidationError{Errors: ValidationErrorMalformed, Inner: err}) } + if token, err = newToken(t, nil); err != nil { + return token, errorsx.WithStack(err) + } + + token.AssignJWE(jwe) + claims := MapClaims{} if err = t.UnsafeClaimsWithoutVerification(&claims); err != nil { - return nil, errorsx.WithStack(&ValidationError{Errors: ValidationErrorClaimsInvalid, Inner: err}) + return token, errorsx.WithStack(&ValidationError{Errors: ValidationErrorClaimsInvalid, Inner: err}) } + token.Claims = claims + var kid, alg string if kid, alg, err = headerValidateJWS(t.Headers); err != nil { - return nil, errorsx.WithStack(&ValidationError{Errors: ValidationErrorMalformed, Inner: err}) + return token, errorsx.WithStack(&ValidationError{Errors: ValidationErrorMalformed, Inner: err}) } - if alg == consts.JSONWebTokenAlgNone { - if err = t.UnsafeClaimsWithoutVerification(&claims); err != nil { - return nil, errorsx.WithStack(&ValidationError{Errors: ValidationErrorSignatureInvalid, Inner: err}) - } - } else { + if alg != consts.JSONWebTokenAlgNone { if o.jwsKeyFunc != nil { if key, err = o.jwsKeyFunc(ctx, t, claims); err != nil { - return nil, errorsx.WithStack(&ValidationError{Errors: ValidationErrorUnverifiable, Inner: err}) + return token, errorsx.WithStack(&ValidationError{Errors: ValidationErrorUnverifiable, Inner: err}) } } else if o.client != nil && o.client.IsClientSigned() { - if ckid := o.client.GetSignatureKeyID(); ckid != "" && ckid != kid { - return nil, errorsx.WithStack(&ValidationError{Errors: ValidationErrorMalformed, Inner: fmt.Errorf("error validating the jws header: kid '%s' does not match the registered kid '%s'", kid, ckid)}) - } + var ( + ckid, calg string + ) + + ckid = o.client.GetSignatureKeyID() - if calg := o.client.GetSignatureAlg(); calg != "" && calg != alg { - return nil, errorsx.WithStack(&ValidationError{Errors: ValidationErrorMalformed, Inner: fmt.Errorf("error validating the jws header: alg '%s' does not match the registered alg '%s'", alg, calg)}) + if calg = o.client.GetSignatureAlg(); calg != "" && calg != alg { + return token, errorsx.WithStack(&ValidationError{Errors: ValidationErrorAlgorithmInvalid, Inner: fmt.Errorf("error validating the jws header: alg '%s' does not match the registered alg '%s'", alg, calg)}) } - if key, err = FindClientPublicJWK(ctx, o.client, j.Config.GetJWKSFetcherStrategy(ctx), kid, alg, consts.JSONWebTokenUseSignature, true); err != nil { - return nil, errorsx.WithStack(&ValidationError{Errors: ValidationErrorUnverifiable, Inner: err}) + if IsSignedJWTClientSecretAlg(alg) { + if ckid != "" { + return token, errorsx.WithStack(&ValidationError{Errors: ValidationErrorKeyIDInvalid, Inner: fmt.Errorf("error validating the jws header: alg '%s' does not support tokens with a kid but the token has kid '%s'", alg, kid)}) + } + + if key, err = NewJWKFromClientSecret(ctx, o.client, "", alg, consts.JSONWebTokenUseSignature); err != nil { + return token, errorsx.WithStack(&ValidationError{Errors: ValidationErrorUnverifiable, Inner: err}) + } + } else { + if ckid != "" && ckid != kid { + return token, errorsx.WithStack(&ValidationError{Errors: ValidationErrorKeyIDInvalid, Inner: fmt.Errorf("error validating the jws header: kid '%s' does not match the registered kid '%s'", kid, ckid)}) + } + + if key, err = FindClientPublicJWK(ctx, o.client, j.Config.GetJWKSFetcherStrategy(ctx), kid, alg, consts.JSONWebTokenUseSignature, true); err != nil { + return token, errorsx.WithStack(&ValidationError{Errors: ValidationErrorUnverifiable, Inner: err}) + } } } else if key, err = j.Issuer.GetIssuerStrictJWK(ctx, kid, alg, consts.JSONWebTokenUseSignature); err != nil { - return nil, errorsx.WithStack(&ValidationError{Errors: ValidationErrorUnverifiable, Inner: err}) + return token, errorsx.WithStack(&ValidationError{Errors: ValidationErrorUnverifiable, Inner: err}) } if err = t.Claims(key.Public(), &claims); err != nil { - return nil, errorsx.WithStack(&ValidationError{Errors: ValidationErrorSignatureInvalid, Inner: err}) + return token, errorsx.WithStack(&ValidationError{Errors: ValidationErrorSignatureInvalid, Inner: err}) } } - if token, err = newToken(t, claims); err != nil { - return nil, errorsx.WithStack(err) - } - - token.AssignJWE(jwe) - if err = claims.Valid(); err != nil { return token, errorsx.WithStack(err) } @@ -286,5 +300,5 @@ func (j *DefaultStrategy) Decrypt(ctx context.Context, tokenStringEnc string, op return "", "", nil, errorsx.WithStack(&ValidationError{Errors: ValidationErrorMalformed, Inner: err}) } - return string(tokenRaw), signature, jwe, nil + return tokenString, signature, jwe, nil } diff --git a/token/jwt/token.go b/token/jwt/token.go index d76cbf57..7616e10b 100644 --- a/token/jwt/token.go +++ b/token/jwt/token.go @@ -9,7 +9,6 @@ import ( "errors" "fmt" "reflect" - "strings" "github.com/go-jose/go-jose/v4" "github.com/go-jose/go-jose/v4/jwt" @@ -18,6 +17,7 @@ import ( "authelia.com/provider/oauth2/x/errorsx" ) +// New returns a new Token. func New() *Token { return &Token{ Header: map[string]any{}, @@ -243,6 +243,8 @@ func (t *Token) AssignJWE(jwe *jose.JSONWebEncryption) { t.KeyAlgorithm = jose.KeyAlgorithm(jwe.Header.Algorithm) } +// CompactEncrypted serializes this token as a Compact Encrypted string, and returns the token string, signature, and +// an error if one occurred. func (t *Token) CompactEncrypted(keySig, keyEnc any) (tokenString, signature string, err error) { var ( signed string @@ -291,6 +293,8 @@ func (t *Token) CompactEncrypted(keySig, keyEnc any) (tokenString, signature str return tokenString, signature, nil } +// CompactSigned serializes this token as a Compact Signed string, and returns the token string, signature, and +// an error if one occurred. func (t *Token) CompactSigned(k any) (tokenString, signature string, err error) { if tokenString, err = t.CompactSignedString(k); err != nil { return "", "", err @@ -337,11 +341,11 @@ func (t *Token) CompactSignedString(k any) (tokenString string, err error) { return tokenString, nil } -func (t *Token) IsJWTProfileAccessToken() bool { +// IsJWTProfileAccessToken returns true if the token is a JWT Profile Access Token. +func (t *Token) IsJWTProfileAccessToken() (ok bool) { var ( raw any cty, typ string - ok bool ) if t.HeaderJWE != nil && len(t.HeaderJWE) > 0 { @@ -424,25 +428,3 @@ func pointer(v any) any { } return v } - -type PotentialTokenType int - -const ( - Unknown PotentialTokenType = iota - Opaque - SignedJWT - EncryptedJWT -) - -func GetPotentialTokenType(token string) PotentialTokenType { - switch strings.Count(token, ".") { - case 1: - return Opaque - case 2: - return SignedJWT - case 4: - return EncryptedJWT - default: - return Unknown - } -} diff --git a/token/jwt/util.go b/token/jwt/util.go index fc5e29fe..c8030d8c 100644 --- a/token/jwt/util.go +++ b/token/jwt/util.go @@ -28,7 +28,18 @@ func IsEncryptedJWT(tokenString string) (encrypted bool) { return reEncryptedJWT.MatchString(tokenString) } -// IsEncryptedJWTClientSecretAlg returns true if a given alg string is a client secret based algorithm i.e. symmetric. +// IsSignedJWTClientSecretAlg returns true if the given alg string is a client secret based signature algorithm. +func IsSignedJWTClientSecretAlg(alg string) (csa bool) { + switch a := jose.SignatureAlgorithm(alg); a { + case jose.HS256, jose.HS384, jose.HS512: + return true + default: + return false + } +} + +// IsEncryptedJWTClientSecretAlg returns true if a given alg string is a client secret based encryption algorithm +// i.e. symmetric. func IsEncryptedJWTClientSecretAlg(alg string) (csa bool) { switch a := jose.KeyAlgorithm(alg); a { case jose.A128KW, jose.A192KW, jose.A256KW, jose.DIRECT, jose.A128GCMKW, jose.A192GCMKW, jose.A256GCMKW: @@ -58,12 +69,8 @@ func headerValidateJWS(headers []jose.Header) (kid, alg string, err error) { return "", "", fmt.Errorf("jws header is malformed") } - if headers[0].Algorithm == "" { - return "", "", fmt.Errorf("jws header 'alg' value is missing or empty") - } - - if headers[0].KeyID == "" && headers[0].Algorithm != consts.JSONWebTokenAlgNone { - return "", "", fmt.Errorf("jws header 'kid' value is missing or empty") + if headers[0].Algorithm == "" && headers[0].KeyID == "" { + return "", "", fmt.Errorf("jws header 'alg' and 'kid' values are missing or empty") } if headers[0].JSONWebKey != nil { @@ -182,17 +189,13 @@ func (e *JWKLookupError) GetDescription() string { } func (e *JWKLookupError) Error() string { - return fmt.Sprintf("Error occurrered looking up JSON Web Key: %s", e.Description) + return fmt.Sprintf("Error occurred looking up JSON Web Key: %s", e.Description) } // FindClientPublicJWK given a BaseClient, JWKSFetcherStrategy, and search parameters will return a *jose.JSONWebKey on // a valid match. The *jose.JSONWebKey is guaranteed to match the alg and use values, and if strict is true it must // match the kid value as well. func FindClientPublicJWK(ctx context.Context, client BaseClient, fetcher JWKSFetcherStrategy, kid, alg, use string, strict bool) (key *jose.JSONWebKey, err error) { - if strict && kid == "" { - return nil, &JWKLookupError{Description: "The JSON Web Key strict search was attempted without a kid but the strict search doesn't permit this."} - } - var ( keys *jose.JSONWebKeySet ) @@ -225,10 +228,6 @@ func SearchJWKS(jwks *jose.JSONWebKeySet, kid, alg, use string, strict bool) (ke return nil, &JWKLookupError{Description: "The retrieved JSON Web Key Set does not contain any key."} } - if strict && kid == "" { - return nil, &JWKLookupError{Description: "The JSON Web Key strict search was attempted without a kid but the strict search doesn't permit this."} - } - var keys []jose.JSONWebKey if kid == "" { @@ -241,6 +240,8 @@ func SearchJWKS(jwks *jose.JSONWebKeySet, kid, alg, use string, strict bool) (ke return nil, &JWKLookupError{Description: fmt.Sprintf("The JSON Web Token uses signing key with kid '%s' which was not found.", kid)} } + var matched []jose.JSONWebKey + for _, k := range keys { if k.Use != use { continue @@ -250,12 +251,24 @@ func SearchJWKS(jwks *jose.JSONWebKeySet, kid, alg, use string, strict bool) (ke continue } - return &k, nil + matched = append(matched, k) } - return nil, &JWKLookupError{Description: fmt.Sprintf("Unable to find JSON web key with kid '%s', use '%s', and alg '%s' in JSON Web Key Set.", kid, use, alg)} + switch len(matched) { + case 1: + return &matched[0], nil + case 0: + return nil, &JWKLookupError{Description: fmt.Sprintf("Unable to find JSON web key with kid '%s', use '%s', and alg '%s' in JSON Web Key Set.", kid, use, alg)} + default: + if strict { + return nil, &JWKLookupError{Description: fmt.Sprintf("Unable to find JSON web key with kid '%s', use '%s', and alg '%s' in JSON Web Key Set.", kid, use, alg)} + } + + return &matched[0], nil + } } +// NewJWKFromClientSecret returns a JWK from a client secret. func NewJWKFromClientSecret(ctx context.Context, client BaseClient, kid, alg, use string) (jwk *jose.JSONWebKey, err error) { var secret []byte diff --git a/token/jwt/validation_error.go b/token/jwt/validation_error.go index 05a32432..30737f57 100644 --- a/token/jwt/validation_error.go +++ b/token/jwt/validation_error.go @@ -14,13 +14,15 @@ const ( ValidationErrorSignatureInvalid // Signature validation failed // Standard Claim validation errors - ValidationErrorAudience // AUD validation failed - ValidationErrorExpired // EXP validation failed - ValidationErrorIssuedAt // IAT validation failed - ValidationErrorIssuer // ISS validation failed - ValidationErrorNotValidYet // NBF validation failed - ValidationErrorId // JTI validation failed - ValidationErrorClaimsInvalid // Generic claims validation error + ValidationErrorAudience // AUD validation failed + ValidationErrorExpired // EXP validation failed + ValidationErrorIssuedAt // IAT validation failed + ValidationErrorIssuer // ISS validation failed + ValidationErrorNotValidYet // NBF validation failed + ValidationErrorId // JTI validation failed + ValidationErrorKeyIDInvalid // KeyID invalid error + ValidationErrorAlgorithmInvalid // Algorithm invalid error + ValidationErrorClaimsInvalid // Generic claims validation error ) // The error from Parse if token is not valid From d2ba93be26e13aac11564d24bee202636aa96d0d Mon Sep 17 00:00:00 2001 From: James Elliott Date: Fri, 13 Sep 2024 15:21:52 +1000 Subject: [PATCH 04/33] refactor: misc --- ...orize_request_handler_oidc_request_test.go | 4 +- token/jwt/jwt_strategy.go | 137 ++++++++++++------ token/jwt/jwt_strategy_opts.go | 46 +++--- token/jwt/token.go | 4 +- 4 files changed, 123 insertions(+), 68 deletions(-) diff --git a/authorize_request_handler_oidc_request_test.go b/authorize_request_handler_oidc_request_test.go index af3e9cd3..e6bdf7e2 100644 --- a/authorize_request_handler_oidc_request_test.go +++ b/authorize_request_handler_oidc_request_test.go @@ -295,7 +295,7 @@ func TestAuthorizeRequestParametersFromOpenIDConnectRequestObject(t *testing.T) client: &DefaultJARClient{JSONWebKeysURI: root.JoinPath("jwks.json").String(), RequestObjectSigningAlg: consts.JSONWebTokenAlgNone, DefaultClient: &DefaultClient{ID: "foo"}}, expected: url.Values{consts.FormParameterState: {"some-state"}, consts.FormParameterClientID: {"foo"}, consts.FormParameterResponseType: {consts.ResponseTypeAuthorizationCodeFlow}, consts.FormParameterScope: {"foo openid"}, consts.FormParameterRequest: {assertionRequestObjectValid}, "foo": {"bar"}, "baz": {"baz"}}, err: ErrInvalidRequestObject, - errString: "The request parameter contains an invalid Request Object. OpenID Connect 1.0 request object could not be decoded or validated. OpenID Connect 1.0 client with id 'foo' expects request objects to be signed with the 'alg' value 'none' but the request object was signed with the 'alg' value 'RS256'.", + errString: "The request parameter contains an invalid Request Object. OpenID Connect 1.0 client provided a request object that has an invalid 'kid' or 'alg' header value. OpenID Connect 1.0 client with id 'foo' was registered with a 'request_object_signing_alg' value of 'none' but the request object had the 'alg' value 'RS256' in the header.", }, { name: "ShouldFailRequestURIRS256", @@ -303,7 +303,7 @@ func TestAuthorizeRequestParametersFromOpenIDConnectRequestObject(t *testing.T) client: &DefaultJARClient{JSONWebKeysURI: root.JoinPath("jwks.json").String(), RequestObjectSigningAlg: consts.JSONWebTokenAlgNone, RequestURIs: []string{root.JoinPath("request-object", "valid", "standard.jwk").String()}, DefaultClient: &DefaultClient{ID: "foo"}}, expected: url.Values{consts.FormParameterResponseType: {"token"}, consts.FormParameterClientID: {"foo"}, consts.FormParameterState: {"some-state"}, consts.FormParameterScope: {"foo openid"}, consts.FormParameterRequestURI: {root.JoinPath("request-object", "valid", "standard.jwk").String()}, "foo": {"bar"}, "baz": {"baz"}}, err: ErrInvalidRequestObject, - errString: "The request parameter contains an invalid Request Object. OpenID Connect 1.0 request object could not be decoded or validated. OpenID Connect 1.0 client with id 'foo' expects request objects to be signed with the 'alg' value 'none' but the request object was signed with the 'alg' value 'RS256'.", + errString: "The request parameter contains an invalid Request Object. OpenID Connect 1.0 client provided a request object that has an invalid 'kid' or 'alg' header value. OpenID Connect 1.0 client with id 'foo' was registered with a 'request_object_signing_alg' value of 'none' but the request object had the 'alg' value 'RS256' in the header.", }, { name: "ShouldPassRequestAlgNone", diff --git a/token/jwt/jwt_strategy.go b/token/jwt/jwt_strategy.go index 76174b0b..c6ed2958 100644 --- a/token/jwt/jwt_strategy.go +++ b/token/jwt/jwt_strategy.go @@ -37,7 +37,7 @@ type DefaultStrategy struct { } func (j *DefaultStrategy) Encode(ctx context.Context, opts ...StrategyOpt) (tokenString string, signature string, err error) { - o := &optsStrategy{ + o := &StrategyOpts{ claims: MapClaims{}, headers: NewHeaders(), } @@ -87,8 +87,85 @@ func (j *DefaultStrategy) Encode(ctx context.Context, opts ...StrategyOpt) (toke return encodeNestedCompactEncrypted(ctx, o.claims, o.headers, o.headersJWE, keySig, keyEnc, jose.ContentEncryption(enc)) } +func (j *DefaultStrategy) Validate(ctx context.Context, token *Token, opts ...StrategyOpt) (err error) { + if token == nil { + return errorsx.WithStack(fmt.Errorf("token is nil")) + } + + if token.valid { + return nil + } + + if token.parsedToken == nil { + return errorsx.WithStack(fmt.Errorf("token is in an inconsistent state")) + } + + o := &StrategyOpts{ + sigAlgorithm: SignatureAlgorithms, + keyAlgorithm: EncryptionKeyAlgorithms, + contentEncryption: ContentEncryptionAlgorithms, + jwsKeyFunc: nil, + jweKeyFunc: nil, + } + + for _, opt := range opts { + if err = opt(o); err != nil { + return errorsx.WithStack(err) + } + } + + if err = j.validate(ctx, token.parsedToken, &MapClaims{}, o); err != nil { + return err + } + + token.valid = true + + return nil +} + +func (j *DefaultStrategy) validate(ctx context.Context, t *jwt.JSONWebToken, dest any, o *StrategyOpts) (err error) { + var ( + key *jose.JSONWebKey + kid, alg string + ) + + if kid, alg, err = headerValidateJWS(t.Headers); err != nil { + return errorsx.WithStack(&ValidationError{Errors: ValidationErrorMalformed, Inner: err}) + } + + claims := MapClaims{} + + if o.jwsKeyFunc != nil { + if key, err = o.jwsKeyFunc(ctx, t, claims); err != nil { + return errorsx.WithStack(&ValidationError{Errors: ValidationErrorUnverifiable, Inner: err}) + } + } else if o.client != nil && o.client.IsClientSigned() { + if IsSignedJWTClientSecretAlg(alg) { + if kid != "" { + return errorsx.WithStack(&ValidationError{Errors: ValidationErrorKeyIDInvalid, Inner: fmt.Errorf("error validating the jws header: alg '%s' does not support tokens with a kid but the token has kid '%s'", alg, kid)}) + } + + if key, err = NewJWKFromClientSecret(ctx, o.client, "", alg, consts.JSONWebTokenUseSignature); err != nil { + return errorsx.WithStack(&ValidationError{Errors: ValidationErrorUnverifiable, Inner: err}) + } + } else { + if key, err = FindClientPublicJWK(ctx, o.client, j.Config.GetJWKSFetcherStrategy(ctx), kid, alg, consts.JSONWebTokenUseSignature, true); err != nil { + return errorsx.WithStack(&ValidationError{Errors: ValidationErrorUnverifiable, Inner: err}) + } + } + } else if key, err = j.Issuer.GetIssuerStrictJWK(ctx, kid, alg, consts.JSONWebTokenUseSignature); err != nil { + return errorsx.WithStack(&ValidationError{Errors: ValidationErrorUnverifiable, Inner: err}) + } + + if err = t.Claims(key.Public(), &dest); err != nil { + return errorsx.WithStack(&ValidationError{Errors: ValidationErrorSignatureInvalid, Inner: err}) + } + + return nil +} + func (j *DefaultStrategy) Decode(ctx context.Context, tokenString string, opts ...StrategyOpt) (token *Token, err error) { - o := &optsStrategy{ + o := &StrategyOpts{ sigAlgorithm: SignatureAlgorithms, keyAlgorithm: EncryptionKeyAlgorithms, contentEncryption: ContentEncryptionAlgorithms, @@ -168,59 +245,25 @@ func (j *DefaultStrategy) Decode(ctx context.Context, tokenString string, opts . token.Claims = claims - var kid, alg string + var alg string - if kid, alg, err = headerValidateJWS(t.Headers); err != nil { + if _, alg, err = headerValidateJWS(t.Headers); err != nil { return token, errorsx.WithStack(&ValidationError{Errors: ValidationErrorMalformed, Inner: err}) } - if alg != consts.JSONWebTokenAlgNone { - if o.jwsKeyFunc != nil { - if key, err = o.jwsKeyFunc(ctx, t, claims); err != nil { - return token, errorsx.WithStack(&ValidationError{Errors: ValidationErrorUnverifiable, Inner: err}) - } - } else if o.client != nil && o.client.IsClientSigned() { - var ( - ckid, calg string - ) - - ckid = o.client.GetSignatureKeyID() - - if calg = o.client.GetSignatureAlg(); calg != "" && calg != alg { - return token, errorsx.WithStack(&ValidationError{Errors: ValidationErrorAlgorithmInvalid, Inner: fmt.Errorf("error validating the jws header: alg '%s' does not match the registered alg '%s'", alg, calg)}) - } - - if IsSignedJWTClientSecretAlg(alg) { - if ckid != "" { - return token, errorsx.WithStack(&ValidationError{Errors: ValidationErrorKeyIDInvalid, Inner: fmt.Errorf("error validating the jws header: alg '%s' does not support tokens with a kid but the token has kid '%s'", alg, kid)}) - } - - if key, err = NewJWKFromClientSecret(ctx, o.client, "", alg, consts.JSONWebTokenUseSignature); err != nil { - return token, errorsx.WithStack(&ValidationError{Errors: ValidationErrorUnverifiable, Inner: err}) - } - } else { - if ckid != "" && ckid != kid { - return token, errorsx.WithStack(&ValidationError{Errors: ValidationErrorKeyIDInvalid, Inner: fmt.Errorf("error validating the jws header: kid '%s' does not match the registered kid '%s'", kid, ckid)}) - } - - if key, err = FindClientPublicJWK(ctx, o.client, j.Config.GetJWKSFetcherStrategy(ctx), kid, alg, consts.JSONWebTokenUseSignature, true); err != nil { - return token, errorsx.WithStack(&ValidationError{Errors: ValidationErrorUnverifiable, Inner: err}) - } - } - } else if key, err = j.Issuer.GetIssuerStrictJWK(ctx, kid, alg, consts.JSONWebTokenUseSignature); err != nil { - return token, errorsx.WithStack(&ValidationError{Errors: ValidationErrorUnverifiable, Inner: err}) - } + validate := o.client != nil || !o.allowUnverified - if err = t.Claims(key.Public(), &claims); err != nil { - return token, errorsx.WithStack(&ValidationError{Errors: ValidationErrorSignatureInvalid, Inner: err}) + if alg != consts.JSONWebTokenAlgNone && validate { + if err = j.validate(ctx, t, &claims, o); err != nil { + return nil, errorsx.WithStack(err) } } - if err = claims.Valid(); err != nil { - return token, errorsx.WithStack(err) - } + //if err = claims.Valid(); err != nil { + // return token, errorsx.WithStack(err) + //} - token.valid = true + token.valid = validate return token, nil } @@ -234,7 +277,7 @@ func (j *DefaultStrategy) Decrypt(ctx context.Context, tokenStringEnc string, op } } - o := &optsStrategy{ + o := &StrategyOpts{ sigAlgorithm: SignatureAlgorithmsNone, keyAlgorithm: EncryptionKeyAlgorithms, contentEncryption: ContentEncryptionAlgorithms, diff --git a/token/jwt/jwt_strategy_opts.go b/token/jwt/jwt_strategy_opts.go index 6291bb50..1828fcc0 100644 --- a/token/jwt/jwt_strategy_opts.go +++ b/token/jwt/jwt_strategy_opts.go @@ -7,7 +7,7 @@ import ( "github.com/go-jose/go-jose/v4/jwt" ) -type optsStrategy struct { +type StrategyOpts struct { client Client claims MapClaims @@ -19,16 +19,26 @@ type optsStrategy struct { jwsKeyFunc KeyFuncJWS jweKeyFunc KeyFuncJWE + + allowUnverified bool } type ( KeyFuncJWS func(ctx context.Context, token *jwt.JSONWebToken, claims MapClaims) (jwk *jose.JSONWebKey, err error) KeyFuncJWE func(ctx context.Context, jwe *jose.JSONWebEncryption, kid, alg string) (jwk *jose.JSONWebKey, err error) - StrategyOpt func(opts *optsStrategy) (err error) + StrategyOpt func(opts *StrategyOpts) (err error) ) +func WithAllowUnverified() StrategyOpt { + return func(opts *StrategyOpts) (err error) { + opts.allowUnverified = true + + return nil + } +} + func WithHeaders(headers Mapper) StrategyOpt { - return func(opts *optsStrategy) (err error) { + return func(opts *StrategyOpts) (err error) { opts.headers = headers return nil @@ -36,7 +46,7 @@ func WithHeaders(headers Mapper) StrategyOpt { } func WithHeadersJWE(headers Mapper) StrategyOpt { - return func(opts *optsStrategy) (err error) { + return func(opts *StrategyOpts) (err error) { opts.headersJWE = headers return nil @@ -44,7 +54,7 @@ func WithHeadersJWE(headers Mapper) StrategyOpt { } func WithClaims(claims MapClaims) StrategyOpt { - return func(opts *optsStrategy) (err error) { + return func(opts *StrategyOpts) (err error) { opts.claims = claims return nil @@ -52,7 +62,7 @@ func WithClaims(claims MapClaims) StrategyOpt { } func WithClient(client Client) StrategyOpt { - return func(opts *optsStrategy) (err error) { + return func(opts *StrategyOpts) (err error) { opts.client = client return nil @@ -60,7 +70,7 @@ func WithClient(client Client) StrategyOpt { } func WithIDTokenClient(client any) StrategyOpt { - return func(opts *optsStrategy) (err error) { + return func(opts *StrategyOpts) (err error) { switch c := client.(type) { case IDTokenClient: opts.client = &decoratedIDTokenClient{IDTokenClient: c} @@ -71,7 +81,7 @@ func WithIDTokenClient(client any) StrategyOpt { } func WithUserInfoClient(client any) StrategyOpt { - return func(opts *optsStrategy) (err error) { + return func(opts *StrategyOpts) (err error) { switch c := client.(type) { case UserInfoClient: opts.client = &decoratedUserInfoClient{UserInfoClient: c} @@ -82,7 +92,7 @@ func WithUserInfoClient(client any) StrategyOpt { } func WithIntrospectionClient(client any) StrategyOpt { - return func(opts *optsStrategy) (err error) { + return func(opts *StrategyOpts) (err error) { switch c := client.(type) { case IntrospectionClient: opts.client = &decoratedIntrospectionClient{IntrospectionClient: c} @@ -93,7 +103,7 @@ func WithIntrospectionClient(client any) StrategyOpt { } func WithJARMClient(client any) StrategyOpt { - return func(opts *optsStrategy) (err error) { + return func(opts *StrategyOpts) (err error) { switch c := client.(type) { case JARMClient: opts.client = &decoratedJARMClient{JARMClient: c} @@ -104,7 +114,7 @@ func WithJARMClient(client any) StrategyOpt { } func WithJARClient(client any) StrategyOpt { - return func(opts *optsStrategy) (err error) { + return func(opts *StrategyOpts) (err error) { switch c := client.(type) { case JARClient: opts.client = &decoratedJARClient{JARClient: c} @@ -115,7 +125,7 @@ func WithJARClient(client any) StrategyOpt { } func WithJWTProfileAccessTokenClient(client any) StrategyOpt { - return func(opts *optsStrategy) (err error) { + return func(opts *StrategyOpts) (err error) { switch c := client.(type) { case JWTProfileAccessTokenClient: opts.client = &decoratedJWTProfileAccessTokenClient{JWTProfileAccessTokenClient: c} @@ -126,7 +136,7 @@ func WithJWTProfileAccessTokenClient(client any) StrategyOpt { } func WithNewStatelessJWTProfileIntrospectionClient(client any) StrategyOpt { - return func(opts *optsStrategy) (err error) { + return func(opts *StrategyOpts) (err error) { switch c := client.(type) { case IntrospectionClient: opts.client = &decoratedIntrospectionClient{IntrospectionClient: c} @@ -139,7 +149,7 @@ func WithNewStatelessJWTProfileIntrospectionClient(client any) StrategyOpt { } func WithSigAlgorithm(algs ...jose.SignatureAlgorithm) StrategyOpt { - return func(opts *optsStrategy) (err error) { + return func(opts *StrategyOpts) (err error) { opts.sigAlgorithm = algs return nil @@ -147,7 +157,7 @@ func WithSigAlgorithm(algs ...jose.SignatureAlgorithm) StrategyOpt { } func WithKeyAlgorithm(algs ...jose.KeyAlgorithm) StrategyOpt { - return func(opts *optsStrategy) (err error) { + return func(opts *StrategyOpts) (err error) { opts.keyAlgorithm = algs return nil @@ -155,7 +165,7 @@ func WithKeyAlgorithm(algs ...jose.KeyAlgorithm) StrategyOpt { } func WithContentEncryption(enc ...jose.ContentEncryption) StrategyOpt { - return func(opts *optsStrategy) (err error) { + return func(opts *StrategyOpts) (err error) { opts.contentEncryption = enc return nil @@ -163,7 +173,7 @@ func WithContentEncryption(enc ...jose.ContentEncryption) StrategyOpt { } func WithKeyFunc(f KeyFuncJWS) StrategyOpt { - return func(opts *optsStrategy) (err error) { + return func(opts *StrategyOpts) (err error) { opts.jwsKeyFunc = f return nil @@ -171,7 +181,7 @@ func WithKeyFunc(f KeyFuncJWS) StrategyOpt { } func WithKeyFuncJWE(f KeyFuncJWE) StrategyOpt { - return func(opts *optsStrategy) (err error) { + return func(opts *StrategyOpts) (err error) { opts.jweKeyFunc = f return nil diff --git a/token/jwt/token.go b/token/jwt/token.go index 7616e10b..65c69498 100644 --- a/token/jwt/token.go +++ b/token/jwt/token.go @@ -145,6 +145,8 @@ type Token struct { Claims MapClaims + parsedToken *jwt.JSONWebToken + valid bool } @@ -394,7 +396,7 @@ func unsignedToken(token *Token) (tokenString string, err error) { } func newToken(parsedToken *jwt.JSONWebToken, claims MapClaims) (*Token, error) { - token := &Token{Claims: claims} + token := &Token{Claims: claims, parsedToken: parsedToken} if len(parsedToken.Headers) != 1 { return nil, &ValidationError{text: fmt.Sprintf("only one header supported, got %v", len(parsedToken.Headers)), Errors: ValidationErrorMalformed} } From eb609a6738b0137ebd50528611514450c48a3126 Mon Sep 17 00:00:00 2001 From: James Elliott Date: Thu, 19 Sep 2024 16:29:57 +1000 Subject: [PATCH 05/33] refactor: misc --- client.go | 28 ++- client_authentication_secret.go | 3 +- token/jwt/claims_map.go | 339 +++++++++++++++++++++++++------- token/jwt/client.go | 13 +- token/jwt/client_test.go | 16 +- token/jwt/jwt_strategy.go | 8 +- token/jwt/util.go | 19 +- token/jwt/validation_error.go | 5 +- 8 files changed, 339 insertions(+), 92 deletions(-) diff --git a/client.go b/client.go index 34bd7d69..1d6bc4e0 100644 --- a/client.go +++ b/client.go @@ -5,7 +5,6 @@ package oauth2 import ( "context" - "fmt" "time" "github.com/go-jose/go-jose/v4" @@ -21,8 +20,17 @@ type Client interface { // GetClientSecret returns the ClientSecret. GetClientSecret() (secret ClientSecret) - // GetClientSecretPlainText returns the ClientSecret as plaintext if available. - GetClientSecretPlainText() (secret []byte, err error) + // GetClientSecretPlainText returns the ClientSecret as plaintext if available. The semantics of this function + // return values are important. + // If the client is not configured with a secret the return should be: + // - secret with value nil, ok with value false, and err with value of nil + // If the client is configured with a secret but is hashed or otherwise not a plaintext value: + // - secret with value nil, ok with value true, and err with value of nil + // If an error occurs retrieving the secret other than this: + // - secret with value nil, ok with value true, and err with value of the error + // If the plaintext secret is successful: + // - secret with value of the bytes of the plaintext secret, ok with value true, and err with value of nil + GetClientSecretPlainText() (secret []byte, ok bool, err error) // GetRedirectURIs returns the client's allowed redirect URIs. GetRedirectURIs() []string @@ -475,12 +483,20 @@ func (c *DefaultClient) GetClientSecret() (secret ClientSecret) { return c.ClientSecret } -func (c *DefaultClient) GetClientSecretPlainText() (secret []byte, err error) { +func (c *DefaultClient) GetClientSecretPlainText() (secret []byte, ok bool, err error) { if c.ClientSecret == nil || !c.ClientSecret.Valid() { - return nil, fmt.Errorf("this secret doesn't support plaintext") + return nil, false, nil } - return c.ClientSecret.GetPlainTextValue() + if !c.ClientSecret.IsPlainText() { + return nil, true, nil + } + + if secret, err = c.ClientSecret.GetPlainTextValue(); err != nil { + return nil, true, err + } + + return secret, true, nil } func (c *DefaultClient) GetRotatedClientSecrets() (secrets []ClientSecret) { diff --git a/client_authentication_secret.go b/client_authentication_secret.go index bcb33845..e84e4ceb 100644 --- a/client_authentication_secret.go +++ b/client_authentication_secret.go @@ -19,7 +19,8 @@ type ClientSecret interface { IsPlainText() (is bool) // GetPlainTextValue is a utility function to return the secret in the plaintext format making it usable for the - // client_secret_jwt authentication method. + // client_secret_jwt authentication method. If the client secret doesn't have a value that is plaintext it should + // return oauth2.ErrClientSecretNotPlainText for the sake of deterministic error values. GetPlainTextValue() (secret []byte, err error) // Valid should return false if the secret is nil or otherwise invalid. diff --git a/token/jwt/claims_map.go b/token/jwt/claims_map.go index b5b6b063..9c8aec2a 100644 --- a/token/jwt/claims_map.go +++ b/token/jwt/claims_map.go @@ -25,27 +25,97 @@ var TimeFunc = time.Now // This is the default claims type if you don't supply one type MapClaims map[string]any +// GetIssuer returns the iss claim. +func (m MapClaims) GetIssuer() (iss string, ok bool) { + var v any + + if v, ok = m[consts.ClaimIssuer]; !ok { + return "", false + } + + iss, ok = v.(string) + + return iss, ok +} + // VerifyIssuer compares the iss claim against cmp. // If required is false, this method will return true if the value matches or is unset -func (m MapClaims) VerifyIssuer(cmp string, req bool) bool { - iss, _ := m[consts.ClaimIssuer].(string) +func (m MapClaims) VerifyIssuer(cmp string, required bool) (ok bool) { + var iss string + + if iss, ok = m.GetIssuer(); !ok { + return !required + } + + return verifyMapString(iss, cmp, required) +} + +// GetSubject returns the sub claim. +func (m MapClaims) GetSubject() (sub string, ok bool) { + var v any + + if v, ok = m[consts.ClaimSubject]; !ok { + return "", false + } + + sub, ok = v.(string) + + return sub, ok +} + +// VerifySubject compares the iss claim against cmp. +// If required is false, this method will return true if the value matches or is unset +func (m MapClaims) VerifySubject(cmp string, required bool) (ok bool) { + var sub string - return verifyIss(iss, cmp, req) + if sub, ok = m.GetSubject(); !ok { + return !required + } + + return verifyMapString(sub, cmp, required) +} + +// GetAudience returns the aud claim. +func (m MapClaims) GetAudience() (aud []string, ok bool) { + return StringSliceFromMap(m[consts.ClaimAudience]) } // VerifyAudience compares the aud claim against cmp. // If required is false, this method will return true if the value matches or is unset -func (m MapClaims) VerifyAudience(cmp string, req bool) bool { - var ( - aud []string - ok bool - ) +func (m MapClaims) VerifyAudience(cmp string, required bool) (ok bool) { + var aud []string - if aud, ok = StringSliceFromMap(m[consts.ClaimAudience]); ok { - return verifyAud(aud, cmp, req) + if aud, ok = m.GetAudience(); !ok { + return !required } - return false + return verifyAud(aud, cmp, required) +} + +// VerifyAudienceAll compares the aud claim against cmp. +// If required is false, this method will return true if the value matches or is unset. +// This variant requires all of the audience values in the cmp. +func (m MapClaims) VerifyAudienceAll(cmp []string, required bool) (ok bool) { + var aud []string + + if aud, ok = m.GetAudience(); !ok { + return !required + } + + return verifyAudAll(aud, cmp, required) +} + +// VerifyAudienceAny compares the aud claim against cmp. +// If required is false, this method will return true if the value matches or is unset. +// This variant requires any of the audience values in the cmp. +func (m MapClaims) VerifyAudienceAny(cmp []string, required bool) (ok bool) { + var aud []string + + if aud, ok = m.GetAudience(); !ok { + return !required + } + + return verifyAudAny(aud, cmp, required) } // GetExpiresAt returns the exp claim. @@ -55,90 +125,114 @@ func (m MapClaims) GetExpiresAt() (exp int64, ok bool) { // VerifyExpiresAt compares the exp claim against cmp. // If required is false, this method will return true if the value matches or is unset -func (m MapClaims) VerifyExpiresAt(cmp int64, req bool) bool { - if v, ok := m.GetExpiresAt(); ok { - return verifyExp(v, cmp, req) +func (m MapClaims) VerifyExpiresAt(cmp int64, required bool) (ok bool) { + var exp int64 + + if exp, ok = m.GetExpiresAt(); !ok { + return !required } - return !req + return verifyExp(exp, cmp, required) } // GetIssuedAt returns the iat claim. -func (m MapClaims) GetIssuedAt() (exp int64, ok bool) { +func (m MapClaims) GetIssuedAt() (iat int64, ok bool) { return m.toInt64(consts.ClaimIssuedAt) } // VerifyIssuedAt compares the iat claim against cmp. // If required is false, this method will return true if the value matches or is unset -func (m MapClaims) VerifyIssuedAt(cmp int64, req bool) bool { - if v, ok := m.GetIssuedAt(); ok { - return verifyIat(v, cmp, req) +func (m MapClaims) VerifyIssuedAt(cmp int64, required bool) (ok bool) { + var iat int64 + + if iat, ok = m.GetIssuedAt(); !ok { + return !required } - return !req + return verifyInt64Past(iat, cmp, required) } // GetNotBefore returns the nbf claim. -func (m MapClaims) GetNotBefore() (exp int64, ok bool) { +func (m MapClaims) GetNotBefore() (nbf int64, ok bool) { return m.toInt64(consts.ClaimNotBefore) } // VerifyNotBefore compares the nbf claim against cmp. // If required is false, this method will return true if the value matches or is unset -func (m MapClaims) VerifyNotBefore(cmp int64, req bool) bool { - if v, ok := m.GetNotBefore(); ok { - return verifyNbf(v, cmp, req) - } +func (m MapClaims) VerifyNotBefore(cmp int64, required bool) (ok bool) { + var nbf int64 - return !req -} - -func (m MapClaims) toInt64(claim string) (int64, bool) { - switch t := m[claim].(type) { - case float64: - return int64(t), true - case int64: - return t, true - case json.Number: - v, err := t.Int64() - if err == nil { - return v, true - } - - vf, err := t.Float64() - if err != nil { - return 0, false - } - - return int64(vf), true + if nbf, ok = m.GetNotBefore(); !ok { + return !required } - return 0, false + return verifyInt64Past(nbf, cmp, required) } // Valid validates time based claims "exp, iat, nbf". // There is no accounting for clock skew. // As well, if any of the above claims are not in the token, it will still // be considered a valid claim. -func (m MapClaims) Valid() error { +func (m MapClaims) Valid(opts ...ValidationOpt) error { + vopts := &optsValidation{} + + for _, opt := range opts { + opt(vopts) + } + + var now int64 + + if vopts.timef != nil { + now = vopts.timef().Unix() + } else { + now = TimeFunc().Unix() + } + vErr := new(ValidationError) - now := TimeFunc().Unix() - if !m.VerifyExpiresAt(now, false) { + if !m.VerifyExpiresAt(now, vopts.expRequired) { vErr.Inner = errors.New("Token is expired") vErr.Errors |= ValidationErrorExpired } - if !m.VerifyIssuedAt(now, false) { + if !m.VerifyIssuedAt(now, vopts.iatRequired) { vErr.Inner = errors.New("Token used before issued") vErr.Errors |= ValidationErrorIssuedAt } - if !m.VerifyNotBefore(now, false) { + if !m.VerifyNotBefore(now, vopts.nbfRequired) { vErr.Inner = errors.New("Token is not valid yet") vErr.Errors |= ValidationErrorNotValidYet } + if len(vopts.iss) != 0 { + if !m.VerifyIssuer(vopts.iss, true) { + vErr.Inner = errors.New("Token has invalid issuer") + vErr.Errors |= ValidationErrorIssuer + } + } + + if len(vopts.sub) != 0 { + if !m.VerifySubject(vopts.sub, true) { + vErr.Inner = errors.New("Token has invalid subject") + vErr.Errors |= ValidationErrorSubject + } + } + + if len(vopts.aud) != 0 { + if !m.VerifyAudienceAny(vopts.aud, true) { + vErr.Inner = errors.New("Token has invalid audience") + vErr.Errors |= ValidationErrorAudience + } + } + + if len(vopts.audAll) != 0 { + if !m.VerifyAudienceAll(vopts.audAll, true) { + vErr.Inner = errors.New("Token has invalid audience") + vErr.Errors |= ValidationErrorAudience + } + } + if vErr.valid() { return nil } @@ -162,13 +256,99 @@ func (m MapClaims) UnmarshalJSON(b []byte) error { return nil } +func (m MapClaims) toInt64(claim string) (val int64, ok bool) { + var err error + + switch t := m[claim].(type) { + case float64: + return int64(t), true + case int64: + return t, true + case json.Number: + if val, err = t.Int64(); err == nil { + return val, true + } + + var valf float64 + + if valf, err = t.Float64(); err != nil { + return 0, false + } + + return int64(valf), true + } + + return 0, false +} + +type ValidationOpt func(opts *optsValidation) + +type optsValidation struct { + timef func() time.Time + iss string + aud []string + audAll []string + sub string + expRequired bool + iatRequired bool + nbfRequired bool +} + +func ValidateTimeFunc(timef func() time.Time) ValidationOpt { + return func(opts *optsValidation) { + opts.timef = timef + } +} + +func ValidateIssuer(iss string) ValidationOpt { + return func(opts *optsValidation) { + opts.iss = iss + } +} + +func ValidateAudienceAny(aud ...string) ValidationOpt { + return func(opts *optsValidation) { + opts.aud = aud + } +} + +func ValidateAudienceAll(aud ...string) ValidationOpt { + return func(opts *optsValidation) { + opts.audAll = aud + } +} + +func ValidateSubject(sub string) ValidationOpt { + return func(opts *optsValidation) { + opts.sub = sub + } +} + +func ValidateRequireExpiresAt() ValidationOpt { + return func(opts *optsValidation) { + opts.expRequired = true + } +} + +func ValidateRequireIssuedAt() ValidationOpt { + return func(opts *optsValidation) { + opts.iatRequired = true + } +} + +func ValidateRequireNotBefore() ValidationOpt { + return func(opts *optsValidation) { + opts.nbfRequired = true + } +} + func verifyAud(aud []string, cmp string, required bool) bool { if len(aud) == 0 { return !required } for _, a := range aud { - if subtle.ConstantTimeCompare([]byte(a), []byte(cmp)) != 0 { + if subtle.ConstantTimeCompare([]byte(a), []byte(cmp)) == 1 { return true } } @@ -176,38 +356,61 @@ func verifyAud(aud []string, cmp string, required bool) bool { return false } -func verifyExp(exp int64, now int64, required bool) bool { - if exp == 0 { +func verifyAudAny(aud []string, cmp []string, required bool) bool { + if len(aud) == 0 { return !required } - return now <= exp + for _, c := range cmp { + for _, a := range aud { + if subtle.ConstantTimeCompare([]byte(a), []byte(c)) == 1 { + return true + } + } + } + + return false } -func verifyIat(iat int64, now int64, required bool) bool { - if iat == 0 { +func verifyAudAll(aud []string, cmp []string, required bool) bool { + if len(aud) == 0 { return !required } - return now >= iat +outer: + for _, c := range cmp { + for _, a := range aud { + if subtle.ConstantTimeCompare([]byte(a), []byte(c)) == 1 { + continue outer + } + } + + return false + } + + return true } -func verifyIss(iss string, cmp string, required bool) bool { - if iss == "" { +func verifyExp(exp int64, now int64, required bool) bool { + if exp == 0 { return !required } - if subtle.ConstantTimeCompare([]byte(iss), []byte(cmp)) != 0 { - return true - } else { - return false + return now <= exp +} + +func verifyInt64Past(iat int64, now int64, required bool) bool { + if iat == 0 { + return !required } + + return now >= iat } -func verifyNbf(nbf int64, now int64, required bool) bool { - if nbf == 0 { +func verifyMapString(iss string, cmp string, required bool) bool { + if iss == "" { return !required } - return now >= nbf + return subtle.ConstantTimeCompare([]byte(iss), []byte(cmp)) == 1 } diff --git a/token/jwt/client.go b/token/jwt/client.go index 5f723492..3008117a 100644 --- a/token/jwt/client.go +++ b/token/jwt/client.go @@ -82,8 +82,17 @@ type Client interface { } type BaseClient interface { - // GetClientSecretPlainText returns the ClientSecret as plaintext if available. - GetClientSecretPlainText() (secret []byte, err error) + // GetClientSecretPlainText returns the ClientSecret as plaintext if available. The semantics of this function + // return values are important. + // If the client is not configured with a secret the return should be: + // - secret with value nil, ok with value false, and err with value of nil + // If the client is configured with a secret but is hashed or otherwise not a plaintext value: + // - secret with value nil, ok with value true, and err with value of nil + // If an error occurs retrieving the secret other than this: + // - secret with value nil, ok with value true, and err with value of the error + // If the plaintext secret is successful: + // - secret with value of the bytes of the plaintext secret, ok with value true, and err with value of nil + GetClientSecretPlainText() (secret []byte, ok bool, err error) // GetJSONWebKeys returns the JSON Web Key Set containing the public key used by the client to authenticate. GetJSONWebKeys() (jwks *jose.JSONWebKeySet) diff --git a/token/jwt/client_test.go b/token/jwt/client_test.go index 4ae0be76..29d3f7c6 100644 --- a/token/jwt/client_test.go +++ b/token/jwt/client_test.go @@ -8,6 +8,8 @@ import ( type testClient struct { secret []byte + secretNotPlainText bool + secretNotDefined bool kid, alg string encKID, encAlg, enc string csigned bool @@ -15,12 +17,20 @@ type testClient struct { jwksURI string } -func (r *testClient) GetClientSecretPlainText() (secret []byte, err error) { +func (r *testClient) GetClientSecretPlainText() (secret []byte, ok bool, err error) { + if r.secretNotDefined { + return nil, false, nil + } + + if r.secretNotPlainText { + return nil, true, nil + } + if r.secret != nil { - return r.secret, nil + return r.secret, true, nil } - return nil, fmt.Errorf("not supported") + return nil, true, fmt.Errorf("not supported") } func (r *testClient) GetSignatureKeyID() (kid string) { diff --git a/token/jwt/jwt_strategy.go b/token/jwt/jwt_strategy.go index c6ed2958..adadb29d 100644 --- a/token/jwt/jwt_strategy.go +++ b/token/jwt/jwt_strategy.go @@ -78,10 +78,10 @@ func (j *DefaultStrategy) Encode(ctx context.Context, opts ...StrategyOpt) (toke if IsEncryptedJWTClientSecretAlg(alg) { if keyEnc, err = NewJWKFromClientSecret(ctx, o.client, kid, alg, consts.JSONWebTokenUseEncryption); err != nil { - return "", "", errorsx.WithStack(fmt.Errorf("error occurred retrieving issuer jwk: error occurred retrieving the client secret: %w", err)) + return "", "", errorsx.WithStack(fmt.Errorf("Failed to encrypt the JWT using the client secret. %w", err)) } } else if keyEnc, err = FindClientPublicJWK(ctx, o.client, j.Config.GetJWKSFetcherStrategy(ctx), kid, alg, consts.JSONWebTokenUseEncryption, false); err != nil { - return "", "", errorsx.WithStack(fmt.Errorf("error occurred retrieving client jwk: %w", err)) + return "", "", errorsx.WithStack(fmt.Errorf("Failed to encrypt the JWT using the client configuration. %w", err)) } return encodeNestedCompactEncrypted(ctx, o.claims, o.headers, o.headersJWE, keySig, keyEnc, jose.ContentEncryption(enc)) @@ -259,10 +259,6 @@ func (j *DefaultStrategy) Decode(ctx context.Context, tokenString string, opts . } } - //if err = claims.Valid(); err != nil { - // return token, errorsx.WithStack(err) - //} - token.valid = validate return token, nil diff --git a/token/jwt/util.go b/token/jwt/util.go index c8030d8c..81cca2fa 100644 --- a/token/jwt/util.go +++ b/token/jwt/util.go @@ -189,7 +189,7 @@ func (e *JWKLookupError) GetDescription() string { } func (e *JWKLookupError) Error() string { - return fmt.Sprintf("Error occurred looking up JSON Web Key: %s", e.Description) + return fmt.Sprintf("Error occurred retriving the JSON Web Key. %s", e.Description) } // FindClientPublicJWK given a BaseClient, JWKSFetcherStrategy, and search parameters will return a *jose.JSONWebKey on @@ -270,10 +270,21 @@ func SearchJWKS(jwks *jose.JSONWebKeySet, kid, alg, use string, strict bool) (ke // NewJWKFromClientSecret returns a JWK from a client secret. func NewJWKFromClientSecret(ctx context.Context, client BaseClient, kid, alg, use string) (jwk *jose.JSONWebKey, err error) { - var secret []byte + var ( + secret []byte + ok bool + ) + + if secret, ok, err = client.GetClientSecretPlainText(); err != nil { + return nil, &JWKLookupError{Description: fmt.Sprintf("The client returned an error while trying to retrieve the plaintext client secret: %s.", err.Error())} + } + + if !ok { + return nil, &JWKLookupError{Description: "The client is not configured with a client secret."} + } - if secret, err = client.GetClientSecretPlainText(); err != nil { - return nil, err + if len(secret) == 0 { + return nil, &JWKLookupError{Description: "The client is not configured with a client secret that can be used for symmetric algorithms."} } return &jose.JSONWebKey{ diff --git a/token/jwt/validation_error.go b/token/jwt/validation_error.go index 30737f57..124f7903 100644 --- a/token/jwt/validation_error.go +++ b/token/jwt/validation_error.go @@ -14,12 +14,13 @@ const ( ValidationErrorSignatureInvalid // Signature validation failed // Standard Claim validation errors + ValidationErrorId // JTI validation failed ValidationErrorAudience // AUD validation failed ValidationErrorExpired // EXP validation failed ValidationErrorIssuedAt // IAT validation failed - ValidationErrorIssuer // ISS validation failed ValidationErrorNotValidYet // NBF validation failed - ValidationErrorId // JTI validation failed + ValidationErrorIssuer // ISS validation failed + ValidationErrorSubject // ISS validation failed ValidationErrorKeyIDInvalid // KeyID invalid error ValidationErrorAlgorithmInvalid // Algorithm invalid error ValidationErrorClaimsInvalid // Generic claims validation error From e00c898efaf76721a8ef8a8f407874415c5b123e Mon Sep 17 00:00:00 2001 From: James Elliott Date: Thu, 19 Sep 2024 22:42:15 +1000 Subject: [PATCH 06/33] temp --- client_authentication_strategy.go | 24 +-- token/jwt/token.go | 9 +- token/jwt/validation_error.go | 19 +-- token/jwt/validator.go | 241 ++++++++++++++++++++++++++++++ 4 files changed, 271 insertions(+), 22 deletions(-) create mode 100644 token/jwt/validator.go diff --git a/client_authentication_strategy.go b/client_authentication_strategy.go index 9f63f7d3..2b8541ea 100644 --- a/client_authentication_strategy.go +++ b/client_authentication_strategy.go @@ -134,7 +134,8 @@ func (s *DefaultClientAuthenticationStrategy) authenticate(ctx context.Context, return client, method, nil } -func NewClientAssertion(ctx context.Context, store ClientManager, raw, assertionType string, resolver EndpointClientAuthHandler) (assertion *ClientAssertion, err error) { +// NewClientAssertion converts a raw assertion string into a *ClientAssertion. +func NewClientAssertion(ctx context.Context, store ClientManager, assertion, assertionType string, resolver EndpointClientAuthHandler) (a *ClientAssertion, err error) { var ( token *xjwt.Token @@ -144,25 +145,25 @@ func NewClientAssertion(ctx context.Context, store ClientManager, raw, assertion switch assertionType { case consts.ClientAssertionTypeJWTBearer: - if len(raw) == 0 { - return &ClientAssertion{Raw: raw, Type: assertionType}, errorsx.WithStack(ErrInvalidRequest.WithHintf("The request parameter 'client_assertion' must be set when using 'client_assertion_type' of '%s'.", consts.ClientAssertionTypeJWTBearer)) + if len(assertion) == 0 { + return &ClientAssertion{Assertion: assertion, Type: assertionType}, errorsx.WithStack(ErrInvalidRequest.WithHintf("The request parameter 'client_assertion' must be set when using 'client_assertion_type' of '%s'.", consts.ClientAssertionTypeJWTBearer)) } default: - return &ClientAssertion{Raw: raw, Type: assertionType}, errorsx.WithStack(ErrInvalidRequest.WithHintf("Unknown client_assertion_type '%s'.", assertionType)) + return &ClientAssertion{Assertion: assertion, Type: assertionType}, errorsx.WithStack(ErrInvalidRequest.WithHintf("Unknown client_assertion_type '%s'.", assertionType)) } - if token, _, err = xjwt.NewParser(xjwt.WithoutClaimsValidation()).ParseUnverified(raw, &xjwt.MapClaims{}); err != nil { - return &ClientAssertion{Raw: raw, Type: assertionType}, resolveJWTErrorToRFCError(err) + if token, _, err = xjwt.NewParser(xjwt.WithoutClaimsValidation()).ParseUnverified(assertion, &xjwt.MapClaims{}); err != nil { + return &ClientAssertion{Assertion: assertion, Type: assertionType}, resolveJWTErrorToRFCError(err) } if id, err = token.Claims.GetSubject(); err != nil { if id, err = token.Claims.GetIssuer(); err != nil { - return &ClientAssertion{Raw: raw, Type: assertionType}, nil + return &ClientAssertion{Assertion: assertion, Type: assertionType}, nil } } if client, err = store.GetClient(ctx, id); err != nil { - return &ClientAssertion{Raw: raw, Type: assertionType, ID: id}, nil + return &ClientAssertion{Assertion: assertion, Type: assertionType, ID: id}, nil } if c, ok := client.(AuthenticationMethodClient); ok { @@ -170,7 +171,7 @@ func NewClientAssertion(ctx context.Context, store ClientManager, raw, assertion } return &ClientAssertion{ - Raw: raw, + Assertion: assertion, Type: assertionType, Parsed: true, ID: id, @@ -180,8 +181,9 @@ func NewClientAssertion(ctx context.Context, store ClientManager, raw, assertion }, nil } +// ClientAssertion represents a client assertion. type ClientAssertion struct { - Raw, Type string + Assertion, Type string Parsed bool ID, Method, Algorithm string Client Client @@ -295,7 +297,7 @@ func (s *DefaultClientAuthenticationStrategy) doAuthenticateAssertionParseAssert claims = &xjwt.RegisteredClaims{} - if token, err = parser.ParseWithClaims(assertion.Raw, claims, func(token *xjwt.Token) (key any, err error) { + if token, err = parser.ParseWithClaims(assertion.Assertion, claims, func(token *xjwt.Token) (key any, err error) { if subtle.ConstantTimeCompare([]byte(client.GetID()), []byte(claims.Subject)) == 0 { return nil, errorsx.WithStack(ErrInvalidClient.WithHint("The supplied 'client_id' did not match the 'sub' claim of the 'client_assertion'.")) } diff --git a/token/jwt/token.go b/token/jwt/token.go index 65c69498..2f926c1e 100644 --- a/token/jwt/token.go +++ b/token/jwt/token.go @@ -343,6 +343,10 @@ func (t *Token) CompactSignedString(k any) (tokenString string, err error) { return tokenString, nil } +func (t *Token) Validate() { + +} + // IsJWTProfileAccessToken returns true if the token is a JWT Profile Access Token. func (t *Token) IsJWTProfileAccessToken() (ok bool) { var ( @@ -406,6 +410,9 @@ func newToken(parsedToken *jwt.JSONWebToken, claims MapClaims) (*Token, error) { token.Header = map[string]any{ consts.JSONWebTokenHeaderAlgorithm: h.Algorithm, } + + token.SignatureAlgorithm = jose.SignatureAlgorithm(h.Algorithm) + if h.KeyID != "" { token.Header[consts.JSONWebTokenHeaderKeyIdentifier] = h.KeyID token.KeyID = h.KeyID @@ -415,8 +422,6 @@ func newToken(parsedToken *jwt.JSONWebToken, claims MapClaims) (*Token, error) { token.Header[string(k)] = v } - token.SignatureAlgorithm = jose.SignatureAlgorithm(h.Algorithm) - return token, nil } diff --git a/token/jwt/validation_error.go b/token/jwt/validation_error.go index 124f7903..73473e54 100644 --- a/token/jwt/validation_error.go +++ b/token/jwt/validation_error.go @@ -14,15 +14,16 @@ const ( ValidationErrorSignatureInvalid // Signature validation failed // Standard Claim validation errors - ValidationErrorId // JTI validation failed - ValidationErrorAudience // AUD validation failed - ValidationErrorExpired // EXP validation failed - ValidationErrorIssuedAt // IAT validation failed - ValidationErrorNotValidYet // NBF validation failed - ValidationErrorIssuer // ISS validation failed - ValidationErrorSubject // ISS validation failed - ValidationErrorKeyIDInvalid // KeyID invalid error - ValidationErrorAlgorithmInvalid // Algorithm invalid error + ValidationErrorId // Claim JTI validation failed + ValidationErrorAudience // Claim AUD validation failed + ValidationErrorExpired // Claim EXP validation failed + ValidationErrorIssuedAt // Claim IAT validation failed + ValidationErrorNotValidYet // Claim NBF validation failed + ValidationErrorIssuer // Claim ISS validation failed + ValidationErrorSubject // Claim SUB validation failed + ValidationErrorTypInvalid // Header TYP invalid error + ValidationErrorKeyIDInvalid // Header KID invalid error + ValidationErrorAlgorithmInvalid // Header ALG invalid error ValidationErrorClaimsInvalid // Generic claims validation error ) diff --git a/token/jwt/validator.go b/token/jwt/validator.go new file mode 100644 index 00000000..4e7f61fa --- /dev/null +++ b/token/jwt/validator.go @@ -0,0 +1,241 @@ +package jwt + +import ( + "errors" + + "authelia.com/provider/oauth2/internal/consts" +) + +func NewValidator(opts ...ValidatorOpt) (validator *Validator) { + validator = &Validator{ + types: []string{consts.JSONWebTokenTypeJWT}, + nbf: -1, + exp: -1, + iat: -1, + } + + for _, opt := range opts { + opt(validator) + } + + return validator +} + +type ValidatorOpt func(*Validator) + +func ValidateTypes(types ...string) ValidatorOpt { + return func(validator *Validator) { + validator.types = types + } +} + +func ValidateAlgorithm(alg string) ValidatorOpt { + return func(validator *Validator) { + validator.alg = alg + } +} + +func ValidateKeyID(kid string) ValidatorOpt { + return func(validator *Validator) { + validator.kid = kid + } +} + +func ValidateIssuer(iss string) ValidatorOpt { + return func(validator *Validator) { + validator.iss = iss + } +} + +func ValidateSubject(sub string) ValidatorOpt { + return func(validator *Validator) { + validator.sub = sub + } +} + +func ValidateAudienceAll(aud []string) ValidatorOpt { + return func(validator *Validator) { + validator.audAll = aud + } +} + +func ValidateAudienceAny(aud []string) ValidatorOpt { + return func(validator *Validator) { + validator.audAny = aud + } +} + +func ValidateNotBefore(nbf int64) ValidatorOpt { + return func(validator *Validator) { + validator.nbf = nbf + } +} + +func ValidateRequireNotBefore() ValidatorOpt { + return func(validator *Validator) { + validator.requireNBF = true + } +} + +func ValidateExpires(exp int64) ValidatorOpt { + return func(validator *Validator) { + validator.exp = exp + } +} + +func ValidateRequireExpires() ValidatorOpt { + return func(validator *Validator) { + validator.requireEXP = true + } +} + +func ValidateIssuedAt(iat int64) ValidatorOpt { + return func(validator *Validator) { + validator.iat = iat + } +} + +func ValidateRequireIssuedAt() ValidatorOpt { + return func(validator *Validator) { + validator.requireIAT = true + } +} + +type Validator struct { + types []string + alg string + kid string + iss string + sub string + audAll []string + audAny []string + nbf int64 + requireNBF bool + exp int64 + requireEXP bool + iat int64 + requireIAT bool +} + +func (v Validator) Validate(token *Token) (err error) { + vErr := new(ValidationError) + now := TimeFunc().Unix() + + if len(v.types) != 0 { + if !validateTokenType(v.types, token.Header) { + vErr.Inner = errors.New("token has an invalid typ") + vErr.Errors |= ValidationErrorTypInvalid + } + } + + if len(v.alg) != 0 { + if v.alg != string(token.SignatureAlgorithm) { + vErr.Inner = errors.New("token has an invalid alg") + vErr.Errors |= ValidationErrorAlgorithmInvalid + } + } + + if len(v.kid) != 0 { + if v.kid != token.KeyID { + vErr.Inner = errors.New("token has an invalid kid") + vErr.Errors |= ValidationErrorKeyIDInvalid + } + } + + if len(v.iss) != 0 { + if !token.Claims.VerifyIssuer(v.iss, true) { + vErr.Inner = errors.New("token has an invalid issuer") + vErr.Errors |= ValidationErrorIssuer + } + } + + if len(v.sub) != 0 { + if !token.Claims.VerifySubject(v.sub, true) { + vErr.Inner = errors.New("token has an invalid subject") + vErr.Errors |= ValidationErrorSubject + } + } + + if len(v.audAll) != 0 { + if !token.Claims.VerifyAudienceAll(v.audAll, true) { + vErr.Inner = errors.New("token has an invalid audience") + vErr.Errors |= ValidationErrorAudience + } + } + + if len(v.audAny) != 0 { + if !token.Claims.VerifyAudienceAny(v.audAny, true) { + vErr.Inner = errors.New("token has an invalid audience") + vErr.Errors |= ValidationErrorAudience + } + } + + if v.exp != -1 { + exp := v.exp + + if exp == 0 { + exp = now + } + + if !token.Claims.VerifyExpiresAt(exp, v.requireEXP) { + vErr.Inner = errors.New("token is expired") + vErr.Errors |= ValidationErrorExpired + } + } + + if v.iat != -1 { + iat := v.iat + + if iat == 0 { + iat = now + } + + if !token.Claims.VerifyIssuedAt(iat, v.requireIAT) { + vErr.Inner = errors.New("token used before issued") + vErr.Errors |= ValidationErrorIssuedAt + } + } + + if v.nbf != -1 { + nbf := v.nbf + + if nbf == 0 { + nbf = now + } + + if !token.Claims.VerifyNotBefore(nbf, v.requireNBF) { + vErr.Inner = errors.New("token is not valid yet") + vErr.Errors |= ValidationErrorNotValidYet + } + } + + if vErr.valid() { + return nil + } + + return vErr +} + +func validateTokenType(typValues []string, header map[string]any) bool { + var ( + typ string + raw any + ok bool + ) + + if raw, ok = header[consts.JSONWebTokenHeaderType]; !ok { + return false + } + + if typ, ok = raw.(string); !ok { + return false + } + + for _, t := range typValues { + if t == typ { + return true + } + } + + return false +} From 38f8bbfa194d06149f708eb7ee26d1186ed963df Mon Sep 17 00:00:00 2001 From: James Elliott Date: Sun, 22 Sep 2024 12:02:53 +1000 Subject: [PATCH 07/33] temp --- authorize_request_handler.go | 162 ++++++---- ...orize_request_handler_oidc_request_test.go | 22 +- client_authentication_secret_plaintext.go | 41 +++ token/jwt/claims_map.go | 47 ++- token/jwt/jwt_strategy.go | 286 ++++++++++-------- token/jwt/token.go | 92 +++++- token/jwt/token_test.go | 2 +- token/jwt/util.go | 42 ++- token/jwt/validation_error.go | 35 +-- token/jwt/validator.go | 60 +--- 10 files changed, 484 insertions(+), 305 deletions(-) create mode 100644 client_authentication_secret_plaintext.go diff --git a/authorize_request_handler.go b/authorize_request_handler.go index 7aea7bc2..92ee5245 100644 --- a/authorize_request_handler.go +++ b/authorize_request_handler.go @@ -9,6 +9,7 @@ import ( "io" "net/http" "strings" + "time" "github.com/pkg/errors" @@ -65,13 +66,13 @@ func (f *Fosite) authorizeRequestParametersFromOpenIDConnectRequestObject(ctx co } var ( - alg string - algAny, algNone bool + alg string + algAny bool ) switch alg = client.GetRequestObjectSigningAlg(); alg { case consts.JSONWebTokenAlgNone: - algNone = true + break case "": algAny = true default: @@ -115,15 +116,26 @@ func (f *Fosite) authorizeRequestParametersFromOpenIDConnectRequestObject(ctx co assertion = request.Form.Get(consts.FormParameterRequest) } + issuer := f.Config.GetIDTokenIssuer(ctx) + strategy := f.Config.GetJWTStrategy(ctx) token, err := strategy.Decode(ctx, assertion, jwt.WithSigAlgorithm(jwt.SignatureAlgorithmsNone...), jwt.WithJARClient(client)) if err != nil { - return errorsx.WithStack(wrapRequestObjectDecodeError(token, client, openid, err)) + return errorsx.WithStack(fmtRequestObjectDecodeError(token, client, issuer, openid, err)) + } + + optsValidHeader := []jwt.TokenValidationOption{ + jwt.ValidateKeyID(client.GetRequestObjectSigningKeyID()), + jwt.ValidateAlgorithm(client.GetRequestObjectSigningAlg()), + } + + if err = token.Valid(optsValidHeader...); err != nil { + return errorsx.WithStack(fmtRequestObjectDecodeError(token, client, issuer, openid, err)) } if algAny { - if token.SignatureAlgorithm == "none" { + if token.SignatureAlgorithm == consts.JSONWebTokenAlgNone { return errorsx.WithStack( ErrInvalidRequestObject. WithHintf("%s client provided a request object that has an invalid 'kid' or 'alg' header value.", hintRequestObjectPrefix(openid)). @@ -189,58 +201,82 @@ func (f *Fosite) authorizeRequestParametersFromOpenIDConnectRequestObject(ctx co } } - if !algNone { - issuer := f.Config.GetIDTokenIssuer(ctx) + if len(issuer) == 0 { + return errorsx.WithStack(ErrServerError.WithHintf("%s request could not be processed due to an authorization server configuration issue.", hintRequestObjectPrefix(openid)).WithDebugf("The OAuth 2.0 client with id '%s' provided a request object that was signed but the issuer for this authorization server is not known.", request.GetClient().GetID())) + } - if len(issuer) == 0 { - return errorsx.WithStack(ErrServerError.WithHintf("%s request could not be processed due to an authorization server configuration issue.", hintRequestObjectPrefix(openid)).WithDebugf("The OAuth 2.0 client with id '%s' provided a request object that was signed but the issuer for this authorization server is not known.", request.GetClient().GetID())) - } + optsValidClaims := []jwt.ClaimValidationOption{ + jwt.ValidateTimeFunc(func() time.Time { + return time.Now().UTC() + }), + jwt.ValidateIssuer(client.GetID()), + jwt.ValidateAudienceAny(issuer), + } - if v, ok = claims[consts.ClaimIssuer]; !ok { - return errorsx.WithStack(ErrInvalidRequestObject.WithHintf(hintRequestObjectInvalidAuthorizationClaim, hintRequestObjectPrefix(openid)).WithDebugf(debugRequestObjectSignedAbsentClaim, request.GetClient().GetID(), consts.ClaimIssuer)) - } + if err = claims.Valid(optsValidClaims...); err != nil { + return errorsx.WithStack(fmtRequestObjectDecodeError(token, client, issuer, openid, err)) + } - clientID := request.GetClient().GetID() + /* + if !algNone { + issuer := f.Config.GetIDTokenIssuer(ctx) - if value, ok = v.(string); !ok { - return errorsx.WithStack(ErrInvalidRequestObject.WithHintf(hintRequestObjectInvalidAuthorizationClaim, hintRequestObjectPrefix(openid)).WithDebugf(debugRequestObjectValueTypeNotString, request.GetClient().GetID(), consts.ClaimIssuer, v, clientID, v)) - } + if len(issuer) == 0 { + return errorsx.WithStack(ErrServerError.WithHintf("%s request could not be processed due to an authorization server configuration issue.", hintRequestObjectPrefix(openid)).WithDebugf("The OAuth 2.0 client with id '%s' provided a request object that was signed but the issuer for this authorization server is not known.", request.GetClient().GetID())) + } - if value != clientID { - return errorsx.WithStack(ErrInvalidRequestObject.WithHintf(hintRequestObjectInvalidAuthorizationClaim, hintRequestObjectPrefix(openid)).WithDebugf(debugRequestObjectValueMismatch, clientID, consts.ClaimIssuer, value, clientID)) - } + claimsOpts = append(claimsOpts, jwt.(issuer)) + if err = claims.Valid(jwt.ValidateIssuer(client.GetID()), jwt.ValidateAudienceAny(issuer)); err != nil { - if v, ok = claims[consts.ClaimAudience]; !ok { - return errorsx.WithStack(ErrInvalidRequestObject.WithHintf(hintRequestObjectInvalidAuthorizationClaim, hintRequestObjectPrefix(openid)).WithDebugf(debugRequestObjectSignedAbsentClaim, request.GetClient().GetID(), consts.ClaimAudience)) - } + } - var valid bool + if v, ok = claims[consts.ClaimIssuer]; !ok { + return errorsx.WithStack(ErrInvalidRequestObject.WithHintf(hintRequestObjectInvalidAuthorizationClaim, hintRequestObjectPrefix(openid)).WithDebugf(debugRequestObjectSignedAbsentClaim, request.GetClient().GetID(), consts.ClaimIssuer)) + } - switch t := v.(type) { - case string: - valid = t == issuer - case []string: - for _, value = range t { - if value == issuer { - valid = true + clientID := request.GetClient().GetID() - break - } + if value, ok = v.(string); !ok { + return errorsx.WithStack(ErrInvalidRequestObject.WithHintf(hintRequestObjectInvalidAuthorizationClaim, hintRequestObjectPrefix(openid)).WithDebugf(debugRequestObjectValueTypeNotString, request.GetClient().GetID(), consts.ClaimIssuer, v, clientID, v)) } - case []any: - for _, x := range t { - if value, ok = x.(string); ok && value == issuer { - valid = true - break + if value != clientID { + return errorsx.WithStack(ErrInvalidRequestObject.WithHintf(hintRequestObjectInvalidAuthorizationClaim, hintRequestObjectPrefix(openid)).WithDebugf(debugRequestObjectValueMismatch, clientID, consts.ClaimIssuer, value, clientID)) + } + + if v, ok = claims[consts.ClaimAudience]; !ok { + return errorsx.WithStack(ErrInvalidRequestObject.WithHintf(hintRequestObjectInvalidAuthorizationClaim, hintRequestObjectPrefix(openid)).WithDebugf(debugRequestObjectSignedAbsentClaim, request.GetClient().GetID(), consts.ClaimAudience)) + } + + var valid bool + + switch t := v.(type) { + case string: + valid = t == issuer + case []string: + for _, value = range t { + if value == issuer { + valid = true + + break + } + } + case []any: + for _, x := range t { + if value, ok = x.(string); ok && value == issuer { + valid = true + + break + } } } - } - if !valid { - return errorsx.WithStack(ErrInvalidRequestObject.WithHintf(hintRequestObjectInvalidAuthorizationClaim, hintRequestObjectPrefix(openid)).WithDebugf("The OAuth 2.0 client with id '%s' included a request object with a 'aud' claim with the values '%s' which is required match the issuer '%s'.", request.GetClient().GetID(), value, issuer)) + if !valid { + return errorsx.WithStack(ErrInvalidRequestObject.WithHintf(hintRequestObjectInvalidAuthorizationClaim, hintRequestObjectPrefix(openid)).WithDebugf("The OAuth 2.0 client with id '%s' included a request object with a 'aud' claim with the values '%s' which is required match the issuer '%s'.", request.GetClient().GetID(), value, issuer)) + } } - } + + */ claimScope := RemoveEmpty(strings.Split(request.Form.Get(consts.FormParameterScope), " ")) for _, s := range scope { @@ -555,41 +591,57 @@ func (f *Fosite) newAuthorizeRequest(ctx context.Context, r *http.Request, isPAR return request, nil } -func wrapRequestObjectDecodeError(token *jwt.Token, client JARClient, openid bool, inner error) (outer *RFC6749Error) { +func fmtRequestObjectDecodeError(token *jwt.Token, client JARClient, issuer string, openid bool, inner error) (outer *RFC6749Error) { outer = ErrInvalidRequestObject.WithWrap(inner).WithHintf("%s request object could not be decoded or validated.", hintRequestObjectPrefix(openid)) if errJWTValidation := new(jwt.ValidationError); errors.As(inner, &errJWTValidation) { switch { - case errJWTValidation.Has(jwt.ValidationErrorKeyIDInvalid): - return outer.WithDebugf("%s client with id '%s' expects request objects to be signed with the 'kid' value '%s' but the request object was signed with the 'kid' value '%s'.", hintRequestObjectPrefix(openid), client.GetID(), client.GetRequestObjectSigningKeyID(), token.KeyID) - case errJWTValidation.Has(jwt.ValidationErrorAlgorithmInvalid): - return outer.WithDebugf("%s client with id '%s' expects request objects to be signed with the 'alg' value '%s' but the request object was signed with the 'alg' value '%s'.", hintRequestObjectPrefix(openid), client.GetID(), client.GetRequestObjectSigningAlg(), token.SignatureAlgorithm) + case errJWTValidation.Has(jwt.ValidationErrorHeaderKeyIDInvalid): + return outer.WithDebugf("%s client with id '%s' expects request objects to be signed with the 'kid' value '%s' due to the client registration 'request_object_signing_key_id' value but the request object was signed with the 'kid' value '%s'.", hintRequestObjectPrefix(openid), client.GetID(), client.GetRequestObjectSigningKeyID(), token.KeyID) + case errJWTValidation.Has(jwt.ValidationErrorHeaderAlgorithmInvalid): + return outer.WithDebugf("%s client with id '%s' expects request objects to be signed with the 'alg' value '%s' due to the client registration 'request_object_signing_alg' value but the request object was signed with the 'alg' value '%s'.", hintRequestObjectPrefix(openid), client.GetID(), client.GetRequestObjectSigningAlg(), token.SignatureAlgorithm) + case errJWTValidation.Has(jwt.ValidationErrorHeaderTypeInvalid): + return outer.WithDebugf("%s client with id '%s' expects request objects to be signed with the 'typ' value '%s' but the request object was signed with the 'typ' value '%s'.", hintRequestObjectPrefix(openid), client.GetID(), consts.JSONWebTokenTypeJWT, token.Header[consts.JSONWebTokenHeaderType]) case errJWTValidation.Has(jwt.ValidationErrorMalformed): - return outer.WithDebugf("%s client with id '%s' provided a request object that was malformed. The following error occurred trying to validate the request object: %s.", hintRequestObjectPrefix(openid), client.GetID(), strings.TrimPrefix(errJWTValidation.Error(), "go-jose/go-jose: ")) + return outer.WithDebugf("%s client with id '%s' provided a request object that was malformed. %s.", hintRequestObjectPrefix(openid), client.GetID(), strings.TrimPrefix(errJWTValidation.Error(), "go-jose/go-jose: ")) case errJWTValidation.Has(jwt.ValidationErrorUnverifiable): - return outer.WithDebugf("%s client with id '%s' provided a request object that was not able to be verified. The following error occurred trying to validate the object: %s.", hintRequestObjectPrefix(openid), client.GetID(), strings.TrimPrefix(errJWTValidation.Error(), "go-jose/go-jose: ")) + return outer.WithDebugf("%s client with id '%s' provided a request object that was not able to be verified. %s.", hintRequestObjectPrefix(openid), client.GetID(), strings.TrimPrefix(errJWTValidation.Error(), "go-jose/go-jose: ")) case errJWTValidation.Has(jwt.ValidationErrorSignatureInvalid): - return outer.WithDebugf("%s client with id '%s' provided a request object that has an invalid signature. The following error occurred trying to validate the request object signature: %s.", hintRequestObjectPrefix(openid), client.GetID(), strings.TrimPrefix(errJWTValidation.Error(), "go-jose/go-jose: ")) + return outer.WithDebugf("%s client with id '%s' provided a request object that has an invalid signature. %s.", hintRequestObjectPrefix(openid), client.GetID(), strings.TrimPrefix(errJWTValidation.Error(), "go-jose/go-jose: ")) case errJWTValidation.Has(jwt.ValidationErrorExpired): exp, ok := token.Claims.GetExpiresAt() if ok { return outer.WithDebugf("%s client with id '%s' provided a request object that was expired. The request object expired at %d.", hintRequestObjectPrefix(openid), client.GetID(), exp) } else { - return outer.WithDebugf("%s client with id '%s' provided a request object that was expired. Error occurred trying to validate the 'exp' claim': %s", hintRequestObjectPrefix(openid), client.GetID(), strings.TrimPrefix(errJWTValidation.Error(), "go-jose/go-jose: ")) + return outer.WithDebugf("%s client with id '%s' provided a request object that was expired. The request object does not have an 'exp' claim or it has an invalid type.", hintRequestObjectPrefix(openid), client.GetID()) } case errJWTValidation.Has(jwt.ValidationErrorIssuedAt): iat, ok := token.Claims.GetIssuedAt() if ok { return outer.WithDebugf("%s client with id '%s' provided a request object that was issued in the future. The request object was issued at %d.", hintRequestObjectPrefix(openid), client.GetID(), iat) } else { - return outer.WithDebugf("%s client with id '%s' provided a request object that was issued in the future. Error occurred trying to validate the 'iat' claim: %s", hintRequestObjectPrefix(openid), client.GetID(), strings.TrimPrefix(errJWTValidation.Error(), "go-jose/go-jose: ")) + return outer.WithDebugf("%s client with id '%s' provided a request object that was issued in the future. The request object does not have an 'iat' claim or it has an invalid type.", hintRequestObjectPrefix(openid), client.GetID()) } case errJWTValidation.Has(jwt.ValidationErrorNotValidYet): nbf, ok := token.Claims.GetNotBefore() if ok { return outer.WithDebugf("%s client with id '%s' provided a request object that was issued in the future. The request object is not valid before %d.", hintRequestObjectPrefix(openid), client.GetID(), nbf) } else { - return outer.WithDebugf("%s client with id '%s' provided a request object that was issued in the future. Error occurred trying to validate the 'nbf' claim: %s", hintRequestObjectPrefix(openid), client.GetID(), strings.TrimPrefix(errJWTValidation.Error(), "go-jose/go-jose: ")) + return outer.WithDebugf("%s client with id '%s' provided a request object that was issued in the future. The request object does not have an 'nbf' claim or it has an invalid type.", hintRequestObjectPrefix(openid), client.GetID()) + } + case errJWTValidation.Has(jwt.ValidationErrorIssuer): + iss, ok := token.Claims.GetIssuer() + if ok { + return outer.WithDebugf("%s client with id '%s' provided a request object that has an invalid issuer. The request object was expected to have an 'iss' claim which matches the value '%s' but the 'iss' claim had the value '%s'.", hintRequestObjectPrefix(openid), client.GetID(), client.GetID(), iss) + } else { + return outer.WithDebugf("%s client with id '%s' provided a request object that has an invalid issuer. The request object does not have an 'iss' claim or it has an invalid type.", hintRequestObjectPrefix(openid), client.GetID()) + } + case errJWTValidation.Has(jwt.ValidationErrorAudience): + aud, ok := token.Claims.GetAudience() + if ok { + return outer.WithDebugf("%s client with id '%s' provided a request object that has an invalid audience. The request object was expected to have an 'aud' claim which matches the issuer value of '%s' but the 'aud' claim had the values '%s'.", hintRequestObjectPrefix(openid), client.GetID(), issuer, strings.Join(aud, "', '")) + } else { + return outer.WithDebugf("%s client with id '%s' provided a request object that has an invalid audience. The request object does not have an 'aud' claim or it has an invalid type.", hintRequestObjectPrefix(openid), client.GetID()) } case errJWTValidation.Has(jwt.ValidationErrorClaimsInvalid): return outer.WithDebugf("%s client with id '%s' provided a request object that had one or more invalid claims. Error occurred trying to validate the request objects claims: %s", hintRequestObjectPrefix(openid), client.GetID(), strings.TrimPrefix(errJWTValidation.Error(), "go-jose/go-jose: ")) @@ -597,8 +649,8 @@ func wrapRequestObjectDecodeError(token *jwt.Token, client JARClient, openid boo return outer.WithDebugf("%s client with id '%s' provided a request object that could not be validated. Error occurred trying to validate the request object: %s", hintRequestObjectPrefix(openid), client.GetID(), strings.TrimPrefix(errJWTValidation.Error(), "go-jose/go-jose: ")) } } else if errJWKLookup := new(jwt.JWKLookupError); errors.As(inner, &errJWKLookup) { - return outer.WithDebugf("%s client with id '%s' provided a request object that could not be validated due to a key lookup error. %s", hintRequestObjectPrefix(openid), client.GetID(), errJWKLookup.Description) + return outer.WithDebugf("%s client with id '%s' provided a request object that could not be validated due to a key lookup error. %s.", hintRequestObjectPrefix(openid), client.GetID(), errJWKLookup.Description) } else { - return outer.WithDebugf("%s client with id '%s' provided a request object that could not be validated. %s", hintRequestObjectPrefix(openid), client.GetID(), ErrorToDebugRFC6749Error(inner).Error()) + return outer.WithDebugf("%s client with id '%s' provided a request object that could not be validated. %s.", hintRequestObjectPrefix(openid), client.GetID(), ErrorToDebugRFC6749Error(inner).Error()) } } diff --git a/authorize_request_handler_oidc_request_test.go b/authorize_request_handler_oidc_request_test.go index e6bdf7e2..5402e7de 100644 --- a/authorize_request_handler_oidc_request_test.go +++ b/authorize_request_handler_oidc_request_test.go @@ -46,7 +46,7 @@ func TestAuthorizeRequestParametersFromOpenIDConnectRequestObject(t *testing.T) assertionRequestObjectInvalidAudience := mustGenerateAssertion(t, jwt.MapClaims{consts.ClaimIssuer: "foo", consts.ClaimAudience: []string{"https://auth.not-example.com"}, consts.FormParameterScope: "foo", "foo": "bar", "baz": "baz", consts.FormParameterResponseType: consts.ResponseTypeAuthorizationCodeFlow, consts.FormParameterResponseMode: consts.ResponseModeFormPost}, key, "kid-foo") assertionRequestObjectInvalidIssuer := mustGenerateAssertion(t, jwt.MapClaims{consts.ClaimIssuer: "not-foo", consts.ClaimAudience: []string{"https://auth.example.com"}, consts.FormParameterScope: "foo", "foo": "bar", "baz": "baz", consts.FormParameterResponseType: consts.ResponseTypeAuthorizationCodeFlow, consts.FormParameterResponseMode: consts.ResponseModeFormPost}, key, "kid-foo") assertionRequestObjectValidWithoutKID := mustGenerateAssertion(t, jwt.MapClaims{consts.ClaimIssuer: "foo", consts.ClaimAudience: []string{"https://auth.example.com"}, consts.FormParameterScope: "foo", "foo": "bar", "baz": "baz"}, key, "") - assertionRequestObjectValidNone := mustGenerateNoneAssertion(t, jwt.MapClaims{consts.FormParameterScope: "foo", "foo": "bar", "baz": "baz", consts.FormParameterState: "some-state"}) + assertionRequestObjectValidNone := mustGenerateNoneAssertion(t, jwt.MapClaims{consts.FormParameterScope: "foo", "foo": "bar", "baz": "baz", consts.FormParameterState: "some-state", consts.ClaimIssuer: "foo", consts.ClaimAudience: []string{"https://auth.example.com"}}) mux := http.NewServeMux() @@ -187,7 +187,7 @@ func TestAuthorizeRequestParametersFromOpenIDConnectRequestObject(t *testing.T) client: &DefaultJARClient{JSONWebKeys: jwks, RequestObjectSigningAlg: "RS256", DefaultClient: &DefaultClient{ID: "foo"}}, expected: url.Values{consts.FormParameterScope: {consts.ScopeOpenID}}, err: ErrInvalidRequestObject, - errString: "The request parameter contains an invalid Request Object. OpenID Connect 1.0 request object could not be decoded or validated. OpenID Connect 1.0 client with id 'foo' provided a request object that was malformed. The following error occurred trying to validate the request object: compact JWS format must have three parts.", + errString: "The request parameter contains an invalid Request Object. OpenID Connect 1.0 request object could not be decoded or validated. OpenID Connect 1.0 client with id 'foo' provided a request object that could not be validated. Provided value does not appear to be a JWE or JWS compact serialized JWT.", }, { name: "ShouldFailUnknownKID", @@ -195,15 +195,15 @@ func TestAuthorizeRequestParametersFromOpenIDConnectRequestObject(t *testing.T) client: &DefaultJARClient{JSONWebKeys: jwks, RequestObjectSigningAlg: "RS256", DefaultClient: &DefaultClient{ID: "test"}}, expected: url.Values{consts.FormParameterScope: {consts.ScopeOpenID}}, err: ErrInvalidRequestObject, - errString: "The request parameter contains an invalid Request Object. OpenID Connect 1.0 request object could not be decoded or validated. OpenID Connect 1.0 client with id 'test' provided a request object that was not able to be verified. The following error occurred trying to validate the object: Error occurred looking up JSON Web Key: The JSON Web Token uses signing key with kid 'does-not-exists' which was not found..", + errString: "The request parameter contains an invalid Request Object. OpenID Connect 1.0 request object could not be decoded or validated. OpenID Connect 1.0 client with id 'test' provided a request object that was not able to be verified. Error occurred retrieving the JSON Web Key. The JSON Web Token uses signing key with kid 'does-not-exists' which was not found.", }, { name: "ShouldFailBadAlgRS256", have: url.Values{consts.FormParameterScope: {consts.ScopeOpenID}, consts.FormParameterClientID: {"foo"}, consts.FormParameterResponseType: {consts.ResponseTypeAuthorizationCodeFlow}, consts.FormParameterRequest: {mustGenerateHSAssertion(t, jwt.MapClaims{})}}, - client: &DefaultJARClient{JSONWebKeys: jwks, RequestObjectSigningAlg: "RS256", DefaultClient: &DefaultClient{ID: "test"}}, + client: &DefaultJARClient{JSONWebKeys: jwks, RequestObjectSigningAlg: "RS256", DefaultClient: &DefaultClient{ID: "test", ClientSecret: NewPlainTextClientSecret("aaaaaaaaaaaaaaabbbbbbbbbbbbbbbbbbbbbbbcccccccccccccccccccccddddddddddddddddddddddd")}}, expected: url.Values{consts.FormParameterScope: {consts.ScopeOpenID}}, err: ErrInvalidRequestObject, - errString: "The request parameter contains an invalid Request Object. OpenID Connect 1.0 request object could not be decoded or validated. OpenID Connect 1.0 client with id 'test' expects request objects to be signed with the 'alg' value 'RS256' but the request object was signed with the 'alg' value 'HS256'.", + errString: "The request parameter contains an invalid Request Object. OpenID Connect 1.0 request object could not be decoded or validated. OpenID Connect 1.0 client with id 'test' expects request objects to be signed with the 'alg' value 'RS256' due to the client registration 'request_object_signing_alg' value but the request object was signed with the 'alg' value 'HS256'.", }, { name: "ShouldFailMismatchedClientID", @@ -279,7 +279,7 @@ func TestAuthorizeRequestParametersFromOpenIDConnectRequestObject(t *testing.T) client: &DefaultJARClient{JSONWebKeysURI: root.JoinPath("jwks.json").String(), RequestObjectSigningAlg: "RS256", DefaultClient: &DefaultClient{ID: "foo"}}, expected: url.Values{consts.FormParameterState: {"some-state"}, consts.FormParameterClientID: {"foo"}, consts.FormParameterResponseType: {consts.ResponseTypeAuthorizationCodeFlow}, consts.FormParameterScope: {"foo openid"}, consts.FormParameterRequest: {assertionRequestObjectValidNone}, "foo": {"bar"}, "baz": {"baz"}}, err: ErrInvalidRequestObject, - errString: "The request parameter contains an invalid Request Object. OpenID Connect 1.0 client provided a request object that has an invalid 'kid' or 'alg' header value. OpenID Connect 1.0 client with id 'foo' was registered with a 'request_object_signing_alg' value of 'RS256' but the request object had the 'alg' value 'none' in the header.", + errString: "The request parameter contains an invalid Request Object. OpenID Connect 1.0 request object could not be decoded or validated. OpenID Connect 1.0 client with id 'foo' expects request objects to be signed with the 'alg' value 'RS256' due to the client registration 'request_object_signing_alg' value but the request object was signed with the 'alg' value 'none'.", }, { name: "ShouldFailRequestURIAlgNone", @@ -287,7 +287,7 @@ func TestAuthorizeRequestParametersFromOpenIDConnectRequestObject(t *testing.T) client: &DefaultJARClient{JSONWebKeysURI: root.JoinPath("jwks.json").String(), RequestObjectSigningAlg: "RS256", RequestURIs: []string{root.JoinPath("request-object", "valid", "none.jwk").String()}, DefaultClient: &DefaultClient{ID: "foo"}}, expected: url.Values{consts.FormParameterResponseType: {"token"}, consts.FormParameterClientID: {"foo"}, consts.FormParameterState: {"some-state"}, consts.FormParameterScope: {"foo openid"}, consts.FormParameterRequestURI: {root.JoinPath("request-object", "valid", "none.jwk").String()}, "foo": {"bar"}, "baz": {"baz"}}, err: ErrInvalidRequestObject, - errString: "The request parameter contains an invalid Request Object. OpenID Connect 1.0 client provided a request object that has an invalid 'kid' or 'alg' header value. OpenID Connect 1.0 client with id 'foo' was registered with a 'request_object_signing_alg' value of 'RS256' but the request object had the 'alg' value 'none' in the header.", + errString: "The request parameter contains an invalid Request Object. OpenID Connect 1.0 request object could not be decoded or validated. OpenID Connect 1.0 client with id 'foo' expects request objects to be signed with the 'alg' value 'RS256' due to the client registration 'request_object_signing_alg' value but the request object was signed with the 'alg' value 'none'.", }, { name: "ShouldFailRequestRS256", @@ -295,7 +295,7 @@ func TestAuthorizeRequestParametersFromOpenIDConnectRequestObject(t *testing.T) client: &DefaultJARClient{JSONWebKeysURI: root.JoinPath("jwks.json").String(), RequestObjectSigningAlg: consts.JSONWebTokenAlgNone, DefaultClient: &DefaultClient{ID: "foo"}}, expected: url.Values{consts.FormParameterState: {"some-state"}, consts.FormParameterClientID: {"foo"}, consts.FormParameterResponseType: {consts.ResponseTypeAuthorizationCodeFlow}, consts.FormParameterScope: {"foo openid"}, consts.FormParameterRequest: {assertionRequestObjectValid}, "foo": {"bar"}, "baz": {"baz"}}, err: ErrInvalidRequestObject, - errString: "The request parameter contains an invalid Request Object. OpenID Connect 1.0 client provided a request object that has an invalid 'kid' or 'alg' header value. OpenID Connect 1.0 client with id 'foo' was registered with a 'request_object_signing_alg' value of 'none' but the request object had the 'alg' value 'RS256' in the header.", + errString: "The request parameter contains an invalid Request Object. OpenID Connect 1.0 request object could not be decoded or validated. OpenID Connect 1.0 client with id 'foo' expects request objects to be signed with the 'alg' value 'none' due to the client registration 'request_object_signing_alg' value but the request object was signed with the 'alg' value 'RS256'.", }, { name: "ShouldFailRequestURIRS256", @@ -303,7 +303,7 @@ func TestAuthorizeRequestParametersFromOpenIDConnectRequestObject(t *testing.T) client: &DefaultJARClient{JSONWebKeysURI: root.JoinPath("jwks.json").String(), RequestObjectSigningAlg: consts.JSONWebTokenAlgNone, RequestURIs: []string{root.JoinPath("request-object", "valid", "standard.jwk").String()}, DefaultClient: &DefaultClient{ID: "foo"}}, expected: url.Values{consts.FormParameterResponseType: {"token"}, consts.FormParameterClientID: {"foo"}, consts.FormParameterState: {"some-state"}, consts.FormParameterScope: {"foo openid"}, consts.FormParameterRequestURI: {root.JoinPath("request-object", "valid", "standard.jwk").String()}, "foo": {"bar"}, "baz": {"baz"}}, err: ErrInvalidRequestObject, - errString: "The request parameter contains an invalid Request Object. OpenID Connect 1.0 client provided a request object that has an invalid 'kid' or 'alg' header value. OpenID Connect 1.0 client with id 'foo' was registered with a 'request_object_signing_alg' value of 'none' but the request object had the 'alg' value 'RS256' in the header.", + errString: "The request parameter contains an invalid Request Object. OpenID Connect 1.0 request object could not be decoded or validated. OpenID Connect 1.0 client with id 'foo' expects request objects to be signed with the 'alg' value 'none' due to the client registration 'request_object_signing_alg' value but the request object was signed with the 'alg' value 'RS256'.", }, { name: "ShouldPassRequestAlgNone", @@ -339,7 +339,7 @@ func TestAuthorizeRequestParametersFromOpenIDConnectRequestObject(t *testing.T) client: &DefaultJARClient{JSONWebKeysURI: root.JoinPath("jwks.json").String(), RequestObjectSigningAlg: "RS256", DefaultClient: &DefaultClient{ID: "foo"}}, expected: url.Values{consts.FormParameterState: {"some-state"}, consts.FormParameterClientID: {"foo"}, consts.FormParameterResponseType: {consts.ResponseTypeAuthorizationCodeFlow}, consts.FormParameterScope: {"foo openid"}, consts.FormParameterRequest: {assertionRequestObjectValidNone}, "foo": {"bar"}, "baz": {"baz"}}, err: ErrInvalidRequestObject, - errString: "The request parameter contains an invalid Request Object. OpenID Connect 1.0 request included a request object which excluded claims that are required or included claims that did not match the OAuth 2.0 request syntax or are generally not permitted. The OAuth 2.0 client with id 'foo' included a request object with a 'aud' claim with the values 'https://auth.not-example.com' which is required match the issuer 'https://auth.example.com'.", + errString: "The request parameter contains an invalid Request Object. OpenID Connect 1.0 request object could not be decoded or validated. OpenID Connect 1.0 client with id 'foo' provided a request object that has an invalid audience. The request object was expected to have an 'aud' claim which matches the issuer value of 'https://auth.example.com' but the 'aud' claim had the values 'https://auth.not-example.com'.", }, { name: "ShouldFailRequestURIBadAudience", @@ -355,7 +355,7 @@ func TestAuthorizeRequestParametersFromOpenIDConnectRequestObject(t *testing.T) client: &DefaultJARClient{JSONWebKeysURI: root.JoinPath("jwks.json").String(), RequestObjectSigningAlg: "RS256", DefaultClient: &DefaultClient{ID: "foo"}}, expected: url.Values{consts.FormParameterState: {"some-state"}, consts.FormParameterClientID: {"foo"}, consts.FormParameterResponseType: {consts.ResponseTypeAuthorizationCodeFlow}, consts.FormParameterScope: {"foo openid"}, consts.FormParameterRequest: {assertionRequestObjectValidNone}, "foo": {"bar"}, "baz": {"baz"}}, err: ErrInvalidRequestObject, - errString: "The request parameter contains an invalid Request Object. OpenID Connect 1.0 request included a request object which excluded claims that are required or included claims that did not match the OAuth 2.0 request syntax or are generally not permitted. The OAuth 2.0 client with id 'foo' included a request object with a 'iss' claim with a value of 'not-foo' which is required to match the value 'foo' in the parameter with the same name from the OAuth 2.0 request syntax.", + errString: "The request parameter contains an invalid Request Object. OpenID Connect 1.0 request object could not be decoded or validated. OpenID Connect 1.0 client with id 'foo' provided a request object that has an invalid issuer. The request object was expected to have an 'iss' claim which matches the value 'foo' but the 'iss' claim had the value 'not-foo'.", }, { name: "ShouldFailRequestURIBadIssuer", diff --git a/client_authentication_secret_plaintext.go b/client_authentication_secret_plaintext.go new file mode 100644 index 00000000..4d7ffc7e --- /dev/null +++ b/client_authentication_secret_plaintext.go @@ -0,0 +1,41 @@ +// Copyright © 2023 Ory Corp +// SPDX-License-Identifier: Apache-2.0 + +package oauth2 + +import ( + "context" + "crypto/subtle" + "fmt" + + "authelia.com/provider/oauth2/x/errorsx" +) + +// NewPlainTextClientSecret returns a new PlainTextClientSecret given a value. +func NewPlainTextClientSecret(value string) *PlainTextClientSecret { + return &PlainTextClientSecret{value: []byte(value)} +} + +type PlainTextClientSecret struct { + value []byte +} + +func (s *PlainTextClientSecret) IsPlainText() (is bool) { + return true +} + +func (s *PlainTextClientSecret) GetPlainTextValue() (secret []byte, err error) { + return s.value, nil +} + +func (s *PlainTextClientSecret) Compare(ctx context.Context, secret []byte) (err error) { + if subtle.ConstantTimeCompare(s.value, secret) == 0 { + return errorsx.WithStack(fmt.Errorf("secrets don't match")) + } + + return nil +} + +func (s *PlainTextClientSecret) Valid() (valid bool) { + return s != nil && len(s.value) != 0 +} diff --git a/token/jwt/claims_map.go b/token/jwt/claims_map.go index 9c8aec2a..8b57c01e 100644 --- a/token/jwt/claims_map.go +++ b/token/jwt/claims_map.go @@ -169,12 +169,11 @@ func (m MapClaims) VerifyNotBefore(cmp int64, required bool) (ok bool) { return verifyInt64Past(nbf, cmp, required) } -// Valid validates time based claims "exp, iat, nbf". -// There is no accounting for clock skew. -// As well, if any of the above claims are not in the token, it will still -// be considered a valid claim. -func (m MapClaims) Valid(opts ...ValidationOpt) error { - vopts := &optsValidation{} +// Valid validates the given claims. By default it only validates time based claims "exp, iat, nbf"; there is no +// accounting for clock skew, and if any of the above claims are not in the token, the claims will still be considered +// valid. However all of these options can be tuned by the opts. +func (m MapClaims) Valid(opts ...ClaimValidationOption) (err error) { + vopts := &ClaimValidationOptions{} for _, opt := range opts { opt(vopts) @@ -281,9 +280,9 @@ func (m MapClaims) toInt64(claim string) (val int64, ok bool) { return 0, false } -type ValidationOpt func(opts *optsValidation) +type ClaimValidationOption func(opts *ClaimValidationOptions) -type optsValidation struct { +type ClaimValidationOptions struct { timef func() time.Time iss string aud []string @@ -294,50 +293,50 @@ type optsValidation struct { nbfRequired bool } -func ValidateTimeFunc(timef func() time.Time) ValidationOpt { - return func(opts *optsValidation) { +func ValidateTimeFunc(timef func() time.Time) ClaimValidationOption { + return func(opts *ClaimValidationOptions) { opts.timef = timef } } -func ValidateIssuer(iss string) ValidationOpt { - return func(opts *optsValidation) { +func ValidateIssuer(iss string) ClaimValidationOption { + return func(opts *ClaimValidationOptions) { opts.iss = iss } } -func ValidateAudienceAny(aud ...string) ValidationOpt { - return func(opts *optsValidation) { +func ValidateAudienceAny(aud ...string) ClaimValidationOption { + return func(opts *ClaimValidationOptions) { opts.aud = aud } } -func ValidateAudienceAll(aud ...string) ValidationOpt { - return func(opts *optsValidation) { +func ValidateAudienceAll(aud ...string) ClaimValidationOption { + return func(opts *ClaimValidationOptions) { opts.audAll = aud } } -func ValidateSubject(sub string) ValidationOpt { - return func(opts *optsValidation) { +func ValidateSubject(sub string) ClaimValidationOption { + return func(opts *ClaimValidationOptions) { opts.sub = sub } } -func ValidateRequireExpiresAt() ValidationOpt { - return func(opts *optsValidation) { +func ValidateRequireExpiresAt() ClaimValidationOption { + return func(opts *ClaimValidationOptions) { opts.expRequired = true } } -func ValidateRequireIssuedAt() ValidationOpt { - return func(opts *optsValidation) { +func ValidateRequireIssuedAt() ClaimValidationOption { + return func(opts *ClaimValidationOptions) { opts.iatRequired = true } } -func ValidateRequireNotBefore() ValidationOpt { - return func(opts *optsValidation) { +func ValidateRequireNotBefore() ClaimValidationOption { + return func(opts *ClaimValidationOptions) { opts.nbfRequired = true } } diff --git a/token/jwt/jwt_strategy.go b/token/jwt/jwt_strategy.go index adadb29d..2bf35db4 100644 --- a/token/jwt/jwt_strategy.go +++ b/token/jwt/jwt_strategy.go @@ -11,11 +11,22 @@ import ( "authelia.com/provider/oauth2/x/errorsx" ) -// Strategy represents the strategy for encoding and decoding JWT's. +// Strategy represents the strategy for encoding and decoding JWT's. It's important to note that this is an interface +// specifically so it can be mocked and the opts values have very important semantics which are difficult to document. type Strategy interface { + // Encode a JWT as either a JWS or JWE nested JWS. Encode(ctx context.Context, opts ...StrategyOpt) (tokenString string, signature string, err error) - Decode(ctx context.Context, tokenString string, opts ...StrategyOpt) (token *Token, err error) + + // Decrypt a JWT or if the provided JWT is a JWS just return it. Decrypt(ctx context.Context, tokenStringEnc string, opts ...StrategyOpt) (tokenString, signature string, jwe *jose.JSONWebEncryption, err error) + + // Decode a JWT. This performs decryption as well as basic signature validation. Optionally the signature validation + // can be skipped and validated later using Validate. + Decode(ctx context.Context, tokenString string, opts ...StrategyOpt) (token *Token, err error) + + // Validate allows performing the signature validation step after using the Decode function without a client while + // also using WithAllowUnverified. + Validate(ctx context.Context, token *Token, opts ...StrategyOpt) (err error) } type StrategyConfig interface { @@ -87,81 +98,82 @@ func (j *DefaultStrategy) Encode(ctx context.Context, opts ...StrategyOpt) (toke return encodeNestedCompactEncrypted(ctx, o.claims, o.headers, o.headersJWE, keySig, keyEnc, jose.ContentEncryption(enc)) } -func (j *DefaultStrategy) Validate(ctx context.Context, token *Token, opts ...StrategyOpt) (err error) { - if token == nil { - return errorsx.WithStack(fmt.Errorf("token is nil")) - } - - if token.valid { - return nil - } - - if token.parsedToken == nil { - return errorsx.WithStack(fmt.Errorf("token is in an inconsistent state")) +func (j *DefaultStrategy) Decrypt(ctx context.Context, tokenStringEnc string, opts ...StrategyOpt) (tokenString, signature string, jwe *jose.JSONWebEncryption, err error) { + if !IsEncryptedJWT(tokenStringEnc) { + if IsSignedJWT(tokenStringEnc) { + return tokenStringEnc, "", nil, nil + } else { + return tokenStringEnc, "", nil, fmt.Errorf("Provided value does not appear to be a JWE or JWS compact serialized JWT") + } } o := &StrategyOpts{ - sigAlgorithm: SignatureAlgorithms, + sigAlgorithm: SignatureAlgorithmsNone, keyAlgorithm: EncryptionKeyAlgorithms, contentEncryption: ContentEncryptionAlgorithms, - jwsKeyFunc: nil, - jweKeyFunc: nil, } for _, opt := range opts { if err = opt(o); err != nil { - return errorsx.WithStack(err) + return "", "", nil, errorsx.WithStack(&ValidationError{Errors: ValidationErrorUnverifiable, Inner: err}) } } - if err = j.validate(ctx, token.parsedToken, &MapClaims{}, o); err != nil { - return err - } - - token.valid = true + var ( + key *jose.JSONWebKey + ) - return nil -} + if jwe, err = jose.ParseEncryptedCompact(tokenStringEnc, o.keyAlgorithm, o.contentEncryption); err != nil { + return "", "", nil, errorsx.WithStack(&ValidationError{Errors: ValidationErrorMalformed, Inner: err}) + } -func (j *DefaultStrategy) validate(ctx context.Context, t *jwt.JSONWebToken, dest any, o *StrategyOpts) (err error) { var ( - key *jose.JSONWebKey - kid, alg string + kid, alg, cty string ) - if kid, alg, err = headerValidateJWS(t.Headers); err != nil { - return errorsx.WithStack(&ValidationError{Errors: ValidationErrorMalformed, Inner: err}) + if kid, alg, _, cty, err = headerValidateJWE(jwe.Header); err != nil { + return "", "", nil, errorsx.WithStack(&ValidationError{Errors: ValidationErrorMalformed, Inner: err}) } - claims := MapClaims{} - - if o.jwsKeyFunc != nil { - if key, err = o.jwsKeyFunc(ctx, t, claims); err != nil { - return errorsx.WithStack(&ValidationError{Errors: ValidationErrorUnverifiable, Inner: err}) + if o.jweKeyFunc != nil { + if key, err = o.jweKeyFunc(ctx, jwe, kid, alg); err != nil { + return "", "", nil, errorsx.WithStack(&ValidationError{Errors: ValidationErrorUnverifiable, Inner: err}) + } + } else if IsEncryptedJWTClientSecretAlg(alg) { + if o.client == nil { + return "", "", nil, errorsx.WithStack(&ValidationError{Errors: ValidationErrorUnverifiable, Inner: err}) } - } else if o.client != nil && o.client.IsClientSigned() { - if IsSignedJWTClientSecretAlg(alg) { - if kid != "" { - return errorsx.WithStack(&ValidationError{Errors: ValidationErrorKeyIDInvalid, Inner: fmt.Errorf("error validating the jws header: alg '%s' does not support tokens with a kid but the token has kid '%s'", alg, kid)}) - } - if key, err = NewJWKFromClientSecret(ctx, o.client, "", alg, consts.JSONWebTokenUseSignature); err != nil { - return errorsx.WithStack(&ValidationError{Errors: ValidationErrorUnverifiable, Inner: err}) - } - } else { - if key, err = FindClientPublicJWK(ctx, o.client, j.Config.GetJWKSFetcherStrategy(ctx), kid, alg, consts.JSONWebTokenUseSignature, true); err != nil { - return errorsx.WithStack(&ValidationError{Errors: ValidationErrorUnverifiable, Inner: err}) - } + if key, err = NewJWKFromClientSecret(ctx, o.client, kid, alg, consts.JSONWebTokenUseEncryption); err != nil { + return "", "", nil, errorsx.WithStack(&ValidationError{Errors: ValidationErrorUnverifiable, Inner: err}) } - } else if key, err = j.Issuer.GetIssuerStrictJWK(ctx, kid, alg, consts.JSONWebTokenUseSignature); err != nil { - return errorsx.WithStack(&ValidationError{Errors: ValidationErrorUnverifiable, Inner: err}) + } else if key, err = j.Issuer.GetIssuerStrictJWK(ctx, kid, alg, consts.JSONWebTokenUseEncryption); err != nil { + return "", "", nil, errorsx.WithStack(&ValidationError{Errors: ValidationErrorUnverifiable, Inner: err}) } - if err = t.Claims(key.Public(), &dest); err != nil { - return errorsx.WithStack(&ValidationError{Errors: ValidationErrorSignatureInvalid, Inner: err}) + var tokenRaw []byte + + if tokenRaw, err = jwe.Decrypt(key); err != nil { + return "", "", nil, errorsx.WithStack(&ValidationError{Errors: ValidationErrorMalformed, Inner: err}) } - return nil + tokenString = string(tokenRaw) + + var t *jwt.JSONWebToken + + if t, err = jwt.ParseSigned(tokenString, SignatureAlgorithmsNone); err != nil { + return "", "", nil, errorsx.WithStack(&ValidationError{Errors: ValidationErrorMalformed, Inner: err}) + } + + if err = headerValidateJWSNested(t.Headers, cty); err != nil { + return "", "", nil, errorsx.WithStack(&ValidationError{Errors: ValidationErrorMalformed, Inner: err}) + } + + if signature, err = getJWTSignature(tokenString); err != nil { + return "", "", nil, errorsx.WithStack(&ValidationError{Errors: ValidationErrorMalformed, Inner: err}) + } + + return tokenString, signature, jwe, nil } func (j *DefaultStrategy) Decode(ctx context.Context, tokenString string, opts ...StrategyOpt) (token *Token, err error) { @@ -180,54 +192,65 @@ func (j *DefaultStrategy) Decode(ctx context.Context, tokenString string, opts . } var ( - key *jose.JSONWebKey + // key *jose.JSONWebKey t *jwt.JSONWebToken jwe *jose.JSONWebEncryption ) - if IsEncryptedJWT(tokenString) { - if jwe, err = jose.ParseEncryptedCompact(tokenString, o.keyAlgorithm, o.contentEncryption); err != nil { - return token, errorsx.WithStack(&ValidationError{Errors: ValidationErrorMalformed, Inner: err}) - } + tokenString, _, jwe, err = j.Decrypt(ctx, tokenString, opts...) + if err != nil { + return token, err + } - var ( - kid, alg, cty string - ) + /* + if IsEncryptedJWT(tokenString) { + if jwe, err = jose.ParseEncryptedCompact(tokenString, o.keyAlgorithm, o.contentEncryption); err != nil { + return token, errorsx.WithStack(&ValidationError{Errors: ValidationErrorMalformed, Inner: err}) + } - if kid, alg, _, cty, err = headerValidateJWE(jwe.Header); err != nil { - return token, errorsx.WithStack(&ValidationError{Errors: ValidationErrorMalformed, Inner: err}) - } + var ( + kid, alg, cty string + ) - if o.jweKeyFunc != nil { - if key, err = o.jweKeyFunc(ctx, jwe, kid, alg); err != nil { - return token, errorsx.WithStack(&ValidationError{Errors: ValidationErrorUnverifiable, Inner: err}) - } - } else if IsEncryptedJWTClientSecretAlg(alg) { - if o.client == nil { - return token, errorsx.WithStack(&ValidationError{Errors: ValidationErrorUnverifiable, Inner: err}) + if kid, alg, _, cty, err = headerValidateJWE(jwe.Header); err != nil { + return token, errorsx.WithStack(&ValidationError{Errors: ValidationErrorMalformed, Inner: err}) } - if key, err = NewJWKFromClientSecret(ctx, o.client, kid, alg, consts.JSONWebTokenUseEncryption); err != nil { + if o.jweKeyFunc != nil { + if key, err = o.jweKeyFunc(ctx, jwe, kid, alg); err != nil { + return token, errorsx.WithStack(&ValidationError{Errors: ValidationErrorUnverifiable, Inner: err}) + } + } else if IsEncryptedJWTClientSecretAlg(alg) { + if o.client == nil { + return token, errorsx.WithStack(&ValidationError{Errors: ValidationErrorUnverifiable, Inner: err}) + } + + if key, err = NewJWKFromClientSecret(ctx, o.client, kid, alg, consts.JSONWebTokenUseEncryption); err != nil { + return token, errorsx.WithStack(&ValidationError{Errors: ValidationErrorUnverifiable, Inner: err}) + } + } else if key, err = j.Issuer.GetIssuerStrictJWK(ctx, kid, alg, consts.JSONWebTokenUseEncryption); err != nil { return token, errorsx.WithStack(&ValidationError{Errors: ValidationErrorUnverifiable, Inner: err}) } - } else if key, err = j.Issuer.GetIssuerStrictJWK(ctx, kid, alg, consts.JSONWebTokenUseEncryption); err != nil { - return token, errorsx.WithStack(&ValidationError{Errors: ValidationErrorUnverifiable, Inner: err}) - } - var rawJWT []byte + var rawJWT []byte - if rawJWT, err = jwe.Decrypt(key); err != nil { - return token, errorsx.WithStack(&ValidationError{Errors: ValidationErrorMalformed, Inner: err}) - } + if rawJWT, err = jwe.Decrypt(key); err != nil { + return token, errorsx.WithStack(&ValidationError{Errors: ValidationErrorMalformed, Inner: err}) + } - if t, err = jwt.ParseSigned(string(rawJWT), o.sigAlgorithm); err != nil { - return token, errorsx.WithStack(&ValidationError{Errors: ValidationErrorMalformed, Inner: err}) - } + if t, err = jwt.ParseSigned(string(rawJWT), o.sigAlgorithm); err != nil { + return token, errorsx.WithStack(&ValidationError{Errors: ValidationErrorMalformed, Inner: err}) + } - if err = headerValidateJWSNested(t.Headers, cty); err != nil { + if err = headerValidateJWSNested(t.Headers, cty); err != nil { + return token, errorsx.WithStack(&ValidationError{Errors: ValidationErrorMalformed, Inner: err}) + } + } else if t, err = jwt.ParseSigned(tokenString, o.sigAlgorithm); err != nil { return token, errorsx.WithStack(&ValidationError{Errors: ValidationErrorMalformed, Inner: err}) } - } else if t, err = jwt.ParseSigned(tokenString, o.sigAlgorithm); err != nil { + */ + + if t, err = jwt.ParseSigned(tokenString, o.sigAlgorithm); err != nil { return token, errorsx.WithStack(&ValidationError{Errors: ValidationErrorMalformed, Inner: err}) } @@ -264,80 +287,79 @@ func (j *DefaultStrategy) Decode(ctx context.Context, tokenString string, opts . return token, nil } -func (j *DefaultStrategy) Decrypt(ctx context.Context, tokenStringEnc string, opts ...StrategyOpt) (tokenString, signature string, jwe *jose.JSONWebEncryption, err error) { - if !IsEncryptedJWT(tokenStringEnc) { - if IsSignedJWT(tokenStringEnc) { - return tokenStringEnc, "", nil, nil - } else { - return tokenStringEnc, "", nil, fmt.Errorf("token does not appear to be a jwe or jws compact serializd jwt") - } +func (j *DefaultStrategy) Validate(ctx context.Context, token *Token, opts ...StrategyOpt) (err error) { + if token == nil { + return errorsx.WithStack(fmt.Errorf("token is nil")) + } + + if token.valid { + return nil + } + + if token.parsedToken == nil { + return errorsx.WithStack(fmt.Errorf("token is in an inconsistent state")) } o := &StrategyOpts{ - sigAlgorithm: SignatureAlgorithmsNone, + sigAlgorithm: SignatureAlgorithms, keyAlgorithm: EncryptionKeyAlgorithms, contentEncryption: ContentEncryptionAlgorithms, + jwsKeyFunc: nil, + jweKeyFunc: nil, } for _, opt := range opts { if err = opt(o); err != nil { - return "", "", nil, errorsx.WithStack(&ValidationError{Errors: ValidationErrorUnverifiable, Inner: err}) + return errorsx.WithStack(err) } } - var ( - key *jose.JSONWebKey - ) - - if jwe, err = jose.ParseEncryptedCompact(tokenStringEnc, o.keyAlgorithm, o.contentEncryption); err != nil { - return "", "", nil, errorsx.WithStack(&ValidationError{Errors: ValidationErrorMalformed, Inner: err}) + if err = j.validate(ctx, token.parsedToken, &MapClaims{}, o); err != nil { + return err } + token.valid = true + + return nil +} + +func (j *DefaultStrategy) validate(ctx context.Context, t *jwt.JSONWebToken, dest any, o *StrategyOpts) (err error) { var ( - kid, alg, cty string + key *jose.JSONWebKey + kid, alg string ) - if kid, alg, _, cty, err = headerValidateJWE(jwe.Header); err != nil { - return "", "", nil, errorsx.WithStack(&ValidationError{Errors: ValidationErrorMalformed, Inner: err}) + if kid, alg, err = headerValidateJWS(t.Headers); err != nil { + return errorsx.WithStack(&ValidationError{Errors: ValidationErrorMalformed, Inner: err}) } - if o.jweKeyFunc != nil { - if key, err = o.jweKeyFunc(ctx, jwe, kid, alg); err != nil { - return "", "", nil, errorsx.WithStack(&ValidationError{Errors: ValidationErrorUnverifiable, Inner: err}) - } - } else if IsEncryptedJWTClientSecretAlg(alg) { - if o.client == nil { - return "", "", nil, errorsx.WithStack(&ValidationError{Errors: ValidationErrorUnverifiable, Inner: err}) - } + claims := MapClaims{} - if key, err = NewJWKFromClientSecret(ctx, o.client, kid, alg, consts.JSONWebTokenUseEncryption); err != nil { - return "", "", nil, errorsx.WithStack(&ValidationError{Errors: ValidationErrorUnverifiable, Inner: err}) + if o.jwsKeyFunc != nil { + if key, err = o.jwsKeyFunc(ctx, t, claims); err != nil { + return errorsx.WithStack(&ValidationError{Errors: ValidationErrorUnverifiable, Inner: err}) } - } else if key, err = j.Issuer.GetIssuerStrictJWK(ctx, kid, alg, consts.JSONWebTokenUseEncryption); err != nil { - return "", "", nil, errorsx.WithStack(&ValidationError{Errors: ValidationErrorUnverifiable, Inner: err}) - } - - var tokenRaw []byte - - if tokenRaw, err = jwe.Decrypt(key); err != nil { - return "", "", nil, errorsx.WithStack(&ValidationError{Errors: ValidationErrorMalformed, Inner: err}) - } - - tokenString = string(tokenRaw) - - var t *jwt.JSONWebToken - - if t, err = jwt.ParseSigned(tokenString, SignatureAlgorithmsNone); err != nil { - return "", "", nil, errorsx.WithStack(&ValidationError{Errors: ValidationErrorMalformed, Inner: err}) - } + } else if o.client != nil && o.client.IsClientSigned() { + if IsSignedJWTClientSecretAlg(alg) { + if kid != "" { + return errorsx.WithStack(&ValidationError{Errors: ValidationErrorHeaderKeyIDInvalid, Inner: fmt.Errorf("error validating the jws header: alg '%s' does not support tokens with a kid but the token has kid '%s'", alg, kid)}) + } - if err = headerValidateJWSNested(t.Headers, cty); err != nil { - return "", "", nil, errorsx.WithStack(&ValidationError{Errors: ValidationErrorMalformed, Inner: err}) + if key, err = NewJWKFromClientSecret(ctx, o.client, "", alg, consts.JSONWebTokenUseSignature); err != nil { + return errorsx.WithStack(&ValidationError{Errors: ValidationErrorUnverifiable, Inner: err}) + } + } else { + if key, err = FindClientPublicJWK(ctx, o.client, j.Config.GetJWKSFetcherStrategy(ctx), kid, alg, consts.JSONWebTokenUseSignature, true); err != nil { + return errorsx.WithStack(&ValidationError{Errors: ValidationErrorUnverifiable, Inner: err}) + } + } + } else if key, err = j.Issuer.GetIssuerStrictJWK(ctx, kid, alg, consts.JSONWebTokenUseSignature); err != nil { + return errorsx.WithStack(&ValidationError{Errors: ValidationErrorUnverifiable, Inner: err}) } - if signature, err = getJWTSignature(tokenString); err != nil { - return "", "", nil, errorsx.WithStack(&ValidationError{Errors: ValidationErrorMalformed, Inner: err}) + if err = t.Claims(getPublicJWK(key), &dest); err != nil { + return errorsx.WithStack(&ValidationError{Errors: ValidationErrorSignatureInvalid, Inner: err}) } - return tokenString, signature, jwe, nil + return nil } diff --git a/token/jwt/token.go b/token/jwt/token.go index 2f926c1e..cc57f69a 100644 --- a/token/jwt/token.go +++ b/token/jwt/token.go @@ -150,9 +150,9 @@ type Token struct { valid bool } -// Valid informs if the token was verified against a given verification key +// IsSignatureValid informs if the token was verified against a given verification key // and claims are valid -func (t *Token) Valid() bool { +func (t *Token) IsSignatureValid() bool { return t.valid } @@ -343,8 +343,44 @@ func (t *Token) CompactSignedString(k any) (tokenString string, err error) { return tokenString, nil } -func (t *Token) Validate() { +// Valid validates the token headers given various input options. This does not validate any claims. +func (t *Token) Valid(opts ...TokenValidationOption) (err error) { + vopts := &TokenValidationOptions{ + types: []string{consts.JSONWebTokenTypeJWT}, + } + + for _, opt := range opts { + opt(vopts) + } + vErr := new(ValidationError) + + if len(vopts.types) != 0 { + if !validateTokenType(vopts.types, t.Header) { + vErr.Inner = errors.New("token has an invalid typ") + vErr.Errors |= ValidationErrorHeaderTypeInvalid + } + } + + if len(vopts.alg) != 0 { + if vopts.alg != string(t.SignatureAlgorithm) { + vErr.Inner = errors.New("token has an invalid alg") + vErr.Errors |= ValidationErrorHeaderAlgorithmInvalid + } + } + + if len(vopts.kid) != 0 { + if vopts.kid != t.KeyID { + vErr.Inner = errors.New("token has an invalid kid") + vErr.Errors |= ValidationErrorHeaderKeyIDInvalid + } + } + + if vErr.valid() { + return nil + } + + return vErr } // IsJWTProfileAccessToken returns true if the token is a JWT Profile Access Token. @@ -377,6 +413,32 @@ func (t *Token) IsJWTProfileAccessToken() (ok bool) { return ok && (typ == consts.JSONWebTokenTypeAccessToken || typ == consts.JSONWebTokenTypeAccessTokenAlternative) } +type TokenValidationOption func(opts *TokenValidationOptions) + +type TokenValidationOptions struct { + types []string + alg string + kid string +} + +func ValidateTypes(types ...string) TokenValidationOption { + return func(validator *TokenValidationOptions) { + validator.types = types + } +} + +func ValidateAlgorithm(alg string) TokenValidationOption { + return func(validator *TokenValidationOptions) { + validator.alg = alg + } +} + +func ValidateKeyID(kid string) TokenValidationOption { + return func(validator *TokenValidationOptions) { + validator.kid = kid + } +} + func unsignedToken(token *Token) (tokenString string, err error) { token.Header[consts.JSONWebTokenHeaderAlgorithm] = consts.JSONWebTokenAlgNone @@ -435,3 +497,27 @@ func pointer(v any) any { } return v } + +func validateTokenType(typValues []string, header map[string]any) bool { + var ( + typ string + raw any + ok bool + ) + + if raw, ok = header[consts.JSONWebTokenHeaderType]; !ok { + return false + } + + if typ, ok = raw.(string); !ok { + return false + } + + for _, t := range typValues { + if t == typ { + return true + } + } + + return false +} diff --git a/token/jwt/token_test.go b/token/jwt/token_test.go index 3a0d0d4e..00cb9d89 100644 --- a/token/jwt/token_test.go +++ b/token/jwt/token_test.go @@ -428,7 +428,7 @@ func TestParser_Parse(t *testing.T) { t.Errorf("[%v] Invalid token passed validation", data.name) } - if (err == nil && !token.Valid()) || (err != nil && token.Valid()) { + if (err == nil && !token.IsSignatureValid()) || (err != nil && token.IsSignatureValid()) { t.Errorf("[%v] Inconsistent behavior between returned error and token.Valid", data.name) } diff --git a/token/jwt/util.go b/token/jwt/util.go index 81cca2fa..c0f69c66 100644 --- a/token/jwt/util.go +++ b/token/jwt/util.go @@ -14,7 +14,7 @@ import ( ) var ( - reSignedJWT = regexp.MustCompile(`^[-_A-Za-z0-9]+\.[-_A-Za-z0-9]+\.[-_A-Za-z0-9]+$`) + reSignedJWT = regexp.MustCompile(`^[-_A-Za-z0-9]+\.[-_A-Za-z0-9]+\.([-_A-Za-z0-9]+)?$`) reEncryptedJWT = regexp.MustCompile(`^[-_A-Za-z0-9]+\.[-_A-Za-z0-9]+\.[-_A-Za-z0-9]+\.[-_A-Za-z0-9]+\.[-_A-Za-z0-9]+$`) ) @@ -180,8 +180,13 @@ type PrivateKey interface { Equal(x crypto.PrivateKey) bool } +const ( + JWKLookupErrorClientNoJWKS uint32 = 1 << iota +) + type JWKLookupError struct { Description string + Errors uint32 // bitfield. see JWKLookupError... constants } func (e *JWKLookupError) GetDescription() string { @@ -189,7 +194,7 @@ func (e *JWKLookupError) GetDescription() string { } func (e *JWKLookupError) Error() string { - return fmt.Sprintf("Error occurred retriving the JSON Web Key. %s", e.Description) + return fmt.Sprintf("Error occurred retrieving the JSON Web Key. %s", e.Description) } // FindClientPublicJWK given a BaseClient, JWKSFetcherStrategy, and search parameters will return a *jose.JSONWebKey on @@ -237,7 +242,7 @@ func SearchJWKS(jwks *jose.JSONWebKeySet, kid, alg, use string, strict bool) (ke } if len(keys) == 0 { - return nil, &JWKLookupError{Description: fmt.Sprintf("The JSON Web Token uses signing key with kid '%s' which was not found.", kid)} + return nil, &JWKLookupError{Description: fmt.Sprintf("The JSON Web Token uses signing key with kid '%s' which was not found", kid)} } var matched []jose.JSONWebKey @@ -258,10 +263,10 @@ func SearchJWKS(jwks *jose.JSONWebKeySet, kid, alg, use string, strict bool) (ke case 1: return &matched[0], nil case 0: - return nil, &JWKLookupError{Description: fmt.Sprintf("Unable to find JSON web key with kid '%s', use '%s', and alg '%s' in JSON Web Key Set.", kid, use, alg)} + return nil, &JWKLookupError{Description: fmt.Sprintf("Unable to find JSON web key with kid '%s', use '%s', and alg '%s' in JSON Web Key Set", kid, use, alg)} default: if strict { - return nil, &JWKLookupError{Description: fmt.Sprintf("Unable to find JSON web key with kid '%s', use '%s', and alg '%s' in JSON Web Key Set.", kid, use, alg)} + return nil, &JWKLookupError{Description: fmt.Sprintf("Unable to find JSON web key with kid '%s', use '%s', and alg '%s' in JSON Web Key Set", kid, use, alg)} } return &matched[0], nil @@ -276,15 +281,15 @@ func NewJWKFromClientSecret(ctx context.Context, client BaseClient, kid, alg, us ) if secret, ok, err = client.GetClientSecretPlainText(); err != nil { - return nil, &JWKLookupError{Description: fmt.Sprintf("The client returned an error while trying to retrieve the plaintext client secret: %s.", err.Error())} + return nil, &JWKLookupError{Description: fmt.Sprintf("The client returned an error while trying to retrieve the plaintext client secret. %s", err.Error())} } if !ok { - return nil, &JWKLookupError{Description: "The client is not configured with a client secret."} + return nil, &JWKLookupError{Description: "The client is not configured with a client secret"} } if len(secret) == 0 { - return nil, &JWKLookupError{Description: "The client is not configured with a client secret that can be used for symmetric algorithms."} + return nil, &JWKLookupError{Description: "The client is not configured with a client secret that can be used for symmetric algorithms"} } return &jose.JSONWebKey{ @@ -332,3 +337,24 @@ func assign(a, b map[string]any) map[string]any { } return a } + +func getPublicJWK(jwk *jose.JSONWebKey) jose.JSONWebKey { + if jwk == nil { + return jose.JSONWebKey{} + } + + if _, ok := jwk.Key.([]byte); ok && IsSignedJWTClientSecretAlg(jwk.Algorithm) { + return jose.JSONWebKey{ + KeyID: jwk.KeyID, + Key: jwk.Key, + Algorithm: jwk.Algorithm, + Use: jwk.Use, + Certificates: jwk.Certificates, + CertificatesURL: jwk.CertificatesURL, + CertificateThumbprintSHA1: jwk.CertificateThumbprintSHA1, + CertificateThumbprintSHA256: jwk.CertificateThumbprintSHA256, + } + } + + return jwk.Public() +} diff --git a/token/jwt/validation_error.go b/token/jwt/validation_error.go index 73473e54..548c581e 100644 --- a/token/jwt/validation_error.go +++ b/token/jwt/validation_error.go @@ -5,29 +5,24 @@ package jwt // Validation provides a backwards compatible error definition // from `jwt-go` to `go-jose`. -// The sourcecode was taken from https://github.com/dgrijalva/jwt-go/blob/master/errors.go -// -// > The errors that might occur when parsing and validating a token const ( - ValidationErrorMalformed uint32 = 1 << iota // Token is malformed - ValidationErrorUnverifiable // Token could not be verified because of signing problems - ValidationErrorSignatureInvalid // Signature validation failed - - // Standard Claim validation errors - ValidationErrorId // Claim JTI validation failed - ValidationErrorAudience // Claim AUD validation failed - ValidationErrorExpired // Claim EXP validation failed - ValidationErrorIssuedAt // Claim IAT validation failed - ValidationErrorNotValidYet // Claim NBF validation failed - ValidationErrorIssuer // Claim ISS validation failed - ValidationErrorSubject // Claim SUB validation failed - ValidationErrorTypInvalid // Header TYP invalid error - ValidationErrorKeyIDInvalid // Header KID invalid error - ValidationErrorAlgorithmInvalid // Header ALG invalid error - ValidationErrorClaimsInvalid // Generic claims validation error + ValidationErrorMalformed uint32 = 1 << iota // Token is malformed + ValidationErrorUnverifiable // Token could not be verified because of signing problems + ValidationErrorSignatureInvalid // Signature validation failed. + ValidationErrorHeaderKeyIDInvalid // Header KID invalid error. + ValidationErrorHeaderAlgorithmInvalid // Header ALG invalid error. + ValidationErrorHeaderTypeInvalid // Header TYP invalid error + ValidationErrorId // Claim JTI validation failed + ValidationErrorAudience // Claim AUD validation failed + ValidationErrorExpired // Claim EXP validation failed + ValidationErrorIssuedAt // Claim IAT validation failed + ValidationErrorNotValidYet // Claim NBF validation failed + ValidationErrorIssuer // Claim ISS validation failed + ValidationErrorSubject // Claim SUB validation failed + ValidationErrorClaimsInvalid // Generic claims validation error ) -// The error from Parse if token is not valid +// The ValidationError is an error implementation from Parse if token is not valid. type ValidationError struct { Inner error // stores the error returned by external dependencies, i.e.: KeyFunc Errors uint32 // bitfield. see ValidationError... constants diff --git a/token/jwt/validator.go b/token/jwt/validator.go index 4e7f61fa..0efde232 100644 --- a/token/jwt/validator.go +++ b/token/jwt/validator.go @@ -1,11 +1,6 @@ package jwt -import ( - "errors" - - "authelia.com/provider/oauth2/internal/consts" -) - +/* func NewValidator(opts ...ValidatorOpt) (validator *Validator) { validator = &Validator{ types: []string{consts.JSONWebTokenTypeJWT}, @@ -23,24 +18,6 @@ func NewValidator(opts ...ValidatorOpt) (validator *Validator) { type ValidatorOpt func(*Validator) -func ValidateTypes(types ...string) ValidatorOpt { - return func(validator *Validator) { - validator.types = types - } -} - -func ValidateAlgorithm(alg string) ValidatorOpt { - return func(validator *Validator) { - validator.alg = alg - } -} - -func ValidateKeyID(kid string) ValidatorOpt { - return func(validator *Validator) { - validator.kid = kid - } -} - func ValidateIssuer(iss string) ValidatorOpt { return func(validator *Validator) { validator.iss = iss @@ -95,12 +72,15 @@ func ValidateIssuedAt(iat int64) ValidatorOpt { } } -func ValidateRequireIssuedAt() ValidatorOpt { +func ValidateRequireIssuedAt() TokenValidationOption { return func(validator *Validator) { validator.requireIAT = true } } +*/ + +/* type Validator struct { types []string alg string @@ -124,21 +104,21 @@ func (v Validator) Validate(token *Token) (err error) { if len(v.types) != 0 { if !validateTokenType(v.types, token.Header) { vErr.Inner = errors.New("token has an invalid typ") - vErr.Errors |= ValidationErrorTypInvalid + vErr.Errors |= ValidationErrorHeaderTypeInvalid } } if len(v.alg) != 0 { if v.alg != string(token.SignatureAlgorithm) { vErr.Inner = errors.New("token has an invalid alg") - vErr.Errors |= ValidationErrorAlgorithmInvalid + vErr.Errors |= ValidationErrorHeaderAlgorithmInvalid } } if len(v.kid) != 0 { if v.kid != token.KeyID { vErr.Inner = errors.New("token has an invalid kid") - vErr.Errors |= ValidationErrorKeyIDInvalid + vErr.Errors |= ValidationErrorHeaderKeyIDInvalid } } @@ -216,26 +196,4 @@ func (v Validator) Validate(token *Token) (err error) { return vErr } -func validateTokenType(typValues []string, header map[string]any) bool { - var ( - typ string - raw any - ok bool - ) - - if raw, ok = header[consts.JSONWebTokenHeaderType]; !ok { - return false - } - - if typ, ok = raw.(string); !ok { - return false - } - - for _, t := range typValues { - if t == typ { - return true - } - } - - return false -} +*/ From 17f8056ab29755cec104e5a7bdf49a78a6b316a3 Mon Sep 17 00:00:00 2001 From: James Elliott Date: Sun, 22 Sep 2024 13:55:28 +1000 Subject: [PATCH 08/33] temp --- handler/oauth2/introspector_jwt.go | 4 +- handler/oauth2/introspector_jwt_test.go | 128 ++++++++++++-------- handler/oauth2/strategy_jwt_profile.go | 109 ++++++++++++----- handler/oauth2/strategy_jwt_profile_test.go | 41 ++++++- token/jwt/client.go | 31 ++--- token/jwt/client_test.go | 9 +- token/jwt/jwt_strategy.go | 2 +- 7 files changed, 218 insertions(+), 106 deletions(-) diff --git a/handler/oauth2/introspector_jwt.go b/handler/oauth2/introspector_jwt.go index 0db62d1c..2b55be38 100644 --- a/handler/oauth2/introspector_jwt.go +++ b/handler/oauth2/introspector_jwt.go @@ -27,8 +27,8 @@ func (v *StatelessJWTValidator) IntrospectToken(ctx context.Context, tokenString return "", err } - if !token.IsJWTProfileAccessToken() { - return "", errorsx.WithStack(oauth2.ErrRequestUnauthorized.WithDebug("The provided token is not a valid RFC9068 JWT Profile Access Token as it is missing the header 'typ' value of 'at+jwt' ")) + if err = token.Valid(jwt.ValidateTypes(consts.JSONWebTokenTypeAccessToken, consts.JSONWebTokenTypeAccessTokenAlternative)); err != nil { + return "", errorsx.WithStack(oauth2.ErrRequestUnauthorized.WithDebug("The provided token is not a valid RFC9068 JWT Profile Access Token as it is missing the header 'typ' value of 'at+jwt'.")) } r := AccessTokenJWTToRequest(token) diff --git a/handler/oauth2/introspector_jwt_test.go b/handler/oauth2/introspector_jwt_test.go index 9398e1ef..c3563a73 100644 --- a/handler/oauth2/introspector_jwt_test.go +++ b/handler/oauth2/introspector_jwt_test.go @@ -6,9 +6,9 @@ package oauth2 import ( "context" "encoding/base64" - "fmt" "strings" "testing" + "time" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -33,87 +33,115 @@ func TestIntrospectJWT(t *testing.T) { Config: config, } - var v = &StatelessJWTValidator{ + validator := &StatelessJWTValidator{ Strategy: strategy, Config: &oauth2.Config{ ScopeStrategy: oauth2.HierarchicScopeStrategy, }, } - for k, c := range []struct { - description string - token func() string - expectErr error - scopes []string + testCases := []struct { + name string + token func(t *testing.T) string + err error + expected string + scopes []string }{ { - description: "should fail because jwt is expired", - token: func() string { - jwt := jwtExpiredCase(oauth2.AccessToken) - token, _, err := strategy.GenerateAccessToken(context.TODO(), jwt) - assert.NoError(t, err) - return token + name: "ShouldFailTokenExpired", + token: func(t *testing.T) string { + token := jwtExpiredCase(oauth2.AccessToken, time.Unix(1726972738, 0)) + tokenString, _, err := strategy.GenerateAccessToken(context.TODO(), token) + require.NoError(t, err) + + return tokenString }, - expectErr: oauth2.ErrTokenExpired, + err: oauth2.ErrTokenExpired, + expected: "Token expired. The token expired. Token expired at 1726969138.", }, { - description: "should pass because scope was granted", - token: func() string { - jwt := jwtValidCase(oauth2.AccessToken) - jwt.GrantedScope = []string{"foo", "bar"} - token, _, err := strategy.GenerateAccessToken(context.TODO(), jwt) - assert.NoError(t, err) - return token + name: "ShouldPassScopeGranted", + token: func(t *testing.T) string { + token := jwtValidCase(oauth2.AccessToken) + token.GrantedScope = []string{"foo", "bar"} + tokenString, _, err := strategy.GenerateAccessToken(context.TODO(), token) + + require.NoError(t, err) + + return tokenString }, scopes: []string{"foo"}, }, { - description: "should fail because scope was not granted", - token: func() string { - jwt := jwtValidCase(oauth2.AccessToken) - token, _, err := strategy.GenerateAccessToken(context.TODO(), jwt) - assert.NoError(t, err) - return token + name: "ShouldFailWrongTyp", + token: func(t *testing.T) string { + token := jwtInvalidTypCase(oauth2.AccessToken) + token.GrantedScope = []string{"foo", "bar"} + tokenString, _, err := strategy.GenerateAccessToken(context.TODO(), token) + + require.NoError(t, err) + + return tokenString }, - scopes: []string{"foo"}, - expectErr: oauth2.ErrInvalidScope, + scopes: []string{"foo"}, + err: oauth2.ErrRequestUnauthorized, + expected: "The request could not be authorized. Check that you provided valid credentials in the right format. The provided token is not a valid RFC9068 JWT Profile Access Token as it is missing the header 'typ' value of 'at+jwt'.", }, { - description: "should fail because signature is invalid", - token: func() string { - jwt := jwtValidCase(oauth2.AccessToken) - token, _, err := strategy.GenerateAccessToken(context.TODO(), jwt) - assert.NoError(t, err) - parts := strings.Split(token, ".") - require.Len(t, parts, 3, "%s - %v", token, parts) + name: "ShouldFailScopeNotGranted", + token: func(t *testing.T) string { + token := jwtValidCase(oauth2.AccessToken) + tokenString, _, err := strategy.GenerateAccessToken(context.TODO(), token) + require.NoError(t, err) + + return tokenString + }, + scopes: []string{"foo"}, + err: oauth2.ErrInvalidScope, + expected: "The requested scope is invalid, unknown, or malformed. The request scope 'foo' has not been granted or is not allowed to be requested.", + }, + { + name: "ShouldFailInvalidSignature", + token: func(t *testing.T) string { + token := jwtValidCase(oauth2.AccessToken) + tokenString, _, err := strategy.GenerateAccessToken(context.TODO(), token) + require.NoError(t, err) + parts := strings.Split(tokenString, ".") + require.Len(t, parts, 3, "%s - %v", tokenString, parts) dec, err := base64.RawURLEncoding.DecodeString(parts[1]) assert.NoError(t, err) s := strings.ReplaceAll(string(dec), "peter", "piper") parts[1] = base64.RawURLEncoding.EncodeToString([]byte(s)) + return strings.Join(parts, ".") }, - expectErr: oauth2.ErrTokenSignatureMismatch, + err: oauth2.ErrTokenSignatureMismatch, + expected: "Token signature mismatch. Check that you provided a valid token in the right format. Token has an invalid signature.", }, { - description: "should pass", - token: func() string { - jwt := jwtValidCase(oauth2.AccessToken) - token, _, err := strategy.GenerateAccessToken(context.TODO(), jwt) - assert.NoError(t, err) - return token + name: "ShouldPass", + token: func(t *testing.T) string { + token := jwtValidCase(oauth2.AccessToken) + tokenString, _, err := strategy.GenerateAccessToken(context.TODO(), token) + require.NoError(t, err) + + return tokenString }, }, - } { - t.Run(fmt.Sprintf("case=%d:%v", k, c.description), func(t *testing.T) { - if c.scopes == nil { - c.scopes = []string{} + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + if tc.scopes == nil { + tc.scopes = []string{} } areq := oauth2.NewAccessRequest(nil) - _, err := v.IntrospectToken(context.TODO(), c.token(), oauth2.AccessToken, areq, c.scopes) + _, err := validator.IntrospectToken(context.TODO(), tc.token(t), oauth2.AccessToken, areq, tc.scopes) - if c.expectErr != nil { - require.EqualError(t, err, c.expectErr.Error()) + if tc.err != nil { + assert.EqualError(t, err, tc.err.Error()) + assert.EqualError(t, oauth2.ErrorToDebugRFC6749Error(err), tc.expected) } else { require.NoError(t, err) assert.Equal(t, "peter", areq.Session.GetSubject()) diff --git a/handler/oauth2/strategy_jwt_profile.go b/handler/oauth2/strategy_jwt_profile.go index 4d688803..95608478 100644 --- a/handler/oauth2/strategy_jwt_profile.go +++ b/handler/oauth2/strategy_jwt_profile.go @@ -5,6 +5,7 @@ package oauth2 import ( "context" + "fmt" "strings" "time" @@ -13,7 +14,6 @@ import ( "authelia.com/provider/oauth2" "authelia.com/provider/oauth2/internal/consts" "authelia.com/provider/oauth2/token/jwt" - "authelia.com/provider/oauth2/x/errorsx" ) // JWTProfileCoreStrategy is a JWT RS256 strategy. @@ -174,39 +174,86 @@ func (s *JWTProfileCoreStrategy) GenerateJWT(ctx context.Context, tokenType oaut return s.Strategy.Encode(ctx, jwt.WithClaims(claims.ToMapClaims()), jwt.WithHeaders(header), jwt.WithJWTProfileAccessTokenClient(client)) } -func validateJWT(ctx context.Context, jwtStrategy jwt.Strategy, client jwt.Client, token string) (t *jwt.Token, err error) { - t, err = jwtStrategy.Decode(ctx, token, jwt.WithClient(client)) - if err == nil { - err = t.Claims.Valid() - return +func validateJWT(ctx context.Context, strategy jwt.Strategy, client jwt.Client, tokenString string) (token *jwt.Token, err error) { + if token, err = strategy.Decode(ctx, tokenString, jwt.WithClient(client)); err != nil { + return token, fmtValidateJWTError(token, client, err) + } + + if err = token.Claims.Valid(); err != nil { + return token, fmtValidateJWTError(token, client, err) } - var e *jwt.ValidationError - if err != nil && errors.As(err, &e) { - err = errorsx.WithStack(toRFCErr(e).WithWrap(err).WithDebugError(err)) + return token, nil +} + +func fmtValidateJWTError(token *jwt.Token, client jwt.Client, inner error) (err error) { + var ( + clientText string + skid, salg string + ) + + if client != nil { + clientText = fmt.Sprintf("provided by client with id '%s' ", client.GetID()) + skid, salg = client.GetSigningKeyID(), client.GetSigningAlg() } - return -} - -func toRFCErr(v *jwt.ValidationError) *oauth2.RFC6749Error { - switch { - case v == nil: - return nil - case v.Has(jwt.ValidationErrorMalformed): - return oauth2.ErrInvalidTokenFormat - case v.Has(jwt.ValidationErrorUnverifiable | jwt.ValidationErrorSignatureInvalid): - return oauth2.ErrTokenSignatureMismatch - case v.Has(jwt.ValidationErrorExpired): - return oauth2.ErrTokenExpired - case v.Has(jwt.ValidationErrorAudience | - jwt.ValidationErrorIssuedAt | - jwt.ValidationErrorIssuer | - jwt.ValidationErrorNotValidYet | - jwt.ValidationErrorId | - jwt.ValidationErrorClaimsInvalid): - return oauth2.ErrTokenClaim - default: - return oauth2.ErrRequestUnauthorized + if errJWTValidation := new(jwt.ValidationError); errors.As(inner, &errJWTValidation) { + switch { + case errJWTValidation.Has(jwt.ValidationErrorHeaderKeyIDInvalid): + return oauth2.ErrInvalidTokenFormat.WithDebugf("Token %sis expected to be signed with the 'kid' value '%s' but it was signed with the 'kid' value '%s'.", clientText, skid, token.KeyID) + case errJWTValidation.Has(jwt.ValidationErrorHeaderAlgorithmInvalid): + return oauth2.ErrInvalidTokenFormat.WithDebugf("Token %sis expected to be signed with the 'alg' value '%s' but it was signed with the 'alg' value '%s'.", clientText, salg, token.SignatureAlgorithm) + case errJWTValidation.Has(jwt.ValidationErrorHeaderTypeInvalid): + return oauth2.ErrInvalidTokenFormat.WithDebugf("Token %sis expected to be signed with the 'typ' value '%s' but it was signed with the 'typ' value '%s'.", clientText, consts.JSONWebTokenTypeJWT, token.Header[consts.JSONWebTokenHeaderType]) + case errJWTValidation.Has(jwt.ValidationErrorMalformed): + return oauth2.ErrInvalidTokenFormat.WithDebugf("Token %sis malformed. %s.", clientText, strings.TrimPrefix(errJWTValidation.Error(), "go-jose/go-jose: ")) + case errJWTValidation.Has(jwt.ValidationErrorUnverifiable): + return oauth2.ErrTokenSignatureMismatch.WithDebugf("Token %sis not able to be verified. %s.", clientText, strings.TrimPrefix(errJWTValidation.Error(), "go-jose/go-jose: ")) + case errJWTValidation.Has(jwt.ValidationErrorSignatureInvalid): + return oauth2.ErrTokenSignatureMismatch.WithDebugf("Token %shas an invalid signature.", clientText) + case errJWTValidation.Has(jwt.ValidationErrorExpired): + exp, ok := token.Claims.GetExpiresAt() + if ok { + return oauth2.ErrTokenExpired.WithDebugf("Token %sexpired at %d.", clientText, exp) + } else { + return oauth2.ErrTokenExpired.WithDebugf("Token %sdoes not have an 'exp' claim or it has an invalid type.", clientText) + } + case errJWTValidation.Has(jwt.ValidationErrorIssuedAt): + iat, ok := token.Claims.GetIssuedAt() + if ok { + return oauth2.ErrTokenClaim.WithDebugf("Token %sis issued in the future. The token was issued at %d.", clientText, iat) + } else { + return oauth2.ErrTokenClaim.WithDebugf("Token %sis issued in the future. The token does not have an 'iat' claim or it has an invalid type.", clientText) + } + case errJWTValidation.Has(jwt.ValidationErrorNotValidYet): + nbf, ok := token.Claims.GetNotBefore() + if ok { + return oauth2.ErrTokenClaim.WithDebugf("Token %sis not valid yet. The token is not valid before %d.", clientText, nbf) + } else { + return oauth2.ErrTokenClaim.WithDebugf("Token %sis not valid yet. The token does not have an 'nbf' claim or it has an invalid type.", clientText) + } + case errJWTValidation.Has(jwt.ValidationErrorIssuer): + iss, ok := token.Claims.GetIssuer() + if ok { + return oauth2.ErrTokenClaim.WithDebugf("Token %shas an invalid issuer. The token was expected to have an 'iss' claim with one of the following values: ''. The 'iss' claim has a value of '%s'.", clientText, iss) + } else { + return oauth2.ErrTokenClaim.WithDebugf("Token %shas an invalid issuer. The token does not have an 'iss' claim or it has an invalid type.", clientText) + } + case errJWTValidation.Has(jwt.ValidationErrorAudience): + aud, ok := token.Claims.GetAudience() + if ok { + return oauth2.ErrTokenClaim.WithDebugf("Token %shas an invalid audience. The token was expected to have an 'iss' claim with one of the following values: ''. The 'iss' claim has a value of '%s'.", clientText, aud) + } else { + return oauth2.ErrTokenClaim.WithDebugf("Token %shas an invalid audience. The token does not have an 'iss' claim or it has an invalid type.", clientText) + } + case errJWTValidation.Has(jwt.ValidationErrorClaimsInvalid): + return oauth2.ErrTokenClaim.WithDebugf("Token %shas invalid claims. Error occurred trying to validate the request objects claims: %s", clientText, strings.TrimPrefix(errJWTValidation.Error(), "go-jose/go-jose: ")) + default: + return oauth2.ErrTokenClaim.WithDebugf("Token %scould not be validated. Error occurred trying to validate the token: %s", clientText, strings.TrimPrefix(errJWTValidation.Error(), "go-jose/go-jose: ")) + } + } else if errJWKLookup := new(jwt.JWKLookupError); errors.As(inner, &errJWKLookup) { + return oauth2.ErrRequestUnauthorized.WithDebugf("Token %scould not be validated due to a key lookup error. %s.", clientText, errJWKLookup.Description) + } else { + return oauth2.ErrRequestUnauthorized.WithDebugf("Token %scould not be validated. %s", clientText, oauth2.ErrorToDebugRFC6749Error(inner).Error()) } } diff --git a/handler/oauth2/strategy_jwt_profile_test.go b/handler/oauth2/strategy_jwt_profile_test.go index 163e2ad8..a29476c6 100644 --- a/handler/oauth2/strategy_jwt_profile_test.go +++ b/handler/oauth2/strategy_jwt_profile_test.go @@ -55,6 +55,35 @@ var jwtValidCase = func(tokenType oauth2.TokenType) *oauth2.Request { return r } +var jwtInvalidTypCase = func(tokenType oauth2.TokenType) *oauth2.Request { + r := &oauth2.Request{ + Client: &oauth2.DefaultClient{ + ClientSecret: mustNewBCryptClientSecretPlain("foobarfoobarfoobarfoobar"), + }, + Session: &JWTSession{ + JWTClaims: &jwt.JWTClaims{ + Issuer: "oauth2", + Subject: "peter", + IssuedAt: time.Now().UTC(), + NotBefore: time.Now().UTC(), + Extra: map[string]any{"foo": "bar"}, + }, + JWTHeader: &jwt.Headers{ + Extra: map[string]any{consts.JSONWebTokenHeaderType: consts.JSONWebTokenTypeJWT}, + }, + ExpiresAt: map[oauth2.TokenType]time.Time{ + tokenType: time.Now().UTC().Add(time.Hour), + }, + }, + } + r.SetRequestedScopes([]string{consts.ScopeEmail, consts.ScopeOffline}) + r.GrantScope(consts.ScopeEmail) + r.GrantScope(consts.ScopeOffline) + r.SetRequestedAudience([]string{"group0"}) + r.GrantAudience("group0") + return r +} + var jwtValidCaseWithZeroRefreshExpiry = func(tokenType oauth2.TokenType) *oauth2.Request { r := &oauth2.Request{ Client: &oauth2.DefaultClient{ @@ -118,7 +147,7 @@ var jwtValidCaseWithRefreshExpiry = func(tokenType oauth2.TokenType) *oauth2.Req // returns an expired JWT type. The JWTClaims.ExpiresAt time is intentionally // left empty to ensure it is pulled from the session's ExpiresAt map for // the given oauth2.TokenType. -var jwtExpiredCase = func(tokenType oauth2.TokenType) *oauth2.Request { +var jwtExpiredCase = func(tokenType oauth2.TokenType, now time.Time) *oauth2.Request { r := &oauth2.Request{ Client: &oauth2.DefaultClient{ ClientSecret: mustNewBCryptClientSecretPlain("foobarfoobarfoobarfoobar"), @@ -127,16 +156,16 @@ var jwtExpiredCase = func(tokenType oauth2.TokenType) *oauth2.Request { JWTClaims: &jwt.JWTClaims{ Issuer: "oauth2", Subject: "peter", - IssuedAt: time.Now().UTC().Add(-time.Minute * 10), - NotBefore: time.Now().UTC().Add(-time.Minute * 10), - ExpiresAt: time.Now().UTC().Add(-time.Minute), + IssuedAt: now.UTC().Add(-time.Minute * 10), + NotBefore: now.UTC().Add(-time.Minute * 10), + ExpiresAt: now.UTC().Add(-time.Minute), Extra: map[string]any{"foo": "bar"}, }, JWTHeader: &jwt.Headers{ Extra: make(map[string]any), }, ExpiresAt: map[oauth2.TokenType]time.Time{ - tokenType: time.Now().UTC().Add(-time.Hour), + tokenType: now.UTC().Add(-time.Hour), }, }, } @@ -163,7 +192,7 @@ func TestAccessToken(t *testing.T) { pass: true, }, { - r: jwtExpiredCase(oauth2.AccessToken), + r: jwtExpiredCase(oauth2.AccessToken, time.Unix(1726972738, 0)), pass: false, }, { diff --git a/token/jwt/client.go b/token/jwt/client.go index 3008117a..bca5214d 100644 --- a/token/jwt/client.go +++ b/token/jwt/client.go @@ -70,8 +70,8 @@ func NewStatelessJWTProfileIntrospectionClient(client any) Client { } type Client interface { - GetSignatureKeyID() (kid string) - GetSignatureAlg() (alg string) + GetSigningKeyID() (kid string) + GetSigningAlg() (alg string) GetEncryptionKeyID() (kid string) GetEncryptionAlg() (alg string) GetEncryptionEnc() (enc string) @@ -82,6 +82,9 @@ type Client interface { } type BaseClient interface { + // GetID returns the client ID. + GetID() string + // GetClientSecretPlainText returns the ClientSecret as plaintext if available. The semantics of this function // return values are important. // If the client is not configured with a secret the return should be: @@ -146,11 +149,11 @@ type decoratedJARClient struct { JARClient } -func (r *decoratedJARClient) GetSignatureKeyID() (kid string) { +func (r *decoratedJARClient) GetSigningKeyID() (kid string) { return r.GetRequestObjectSigningKeyID() } -func (r *decoratedJARClient) GetSignatureAlg() (alg string) { +func (r *decoratedJARClient) GetSigningAlg() (alg string) { return r.GetRequestObjectSigningAlg() } @@ -208,11 +211,11 @@ type decoratedIDTokenClient struct { IDTokenClient } -func (r *decoratedIDTokenClient) GetSignatureKeyID() (kid string) { +func (r *decoratedIDTokenClient) GetSigningKeyID() (kid string) { return r.GetIDTokenSignedResponseKeyID() } -func (r *decoratedIDTokenClient) GetSignatureAlg() (alg string) { +func (r *decoratedIDTokenClient) GetSigningAlg() (alg string) { return r.GetIDTokenSignedResponseAlg() } @@ -270,11 +273,11 @@ type decoratedJARMClient struct { JARMClient } -func (r *decoratedJARMClient) GetSignatureKeyID() (kid string) { +func (r *decoratedJARMClient) GetSigningKeyID() (kid string) { return r.GetAuthorizationSignedResponseKeyID() } -func (r *decoratedJARMClient) GetSignatureAlg() (alg string) { +func (r *decoratedJARMClient) GetSigningAlg() (alg string) { return r.GetAuthorizationSignedResponseAlg() } @@ -330,11 +333,11 @@ type decoratedUserInfoClient struct { UserInfoClient } -func (r *decoratedUserInfoClient) GetSignatureKeyID() (kid string) { +func (r *decoratedUserInfoClient) GetSigningKeyID() (kid string) { return r.GetUserinfoSignedResponseKeyID() } -func (r *decoratedUserInfoClient) GetSignatureAlg() (alg string) { +func (r *decoratedUserInfoClient) GetSigningAlg() (alg string) { return r.GetUserinfoSignedResponseAlg() } @@ -390,11 +393,11 @@ type decoratedJWTProfileAccessTokenClient struct { JWTProfileAccessTokenClient } -func (r *decoratedJWTProfileAccessTokenClient) GetSignatureKeyID() (kid string) { +func (r *decoratedJWTProfileAccessTokenClient) GetSigningKeyID() (kid string) { return r.GetAccessTokenSignedResponseKeyID() } -func (r *decoratedJWTProfileAccessTokenClient) GetSignatureAlg() (alg string) { +func (r *decoratedJWTProfileAccessTokenClient) GetSigningAlg() (alg string) { return r.GetAccessTokenSignedResponseAlg() } @@ -452,11 +455,11 @@ type decoratedIntrospectionClient struct { IntrospectionClient } -func (r *decoratedIntrospectionClient) GetSignatureKeyID() (kid string) { +func (r *decoratedIntrospectionClient) GetSigningKeyID() (kid string) { return r.GetIntrospectionSignedResponseKeyID() } -func (r *decoratedIntrospectionClient) GetSignatureAlg() (alg string) { +func (r *decoratedIntrospectionClient) GetSigningAlg() (alg string) { return r.GetIntrospectionSignedResponseAlg() } diff --git a/token/jwt/client_test.go b/token/jwt/client_test.go index 29d3f7c6..91e90280 100644 --- a/token/jwt/client_test.go +++ b/token/jwt/client_test.go @@ -7,6 +7,7 @@ import ( ) type testClient struct { + id string secret []byte secretNotPlainText bool secretNotDefined bool @@ -17,6 +18,10 @@ type testClient struct { jwksURI string } +func (r *testClient) GetID() string { + return r.id +} + func (r *testClient) GetClientSecretPlainText() (secret []byte, ok bool, err error) { if r.secretNotDefined { return nil, false, nil @@ -33,11 +38,11 @@ func (r *testClient) GetClientSecretPlainText() (secret []byte, ok bool, err err return nil, true, fmt.Errorf("not supported") } -func (r *testClient) GetSignatureKeyID() (kid string) { +func (r *testClient) GetSigningKeyID() (kid string) { return r.kid } -func (r *testClient) GetSignatureAlg() (alg string) { +func (r *testClient) GetSigningAlg() (alg string) { return r.alg } diff --git a/token/jwt/jwt_strategy.go b/token/jwt/jwt_strategy.go index 2bf35db4..ba759c77 100644 --- a/token/jwt/jwt_strategy.go +++ b/token/jwt/jwt_strategy.go @@ -67,7 +67,7 @@ func (j *DefaultStrategy) Encode(ctx context.Context, opts ...StrategyOpt) (toke if keySig, err = j.Issuer.GetIssuerJWK(ctx, "", string(jose.RS256), consts.JSONWebTokenUseSignature); err != nil { return "", "", errorsx.WithStack(fmt.Errorf("error occurred retrieving issuer jwk: %w", err)) } - } else if keySig, err = j.Issuer.GetIssuerJWK(ctx, o.client.GetSignatureKeyID(), o.client.GetSignatureAlg(), consts.JSONWebTokenUseSignature); err != nil { + } else if keySig, err = j.Issuer.GetIssuerJWK(ctx, o.client.GetSigningKeyID(), o.client.GetSigningAlg(), consts.JSONWebTokenUseSignature); err != nil { return "", "", errorsx.WithStack(fmt.Errorf("error occurred retrieving issuer jwk: %w", err)) } From 6547ae2aabc2fb16173e9ee3420213e59a278be1 Mon Sep 17 00:00:00 2001 From: James Elliott Date: Sun, 22 Sep 2024 14:25:03 +1000 Subject: [PATCH 09/33] temp --- handler/oauth2/introspector_jwt_test.go | 2 +- internal/mock/client_secret.go | 83 +++++++++++++++++++++++++ internal/randx/sequence.go | 20 ++++-- token/jwt/token.go | 5 ++ 4 files changed, 104 insertions(+), 6 deletions(-) create mode 100644 internal/mock/client_secret.go diff --git a/handler/oauth2/introspector_jwt_test.go b/handler/oauth2/introspector_jwt_test.go index c3563a73..9f9d5d69 100644 --- a/handler/oauth2/introspector_jwt_test.go +++ b/handler/oauth2/introspector_jwt_test.go @@ -109,7 +109,7 @@ func TestIntrospectJWT(t *testing.T) { parts := strings.Split(tokenString, ".") require.Len(t, parts, 3, "%s - %v", tokenString, parts) dec, err := base64.RawURLEncoding.DecodeString(parts[1]) - assert.NoError(t, err) + require.NoError(t, err) s := strings.ReplaceAll(string(dec), "peter", "piper") parts[1] = base64.RawURLEncoding.EncodeToString([]byte(s)) diff --git a/internal/mock/client_secret.go b/internal/mock/client_secret.go new file mode 100644 index 00000000..6d430bf3 --- /dev/null +++ b/internal/mock/client_secret.go @@ -0,0 +1,83 @@ +// Code generated by MockGen. DO NOT EDIT. +// Source: authelia.com/provider/oauth2 (interfaces: ClientSecret) +// +// Generated by this command: +// +// mockgen -package internal -destination internal/mock_client_secret.go authelia.com/provider/oauth2 ClientSecret +// + +// Package internal is a generated GoMock package. +package mock + +import ( + context "context" + reflect "reflect" + + gomock "go.uber.org/mock/gomock" +) + +// MockClientSecret is a mock of ClientSecret interface. +type MockClientSecret struct { + ctrl *gomock.Controller + recorder *MockClientSecretMockRecorder +} + +// MockClientSecretMockRecorder is the mock recorder for MockClientSecret. +type MockClientSecretMockRecorder struct { + mock *MockClientSecret +} + +// NewMockClientSecret creates a new mock instance. +func NewMockClientSecret(ctrl *gomock.Controller) *MockClientSecret { + mock := &MockClientSecret{ctrl: ctrl} + mock.recorder = &MockClientSecretMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockClientSecret) EXPECT() *MockClientSecretMockRecorder { + return m.recorder +} + +// Compare mocks base method. +func (m *MockClientSecret) Compare(arg0 context.Context, arg1 []byte) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Compare", arg0, arg1) + ret0, _ := ret[0].(error) + return ret0 +} + +// Compare indicates an expected call of Compare. +func (mr *MockClientSecretMockRecorder) Compare(arg0, arg1 any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Compare", reflect.TypeOf((*MockClientSecret)(nil).Compare), arg0, arg1) +} + +// GetPlainTextValue mocks base method. +func (m *MockClientSecret) GetPlainTextValue() ([]byte, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetPlainTextValue") + ret0, _ := ret[0].([]byte) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// GetPlainTextValue indicates an expected call of GetPlainTextValue. +func (mr *MockClientSecretMockRecorder) GetPlainTextValue() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetPlainTextValue", reflect.TypeOf((*MockClientSecret)(nil).GetPlainTextValue)) +} + +// IsPlainText mocks base method. +func (m *MockClientSecret) IsPlainText() bool { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "IsPlainText") + ret0, _ := ret[0].(bool) + return ret0 +} + +// IsPlainText indicates an expected call of IsPlainText. +func (mr *MockClientSecretMockRecorder) IsPlainText() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "IsPlainText", reflect.TypeOf((*MockClientSecret)(nil).IsPlainText)) +} diff --git a/internal/randx/sequence.go b/internal/randx/sequence.go index 0b0163ac..f19cc1c3 100644 --- a/internal/randx/sequence.go +++ b/internal/randx/sequence.go @@ -13,20 +13,28 @@ var rander = rand.Reader // random function var ( // AlphaNum contains runes [abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789]. AlphaNum = []rune("abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789") + // Alpha contains runes [abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ]. Alpha = []rune("abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ") + // AlphaLowerNum contains runes [abcdefghijklmnopqrstuvwxyz0123456789]. AlphaLowerNum = []rune("abcdefghijklmnopqrstuvwxyz0123456789") + // AlphaUpperNum contains runes [ABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789]. AlphaUpperNum = []rune("ABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789") + // AlphaLower contains runes [abcdefghijklmnopqrstuvwxyz]. AlphaLower = []rune("abcdefghijklmnopqrstuvwxyz") + // AlphaUpperVowels contains runes [AEIOUY]. AlphaUpperVowels = []rune("AEIOUY") + // AlphaUpperNoVowels contains runes [BCDFGHJKLMNPQRSTVWXZ]. AlphaUpperNoVowels = []rune("BCDFGHJKLMNPQRSTVWXZ") + // AlphaUpper contains runes [ABCDEFGHIJKLMNOPQRSTUVWXYZ]. AlphaUpper = []rune("ABCDEFGHIJKLMNOPQRSTUVWXYZ") + // Numeric contains runes [0123456789]. Numeric = []rune("0123456789") ) @@ -36,11 +44,13 @@ func RuneSequence(l int, allowedRunes []rune) (seq []rune, err error) { c := big.NewInt(int64(len(allowedRunes))) seq = make([]rune, l) + var r *big.Int + for i := 0; i < l; i++ { - r, err := rand.Int(rander, c) - if err != nil { + if r, err = rand.Int(rander, c); err != nil { return seq, err } + rn := allowedRunes[r.Uint64()] seq[i] = rn } @@ -50,9 +60,9 @@ func RuneSequence(l int, allowedRunes []rune) (seq []rune, err error) { // MustString returns a random string sequence using the defined runes. Panics on error. func MustString(l int, allowedRunes []rune) string { - seq, err := RuneSequence(l, allowedRunes) - if err != nil { + if seq, err := RuneSequence(l, allowedRunes); err != nil { panic(err) + } else { + return string(seq) } - return string(seq) } diff --git a/token/jwt/token.go b/token/jwt/token.go index cc57f69a..ab82586e 100644 --- a/token/jwt/token.go +++ b/token/jwt/token.go @@ -355,6 +355,11 @@ func (t *Token) Valid(opts ...TokenValidationOption) (err error) { vErr := new(ValidationError) + if !t.valid { + vErr.Inner = errors.New("token has an invalid or unverified signature") + vErr.Errors |= ValidationErrorSignatureInvalid + } + if len(vopts.types) != 0 { if !validateTokenType(vopts.types, t.Header) { vErr.Inner = errors.New("token has an invalid typ") From de64627a4df493ce93c2e0b49768085a1bf85354 Mon Sep 17 00:00:00 2001 From: James Elliott Date: Mon, 23 Sep 2024 05:50:54 +1000 Subject: [PATCH 10/33] client auth --- authorize_request_handler.go | 29 ++- client_authentication.go | 101 +++++++++ client_authentication_strategy.go | 273 ++++++++++++------------- client_authentication_test.go | 43 ++-- compose/compose.go | 1 - config_default.go | 6 + handler/oauth2/strategy_jwt_profile.go | 18 +- internal/consts/jwt.go | 1 - token/jwt/claims_map.go | 4 +- token/jwt/issuer.go | 33 ++- token/jwt/jwt_strategy.go | 49 ----- token/jwt/token.go | 86 ++++++-- token/jwt/util.go | 6 +- token/jwt/validation_error.go | 31 +-- token/jwt/validator.go | 2 +- 15 files changed, 409 insertions(+), 274 deletions(-) diff --git a/authorize_request_handler.go b/authorize_request_handler.go index 92ee5245..2d8647b9 100644 --- a/authorize_request_handler.go +++ b/authorize_request_handler.go @@ -125,34 +125,23 @@ func (f *Fosite) authorizeRequestParametersFromOpenIDConnectRequestObject(ctx co return errorsx.WithStack(fmtRequestObjectDecodeError(token, client, issuer, openid, err)) } - optsValidHeader := []jwt.TokenValidationOption{ + optsValidHeader := []jwt.HeaderValidationOption{ jwt.ValidateKeyID(client.GetRequestObjectSigningKeyID()), jwt.ValidateAlgorithm(client.GetRequestObjectSigningAlg()), + jwt.ValidateEncryptionKeyID(client.GetRequestObjectEncryptionKeyID()), + jwt.ValidateKeyAlgorithm(client.GetRequestObjectEncryptionAlg()), + jwt.ValidateContentEncryption(client.GetRequestObjectEncryptionEnc()), } if err = token.Valid(optsValidHeader...); err != nil { return errorsx.WithStack(fmtRequestObjectDecodeError(token, client, issuer, openid, err)) } - if algAny { - if token.SignatureAlgorithm == consts.JSONWebTokenAlgNone { - return errorsx.WithStack( - ErrInvalidRequestObject. - WithHintf("%s client provided a request object that has an invalid 'kid' or 'alg' header value.", hintRequestObjectPrefix(openid)). - WithDebugf("%s client with id '%s' was not explicitly registered with a 'request_object_signing_alg' value of 'none' but the request object had the 'alg' value 'none' in the header.", hintRequestObjectPrefix(openid), client.GetID())) - } - } else if string(token.SignatureAlgorithm) != alg { - return errorsx.WithStack( - ErrInvalidRequestObject. - WithHintf("%s client provided a request object that has an invalid 'kid' or 'alg' header value.", hintRequestObjectPrefix(openid)). - WithDebugf("%s client with id '%s' was registered with a 'request_object_signing_alg' value of '%s' but the request object had the 'alg' value '%s' in the header.", hintRequestObjectPrefix(openid), client.GetID(), alg, token.SignatureAlgorithm)) - } - - if kid := client.GetRequestObjectSigningKeyID(); kid != "" && kid != token.KeyID { + if algAny && token.SignatureAlgorithm == consts.JSONWebTokenAlgNone { return errorsx.WithStack( ErrInvalidRequestObject. WithHintf("%s client provided a request object that has an invalid 'kid' or 'alg' header value.", hintRequestObjectPrefix(openid)). - WithDebugf("%s client with id '%s' was registered with a 'request_object_signing_key_id' value of '%s' but the request object had the 'kid' value '%s' in the header.", hintRequestObjectPrefix(openid), client.GetID(), kid, token.KeyID)) + WithDebugf("%s client with id '%s' was not explicitly registered with a 'request_object_signing_alg' value of 'none' but the request object had the 'alg' value 'none' in the header.", hintRequestObjectPrefix(openid), client.GetID())) } claims := token.Claims @@ -602,6 +591,12 @@ func fmtRequestObjectDecodeError(token *jwt.Token, client JARClient, issuer stri return outer.WithDebugf("%s client with id '%s' expects request objects to be signed with the 'alg' value '%s' due to the client registration 'request_object_signing_alg' value but the request object was signed with the 'alg' value '%s'.", hintRequestObjectPrefix(openid), client.GetID(), client.GetRequestObjectSigningAlg(), token.SignatureAlgorithm) case errJWTValidation.Has(jwt.ValidationErrorHeaderTypeInvalid): return outer.WithDebugf("%s client with id '%s' expects request objects to be signed with the 'typ' value '%s' but the request object was signed with the 'typ' value '%s'.", hintRequestObjectPrefix(openid), client.GetID(), consts.JSONWebTokenTypeJWT, token.Header[consts.JSONWebTokenHeaderType]) + case errJWTValidation.Has(jwt.ValidationErrorHeaderEncryptionKeyIDInvalid): + return outer.WithDebugf("%s client with id '%s' expects request objects to be encrypted with the 'kid' value '%s' due to the client registration 'request_object_encryption_key_id' value but the request object was encrypted with the 'kid' value '%s'.", hintRequestObjectPrefix(openid), client.GetID(), client.GetRequestObjectEncryptionKeyID(), token.EncryptionKeyID) + case errJWTValidation.Has(jwt.ValidationErrorHeaderKeyAlgorithmInvalid): + return outer.WithDebugf("%s client with id '%s' expects request objects to be encrypted with the 'alg' value '%s' due to the client registration 'request_object_encryption_alg' value but the request object was encrypted with the 'alg' value '%s'.", hintRequestObjectPrefix(openid), client.GetID(), client.GetRequestObjectEncryptionAlg(), token.KeyAlgorithm) + case errJWTValidation.Has(jwt.ValidationErrorHeaderContentEncryptionInvalid): + return outer.WithDebugf("%s client with id '%s' expects request objects to be encrypted with the 'enc' value '%s' due to the client registration 'request_object_encryption_enc' value but the request object was encrypted with the 'enc' value '%s'.", hintRequestObjectPrefix(openid), client.GetID(), client.GetRequestObjectEncryptionEnc(), token.ContentEncryption) case errJWTValidation.Has(jwt.ValidationErrorMalformed): return outer.WithDebugf("%s client with id '%s' provided a request object that was malformed. %s.", hintRequestObjectPrefix(openid), client.GetID(), strings.TrimPrefix(errJWTValidation.Error(), "go-jose/go-jose: ")) case errJWTValidation.Has(jwt.ValidationErrorUnverifiable): diff --git a/client_authentication.go b/client_authentication.go index 6bb6cfd0..b7610d3a 100644 --- a/client_authentication.go +++ b/client_authentication.go @@ -335,9 +335,17 @@ type EndpointClientAuthHandler interface { // GetAuthMethod returns the appropriate auth method for this client. GetAuthMethod(client AuthenticationMethodClient) string + GetAuthSigningKeyID(client AuthenticationMethodClient) string + // GetAuthSigningAlg returns the appropriate auth signature algorithm for this client. GetAuthSigningAlg(client AuthenticationMethodClient) string + GetAuthEncryptionKeyID(client AuthenticationMethodClient) string + + GetAuthEncryptionAlg(client AuthenticationMethodClient) string + + GetAuthEncryptionEnc(client AuthenticationMethodClient) string + // Name returns the appropriate name for this endpoint for logging purposes. Name() string @@ -345,16 +353,77 @@ type EndpointClientAuthHandler interface { AllowAuthMethodAny() bool } +type EndpointClientAuthJWTClient struct { + client AuthenticationMethodClient + handler EndpointClientAuthHandler +} + +func (c *EndpointClientAuthJWTClient) GetID() string { + return c.client.GetID() +} + +func (c *EndpointClientAuthJWTClient) GetClientSecretPlainText() (secret []byte, ok bool, err error) { + return c.client.GetClientSecretPlainText() +} + +func (c *EndpointClientAuthJWTClient) GetJSONWebKeys() (jwks *jose.JSONWebKeySet) { + return c.client.GetJSONWebKeys() +} + +func (c *EndpointClientAuthJWTClient) GetJSONWebKeysURI() (uri string) { + return c.client.GetJSONWebKeysURI() +} + +func (c *EndpointClientAuthJWTClient) GetSigningKeyID() (kid string) { + return "" +} + +func (c *EndpointClientAuthJWTClient) GetSigningAlg() (alg string) { + return c.handler.GetAuthSigningAlg(c.client) +} + +func (c *EndpointClientAuthJWTClient) GetEncryptionKeyID() (kid string) { + return "" +} + +func (c *EndpointClientAuthJWTClient) GetEncryptionAlg() (alg string) { + return "" +} + +func (c *EndpointClientAuthJWTClient) GetEncryptionEnc() (enc string) { + return "" +} + +func (c *EndpointClientAuthJWTClient) IsClientSigned() (is bool) { + return true +} + type TokenEndpointClientAuthHandler struct{} func (h *TokenEndpointClientAuthHandler) GetAuthMethod(client AuthenticationMethodClient) string { return client.GetTokenEndpointAuthMethod() } +func (h *TokenEndpointClientAuthHandler) GetAuthSigningKeyID(client AuthenticationMethodClient) string { + return "" +} + func (h *TokenEndpointClientAuthHandler) GetAuthSigningAlg(client AuthenticationMethodClient) string { return client.GetTokenEndpointAuthSigningAlg() } +func (h *TokenEndpointClientAuthHandler) GetAuthEncryptionKeyID(client AuthenticationMethodClient) string { + return "" +} + +func (h *TokenEndpointClientAuthHandler) GetAuthEncryptionAlg(client AuthenticationMethodClient) string { + return "" +} + +func (h *TokenEndpointClientAuthHandler) GetAuthEncryptionEnc(client AuthenticationMethodClient) string { + return "" +} + func (h *TokenEndpointClientAuthHandler) Name() string { return "token" } @@ -369,10 +438,26 @@ func (h *IntrospectionEndpointClientAuthHandler) GetAuthMethod(client Authentica return client.GetIntrospectionEndpointAuthMethod() } +func (h *IntrospectionEndpointClientAuthHandler) GetAuthSigningKeyID(client AuthenticationMethodClient) string { + return "" +} + func (h *IntrospectionEndpointClientAuthHandler) GetAuthSigningAlg(client AuthenticationMethodClient) string { return client.GetIntrospectionEndpointAuthSigningAlg() } +func (h *IntrospectionEndpointClientAuthHandler) GetAuthEncryptionKeyID(client AuthenticationMethodClient) string { + return "" +} + +func (h *IntrospectionEndpointClientAuthHandler) GetAuthEncryptionAlg(client AuthenticationMethodClient) string { + return "" +} + +func (h *IntrospectionEndpointClientAuthHandler) GetAuthEncryptionEnc(client AuthenticationMethodClient) string { + return "" +} + func (h *IntrospectionEndpointClientAuthHandler) Name() string { return "introspection" } @@ -387,10 +472,26 @@ func (h *RevocationEndpointClientAuthHandler) GetAuthMethod(client Authenticatio return client.GetRevocationEndpointAuthMethod() } +func (h *RevocationEndpointClientAuthHandler) GetAuthSigningKeyID(client AuthenticationMethodClient) string { + return "" +} + func (h *RevocationEndpointClientAuthHandler) GetAuthSigningAlg(client AuthenticationMethodClient) string { return client.GetRevocationEndpointAuthSigningAlg() } +func (h *RevocationEndpointClientAuthHandler) GetAuthEncryptionKeyID(client AuthenticationMethodClient) string { + return "" +} + +func (h *RevocationEndpointClientAuthHandler) GetAuthEncryptionAlg(client AuthenticationMethodClient) string { + return "" +} + +func (h *RevocationEndpointClientAuthHandler) GetAuthEncryptionEnc(client AuthenticationMethodClient) string { + return "" +} + func (h *RevocationEndpointClientAuthHandler) Name() string { return "revocation" } diff --git a/client_authentication_strategy.go b/client_authentication_strategy.go index 2b8541ea..a96ea824 100644 --- a/client_authentication_strategy.go +++ b/client_authentication_strategy.go @@ -13,6 +13,7 @@ import ( xjwt "github.com/golang-jwt/jwt/v5" "authelia.com/provider/oauth2/internal/consts" + "authelia.com/provider/oauth2/token/jwt" "authelia.com/provider/oauth2/x/errorsx" ) @@ -21,12 +22,13 @@ type DefaultClientAuthenticationStrategy struct { ClientManager } Config interface { + JWTStrategyProvider JWKSFetcherStrategyProvider AllowedJWTAssertionAudiencesProvider } } -func (s *DefaultClientAuthenticationStrategy) AuthenticateClient(ctx context.Context, r *http.Request, form url.Values, resolver EndpointClientAuthHandler) (client Client, method string, err error) { +func (s *DefaultClientAuthenticationStrategy) AuthenticateClient(ctx context.Context, r *http.Request, form url.Values, handler EndpointClientAuthHandler) (client Client, method string, err error) { var ( id, secret string @@ -48,7 +50,7 @@ func (s *DefaultClientAuthenticationStrategy) AuthenticateClient(ctx context.Con var assertion *ClientAssertion if hasAssertion { - if assertion, err = NewClientAssertion(ctx, s.Store, assertionValue, assertionType, resolver); err != nil { + if assertion, err = NewClientAssertion(ctx, s.Config.GetJWTStrategy(ctx), s.Store, assertionValue, assertionType, handler); err != nil { return nil, "", err } } @@ -64,10 +66,10 @@ func (s *DefaultClientAuthenticationStrategy) AuthenticateClient(ctx context.Con hasNone := !hasPost && !hasBasic && assertion == nil && len(id) != 0 - return s.authenticate(ctx, id, secret, assertion, hasBasic, hasPost, hasNone, resolver) + return s.authenticate(ctx, id, secret, assertion, hasBasic, hasPost, hasNone, handler) } -func (s *DefaultClientAuthenticationStrategy) authenticate(ctx context.Context, id, secret string, assertion *ClientAssertion, hasBasic, hasPost, hasNone bool, resolver EndpointClientAuthHandler) (client Client, method string, err error) { +func (s *DefaultClientAuthenticationStrategy) authenticate(ctx context.Context, id, secret string, assertion *ClientAssertion, hasBasic, hasPost, hasNone bool, handler EndpointClientAuthHandler) (client Client, method string, err error) { var methods []string if hasBasic { @@ -115,16 +117,16 @@ func (s *DefaultClientAuthenticationStrategy) authenticate(ctx context.Context, return nil, "", errorsx.WithStack(ErrInvalidRequest. WithHintf("Client Authentication failed with more than one known authentication method included in the request which is not permitted."). - WithDebugf("The registered client with id '%s' and the authorization server policy does not permit this malformed request. The `%s_endpoint_auth_method` methods determined to be used were '%s'.", client.GetID(), resolver.Name(), strings.Join(methods, "', '"))) + WithDebugf("The registered client with id '%s' and the authorization server policy does not permit this malformed request. The `%s_endpoint_auth_method` methods determined to be used were '%s'.", client.GetID(), handler.Name(), strings.Join(methods, "', '"))) } switch { case assertion != nil: - method, err = s.doAuthenticateAssertionJWTBearer(ctx, client, assertion, resolver) + method, err = s.doAuthenticateAssertionJWTBearer(ctx, client, assertion, handler) case hasBasic, hasPost: - method, err = s.doAuthenticateClientSecret(ctx, client, secret, hasBasic, hasPost, resolver) + method, err = s.doAuthenticateClientSecret(ctx, client, secret, hasBasic, hasPost, handler) default: - method, err = s.doAuthenticateNone(ctx, client, resolver) + method, err = s.doAuthenticateNone(ctx, client, handler) } if err != nil { @@ -135,9 +137,9 @@ func (s *DefaultClientAuthenticationStrategy) authenticate(ctx context.Context, } // NewClientAssertion converts a raw assertion string into a *ClientAssertion. -func NewClientAssertion(ctx context.Context, store ClientManager, assertion, assertionType string, resolver EndpointClientAuthHandler) (a *ClientAssertion, err error) { +func NewClientAssertion(ctx context.Context, strategy jwt.Strategy, store ClientManager, assertion, assertionType string, handler EndpointClientAuthHandler) (a *ClientAssertion, err error) { var ( - token *xjwt.Token + token *jwt.Token id, alg, method string client Client @@ -152,12 +154,14 @@ func NewClientAssertion(ctx context.Context, store ClientManager, assertion, ass return &ClientAssertion{Assertion: assertion, Type: assertionType}, errorsx.WithStack(ErrInvalidRequest.WithHintf("Unknown client_assertion_type '%s'.", assertionType)) } - if token, _, err = xjwt.NewParser(xjwt.WithoutClaimsValidation()).ParseUnverified(assertion, &xjwt.MapClaims{}); err != nil { + if token, err = strategy.Decode(ctx, assertion, jwt.WithAllowUnverified()); err != nil { return &ClientAssertion{Assertion: assertion, Type: assertionType}, resolveJWTErrorToRFCError(err) } - if id, err = token.Claims.GetSubject(); err != nil { - if id, err = token.Claims.GetIssuer(); err != nil { + var ok bool + + if id, ok = token.Claims.GetSubject(); !ok { + if id, ok = token.Claims.GetIssuer(); !ok { return &ClientAssertion{Assertion: assertion, Type: assertionType}, nil } } @@ -166,8 +170,10 @@ func NewClientAssertion(ctx context.Context, store ClientManager, assertion, ass return &ClientAssertion{Assertion: assertion, Type: assertionType, ID: id}, nil } - if c, ok := client.(AuthenticationMethodClient); ok { - alg, method = resolver.GetAuthSigningAlg(c), resolver.GetAuthMethod(c) + var c AuthenticationMethodClient + + if c, ok = client.(AuthenticationMethodClient); ok { + alg, method = handler.GetAuthSigningAlg(c), handler.GetAuthMethod(c) } return &ClientAssertion{ @@ -242,35 +248,38 @@ func (s *DefaultClientAuthenticationStrategy) doAuthenticateClientSecret(ctx con } } -func (s *DefaultClientAuthenticationStrategy) doAuthenticateAssertionJWTBearer(ctx context.Context, client Client, assertion *ClientAssertion, resolver EndpointClientAuthHandler) (method string, err error) { +func (s *DefaultClientAuthenticationStrategy) doAuthenticateAssertionJWTBearer(ctx context.Context, client Client, assertion *ClientAssertion, handler EndpointClientAuthHandler) (method string, err error) { var ( - token *xjwt.Token - claims *xjwt.RegisteredClaims + token *jwt.Token ) - if method, _, _, token, claims, err = s.doAuthenticateAssertionParseAssertionJWTBearer(ctx, client, assertion, resolver); err != nil { + if method, _, _, token, err = s.doAuthenticateAssertionParseAssertionJWTBearer(ctx, client, assertion, handler); err != nil { return "", err } if token == nil { - return "", err + return "", errorsx.WithStack(ErrInvalidClient.WithDebug("The client assertion did not result in a parsed token.")) } clientID := []byte(client.GetID()) + claims := &jwt.JWTClaims{} + + claims.FromMapClaims(token.Claims) + switch { case subtle.ConstantTimeCompare([]byte(claims.Issuer), clientID) == 0: return "", errorsx.WithStack(ErrInvalidClient.WithHint("Claim 'iss' from 'client_assertion' must match the 'client_id' of the OAuth 2.0 Client.")) case subtle.ConstantTimeCompare([]byte(claims.Subject), clientID) == 0: return "", errorsx.WithStack(ErrInvalidClient.WithHint("Claim 'sub' from 'client_assertion' must match the 'client_id' of the OAuth 2.0 Client.")) - case claims.ID == "": + case claims.JTI == "": return "", errorsx.WithStack(ErrInvalidClient.WithHint("Claim 'jti' from 'client_assertion' must be set but is not.")) default: - if err = s.Store.ClientAssertionJWTValid(ctx, claims.ID); err != nil { + if err = s.Store.ClientAssertionJWTValid(ctx, claims.JTI); err != nil { return "", errorsx.WithStack(ErrJTIKnown.WithHint("Claim 'jti' from 'client_assertion' MUST only be used once.").WithDebugError(err)) } - if err = s.Store.SetClientAssertionJWT(ctx, claims.ID, time.Unix(claims.ExpiresAt.Unix(), 0)); err != nil { + if err = s.Store.SetClientAssertionJWT(ctx, claims.JTI, time.Unix(claims.ExpiresAt.Unix(), 0)); err != nil { return "", err } @@ -278,147 +287,49 @@ func (s *DefaultClientAuthenticationStrategy) doAuthenticateAssertionJWTBearer(c } } -func (s *DefaultClientAuthenticationStrategy) doAuthenticateAssertionParseAssertionJWTBearer(ctx context.Context, client Client, assertion *ClientAssertion, resolver EndpointClientAuthHandler) (method, kid, alg string, token *xjwt.Token, claims *xjwt.RegisteredClaims, err error) { +func (s *DefaultClientAuthenticationStrategy) doAuthenticateAssertionParseAssertionJWTBearer(ctx context.Context, client Client, assertion *ClientAssertion, handler EndpointClientAuthHandler) (method, kid, alg string, token *jwt.Token, err error) { audience := s.Config.GetAllowedJWTAssertionAudiences(ctx) if len(audience) == 0 { - return "", "", "", nil, nil, errorsx.WithStack(ErrMisconfiguration.WithHint("The authorization server does not support OAuth 2.0 JWT Profile Client Authentication RFC7523 or OpenID Connect 1.0 specific authentication methods.").WithDebug("The authorization server could not determine any safe value for it's audience but it's required to validate the RFC7523 client assertions.")) - } - - opts := []xjwt.ParserOption{ - xjwt.WithStrictDecoding(), - //xjwt.WithAudience(tokenURI), // Satisfies RFC7523 Section 3 Point 3. - xjwt.WithExpirationRequired(), // Satisfies RFC7523 Section 3 Point 4. - xjwt.WithIssuedAt(), // Satisfies RFC7523 Section 3 Point 6. + return "", "", "", nil, errorsx.WithStack(ErrMisconfiguration.WithHint("The authorization server does not support OAuth 2.0 JWT Profile Client Authentication RFC7523 or OpenID Connect 1.0 specific authentication methods.").WithDebug("The authorization server could not determine any safe value for it's audience but it's required to validate the RFC7523 client assertions.")) } - // Automatically satisfies RFC7523 Section 3 Point 5, 8, 9, and 10. - parser := xjwt.NewParser(opts...) - - claims = &xjwt.RegisteredClaims{} - - if token, err = parser.ParseWithClaims(assertion.Assertion, claims, func(token *xjwt.Token) (key any, err error) { - if subtle.ConstantTimeCompare([]byte(client.GetID()), []byte(claims.Subject)) == 0 { - return nil, errorsx.WithStack(ErrInvalidClient.WithHint("The supplied 'client_id' did not match the 'sub' claim of the 'client_assertion'.")) - } - - // The following check satisfies RFC7523 Section 3 Point 2. - // See: https://datatracker.ietf.org/doc/html/rfc7523#section-3. - if claims.Subject == "" { - return nil, errorsx.WithStack(ErrInvalidClient.WithHint("The claim 'sub' from the 'client_assertion' isn't defined.")) - } - - var ( - c AuthenticationMethodClient - ok bool - ) - - if c, ok = client.(AuthenticationMethodClient); !ok { - return nil, errorsx.WithStack(ErrInvalidRequest.WithHint("The registered client does not support OAuth 2.0 JWT Profile Client Authentication RFC7523 or OpenID Connect 1.0 specific authentication methods.")) - } + var ( + c AuthenticationMethodClient + ok bool + ) - return s.doAuthenticateAssertionParseAssertionJWTBearerFindKey(ctx, token.Header, c, resolver) - }); err != nil { - return "", "", "", nil, nil, resolveJWTErrorToRFCError(err) + if c, ok = client.(AuthenticationMethodClient); !ok { + return "", "", "", nil, errorsx.WithStack(ErrInvalidRequest.WithHint("The registered client does not support OAuth 2.0 JWT Profile Client Authentication RFC7523 or OpenID Connect 1.0 specific authentication methods.")) } - // Satisfies RFC7523 Section 3 Point 3. - if err = s.doAuthenticateAssertionJWTBearerClaimAudience(ctx, audience, claims); err != nil { - return "", "", "", nil, nil, err + if token, err = s.Config.GetJWTStrategy(ctx).Decode(ctx, assertion.Assertion, jwt.WithClient(&EndpointClientAuthJWTClient{client: c, handler: handler})); err != nil { + return "", "", "", nil, errorsx.WithStack(fmtClientAssertionDecodeError(token, c, handler, audience, err)) } - return method, kid, alg, token, claims, nil -} - -func (s *DefaultClientAuthenticationStrategy) doAuthenticateAssertionJWTBearerClaimAudience(ctx context.Context, audience []string, claims *xjwt.RegisteredClaims) (err error) { - if len(claims.Audience) == 0 { - return errorsx.WithStack( - ErrInvalidClient. - WithHint("Unable to verify the integrity of the 'client_assertion' value. It may have been used before it was issued, may have been used before it's allowed to be used, may have been used after it's expired, or otherwise doesn't meet a particular validation constraint."). - WithDebug("Unable to validate the 'aud' claim of the 'client_assertion' as it was empty."), - ) + optsClaims := []jwt.ClaimValidationOption{ + jwt.ValidateAudienceAny(audience...), // Satisfies RFC7523 Section 3 Point 3. + jwt.ValidateRequireExpiresAt(), // Satisfies RFC7523 Section 3 Point 4. + jwt.ValidateTimeFunc(time.Now), } - validAudience := false - - var aud, unverified string - -verification: - for _, unverified = range claims.Audience { - for _, aud = range audience { - if subtle.ConstantTimeCompare([]byte(aud), []byte(unverified)) == 1 { - validAudience = true - break verification - } - } + if err = token.Claims.Valid(optsClaims...); err != nil { + return "", "", "", nil, errorsx.WithStack(fmtClientAssertionDecodeError(token, c, handler, audience, err)) } - if !validAudience { - return errorsx.WithStack( - ErrInvalidClient. - WithHint("Unable to verify the integrity of the 'client_assertion' value. It may have been used before it was issued, may have been used before it's allowed to be used, may have been used after it's expired, or otherwise doesn't meet a particular validation constraint."). - WithDebugf("Unable to validate the 'aud' claim of the 'client_assertion' value '%s' as it doesn't match any of the expected values '%s'.", strings.Join(claims.Audience, "', '"), strings.Join(audience, "', '")), - ) + optsHeader := []jwt.HeaderValidationOption{ + jwt.ValidateKeyID(handler.GetAuthSigningKeyID(c)), + jwt.ValidateAlgorithm(handler.GetAuthSigningAlg(c)), + jwt.ValidateEncryptionKeyID(handler.GetAuthEncryptionKeyID(c)), + jwt.ValidateKeyAlgorithm(handler.GetAuthEncryptionAlg(c)), + jwt.ValidateContentEncryption(handler.GetAuthEncryptionEnc(c)), } - return nil -} - -func (s *DefaultClientAuthenticationStrategy) doAuthenticateAssertionParseAssertionJWTBearerFindKey(ctx context.Context, header map[string]any, client AuthenticationMethodClient, handler EndpointClientAuthHandler) (key any, err error) { - var kid, alg, method string - - kid, alg = getJWTHeaderKIDAlg(header) - - if calg := handler.GetAuthSigningAlg(client); calg != alg && calg != "" { - return nil, errorsx.WithStack(ErrInvalidClient.WithHintf("The requested OAuth 2.0 client does not support the '%s_endpoint_auth_signing_alg' value '%s'.", handler.Name(), alg).WithDebugf("The registered OAuth 2.0 client with id '%s' only supports the '%s' algorithm.", client.GetID(), calg)) + if err = token.Valid(optsHeader...); err != nil { + return "", "", "", nil, errorsx.WithStack(fmtClientAssertionDecodeError(token, c, handler, audience, err)) } - switch method = handler.GetAuthMethod(client); method { - case consts.ClientAuthMethodClientSecretJWT: - return s.doAuthenticateAssertionParseAssertionJWTBearerFindKeyClientSecretJWT(ctx, kid, alg, client, handler) - case consts.ClientAuthMethodPrivateKeyJWT: - return s.doAuthenticateAssertionParseAssertionJWTBearerFindKeyPrivateKeyJWT(ctx, kid, alg, client, handler) - case consts.ClientAuthMethodNone: - return nil, errorsx.WithStack(ErrInvalidClient.WithHint("This requested OAuth 2.0 client does not support client authentication, however 'client_assertion' was provided in the request.")) - case consts.ClientAuthMethodClientSecretBasic, consts.ClientAuthMethodClientSecretPost: - return nil, errorsx.WithStack(ErrInvalidClient.WithHintf("This requested OAuth 2.0 client only supports client authentication method '%s', however 'client_assertion' was provided in the request.", method)) - default: - return nil, errorsx.WithStack(ErrInvalidClient.WithHintf("This requested OAuth 2.0 client only supports client authentication method '%s', however that method is not supported by this server.", method)) - } -} - -func (s *DefaultClientAuthenticationStrategy) doAuthenticateAssertionParseAssertionJWTBearerFindKeyClientSecretJWT(_ context.Context, _, alg string, client AuthenticationMethodClient, handler EndpointClientAuthHandler) (key any, err error) { - switch alg { - case xjwt.SigningMethodHS256.Alg(), xjwt.SigningMethodHS384.Alg(), xjwt.SigningMethodRS512.Alg(): - secret := client.GetClientSecret() - - if secret == nil || !secret.IsPlainText() { - return nil, errorsx.WithStack(ErrInvalidClient.WithHint("The requested OAuth 2.0 client does not support the client authentication method 'client_secret_jwt' ")) - } - - if key, err = secret.GetPlainTextValue(); err != nil { - return nil, errorsx.WithStack(ErrInvalidClient.WithHint("The requested OAuth 2.0 client does not support the client authentication method 'client_secret_jwt' ")) - } - - return key, nil - default: - return nil, errorsx.WithStack(ErrInvalidClient.WithHintf("The requested OAuth 2.0 client does not support the '%s_endpoint_auth_signing_alg' value '%s'.", handler.Name(), alg)) - } -} - -func (s *DefaultClientAuthenticationStrategy) doAuthenticateAssertionParseAssertionJWTBearerFindKeyPrivateKeyJWT(ctx context.Context, kid, alg string, client AuthenticationMethodClient, handler EndpointClientAuthHandler) (key any, err error) { - switch alg { - case xjwt.SigningMethodRS256.Alg(), xjwt.SigningMethodRS384.Alg(), xjwt.SigningMethodRS512.Alg(), - xjwt.SigningMethodPS256.Alg(), xjwt.SigningMethodPS384.Alg(), xjwt.SigningMethodPS512.Alg(), - xjwt.SigningMethodES256.Alg(), xjwt.SigningMethodES384.Alg(), xjwt.SigningMethodES512.Alg(): - if key, err = FindClientPublicJWK(ctx, s.Config, client, kid, alg, "sig"); err != nil { - return nil, err - } - - return key, nil - default: - return nil, errorsx.WithStack(ErrInvalidClient.WithHintf("The requested OAuth 2.0 client does not support the '%s_endpoint_auth_signing_alg' value '%s'.", handler.Name(), alg)) - } + return method, kid, alg, token, nil } func (s *DefaultClientAuthenticationStrategy) getClientCredentialsSecretPost(form url.Values) (id, secret string, ok bool) { @@ -443,3 +354,73 @@ func resolveJWTErrorToRFCError(err error) (rfc error) { return errorsx.WithStack(ErrInvalidClient.WithHint("Unable to decode 'client_assertion' value for an unknown reason.").WithWrap(err).WithDebugError(err)) } } + +func fmtClientAssertionDecodeError(token *jwt.Token, client AuthenticationMethodClient, handler EndpointClientAuthHandler, audience []string, inner error) (outer *RFC6749Error) { + outer = ErrInvalidClient.WithWrap(inner).WithHintf("OAuth 2.0 client with id '%s' provided a client assertion which could not be decoded or validated.", client.GetID()) + + if errJWTValidation := new(jwt.ValidationError); errors.As(inner, &errJWTValidation) { + switch { + case errJWTValidation.Has(jwt.ValidationErrorHeaderKeyIDInvalid): + return outer.WithDebugf("OAuth 2.0 client with id '%s' expects client assertions to be signed with the 'kid' value '%s' due to the client registration 'request_object_signing_key_id' value but the client assertion was signed with the 'kid' value '%s'.", client.GetID(), handler.GetAuthSigningKeyID(client), token.KeyID) + case errJWTValidation.Has(jwt.ValidationErrorHeaderAlgorithmInvalid): + return outer.WithDebugf("OAuth 2.0 client with id '%s' expects client assertions to be signed with the 'alg' value '%s' due to the client registration 'request_object_signing_alg' value but the client assertion was signed with the 'alg' value '%s'.", client.GetID(), handler.GetAuthSigningAlg(client), token.SignatureAlgorithm) + case errJWTValidation.Has(jwt.ValidationErrorHeaderTypeInvalid): + return outer.WithDebugf("OAuth 2.0 client with id '%s' expects client assertions to be signed with the 'typ' value '%s' but the client assertion was signed with the 'typ' value '%s'.", client.GetID(), consts.JSONWebTokenTypeJWT, token.Header[consts.JSONWebTokenHeaderType]) + case errJWTValidation.Has(jwt.ValidationErrorHeaderEncryptionKeyIDInvalid): + return outer.WithDebugf("OAuth 2.0 client with id '%s' expects client assertions to be encrypted with the 'kid' value '%s' due to the client registration 'request_object_encryption_key_id' value but the client assertion was encrypted with the 'kid' value '%s'.", client.GetID(), handler.GetAuthEncryptionKeyID(client), token.EncryptionKeyID) + case errJWTValidation.Has(jwt.ValidationErrorHeaderKeyAlgorithmInvalid): + return outer.WithDebugf("OAuth 2.0 client with id '%s' expects client assertions to be encrypted with the 'alg' value '%s' due to the client registration 'request_object_encryption_alg' value but the client assertion was encrypted with the 'alg' value '%s'.", client.GetID(), handler.GetAuthEncryptionAlg(client), token.KeyAlgorithm) + case errJWTValidation.Has(jwt.ValidationErrorHeaderContentEncryptionInvalid): + return outer.WithDebugf("OAuth 2.0 client with id '%s' expects client assertions to be encrypted with the 'enc' value '%s' due to the client registration 'request_object_encryption_enc' value but the client assertion was encrypted with the 'enc' value '%s'.", client.GetID(), handler.GetAuthEncryptionEnc(client), token.ContentEncryption) + case errJWTValidation.Has(jwt.ValidationErrorMalformed): + return outer.WithDebugf("OAuth 2.0 client with id '%s' provided a client assertion that was malformed. %s.", client.GetID(), strings.TrimPrefix(errJWTValidation.Error(), "go-jose/go-jose: ")) + case errJWTValidation.Has(jwt.ValidationErrorUnverifiable): + return outer.WithDebugf("OAuth 2.0 client with id '%s' provided a client assertion that was not able to be verified. %s.", client.GetID(), strings.TrimPrefix(errJWTValidation.Error(), "go-jose/go-jose: ")) + case errJWTValidation.Has(jwt.ValidationErrorSignatureInvalid): + return outer.WithDebugf("OAuth 2.0 client with id '%s' provided a client assertion that has an invalid signature. %s.", client.GetID(), strings.TrimPrefix(errJWTValidation.Error(), "go-jose/go-jose: ")) + case errJWTValidation.Has(jwt.ValidationErrorExpired): + exp, ok := token.Claims.GetExpiresAt() + if ok { + return outer.WithDebugf("OAuth 2.0 client with id '%s' provided a client assertion that was expired. The client assertion expired at %d.", client.GetID(), exp) + } else { + return outer.WithDebugf("OAuth 2.0 client with id '%s' provided a client assertion that was expired. The client assertion does not have an 'exp' claim or it has an invalid type.", client.GetID()) + } + case errJWTValidation.Has(jwt.ValidationErrorIssuedAt): + iat, ok := token.Claims.GetIssuedAt() + if ok { + return outer.WithDebugf("OAuth 2.0 client with id '%s' provided a client assertion that was issued in the future. The client assertion was issued at %d.", client.GetID(), iat) + } else { + return outer.WithDebugf("OAuth 2.0 client with id '%s' provided a client assertion that was issued in the future. The client assertion does not have an 'iat' claim or it has an invalid type.", client.GetID()) + } + case errJWTValidation.Has(jwt.ValidationErrorNotValidYet): + nbf, ok := token.Claims.GetNotBefore() + if ok { + return outer.WithDebugf("OAuth 2.0 client with id '%s' provided a client assertion that was issued in the future. The client assertion is not valid before %d.", client.GetID(), nbf) + } else { + return outer.WithDebugf("OAuth 2.0 client with id '%s' provided a client assertion that was issued in the future. The client assertion does not have an 'nbf' claim or it has an invalid type.", client.GetID()) + } + case errJWTValidation.Has(jwt.ValidationErrorIssuer): + iss, ok := token.Claims.GetIssuer() + if ok { + return outer.WithDebugf("OAuth 2.0 client with id '%s' provided a client assertion that has an invalid issuer. The client assertion was expected to have an 'iss' claim which matches the value '%s' but the 'iss' claim had the value '%s'.", client.GetID(), client.GetID(), iss) + } else { + return outer.WithDebugf("OAuth 2.0 client with id '%s' provided a client assertion that has an invalid issuer. The client assertion does not have an 'iss' claim or it has an invalid type.", client.GetID()) + } + case errJWTValidation.Has(jwt.ValidationErrorAudience): + aud, ok := token.Claims.GetAudience() + if ok { + return outer.WithDebugf("OAuth 2.0 client with id '%s' provided a client assertion that has an invalid audience. The client assertion was expected to have an 'aud' claim which matches one of the values '%s' but the 'aud' claim had the values '%s'.", client.GetID(), strings.Join(audience, "', '"), strings.Join(aud, "', '")) + } else { + return outer.WithDebugf("OAuth 2.0 client with id '%s' provided a client assertion that has an invalid audience. The client assertion does not have an 'aud' claim or it has an invalid type.", client.GetID()) + } + case errJWTValidation.Has(jwt.ValidationErrorClaimsInvalid): + return outer.WithDebugf("OAuth 2.0 client with id '%s' provided a client assertion that had one or more invalid claims. Error occurred trying to validate the client assertions claims: %s", client.GetID(), strings.TrimPrefix(errJWTValidation.Error(), "go-jose/go-jose: ")) + default: + return outer.WithDebugf("OAuth 2.0 client with id '%s' provided a client assertion that could not be validated. Error occurred trying to validate the client assertion: %s", client.GetID(), strings.TrimPrefix(errJWTValidation.Error(), "go-jose/go-jose: ")) + } + } else if errJWKLookup := new(jwt.JWKLookupError); errors.As(inner, &errJWKLookup) { + return outer.WithDebugf("OAuth 2.0 client with id '%s' provided a client assertion that could not be validated due to a key lookup error. %s.", client.GetID(), errJWKLookup.Description) + } else { + return outer.WithDebugf("OAuth 2.0 client with id '%s' provided a client assertion that could not be validated. %s.", client.GetID(), ErrorToDebugRFC6749Error(inner).Error()) + } +} diff --git a/client_authentication_test.go b/client_authentication_test.go index 3995341b..15a186a0 100644 --- a/client_authentication_test.go +++ b/client_authentication_test.go @@ -32,9 +32,10 @@ func TestAuthenticateClient(t *testing.T) { jwksRSA := &jose.JSONWebKeySet{ Keys: []jose.JSONWebKey{ { - KeyID: "kid-foo", - Use: "sig", - Key: &keyRSA.PublicKey, + KeyID: "kid-foo", + Use: "sig", + Algorithm: "RS256", + Key: &keyRSA.PublicKey, }, }, } @@ -694,13 +695,20 @@ func TestAuthenticateClient(t *testing.T) { for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { + config := &Config{ + JWKSFetcherStrategy: NewDefaultJWKSFetcherStrategy(), + AllowedJWTAssertionAudiences: []string{"token-url"}, + HTTPClient: retryablehttp.NewClient(), + } + + config.JWTStrategy = &jwt.DefaultStrategy{ + Config: config, + Issuer: jwt.NewDefaultIssuerUnverifiedFromJWKS(jwksRSA), + } + provider := &Fosite{ - Store: storage.NewMemoryStore(), - Config: &Config{ - JWKSFetcherStrategy: NewDefaultJWKSFetcherStrategy(), - AllowedJWTAssertionAudiences: []string{"token-url"}, - HTTPClient: retryablehttp.NewClient(), - }, + Store: storage.NewMemoryStore(), + Config: config, } var h http.HandlerFunc = func(w http.ResponseWriter, r *http.Request) { @@ -761,12 +769,19 @@ func TestAuthenticateClientTwice(t *testing.T) { store := storage.NewMemoryStore() store.Clients[client.ID] = client + config := &Config{ + JWKSFetcherStrategy: NewDefaultJWKSFetcherStrategy(), + AllowedJWTAssertionAudiences: []string{"token-url"}, + } + + config.JWTStrategy = &jwt.DefaultStrategy{ + Config: config, + Issuer: jwt.NewDefaultIssuerRS256Unverified(key), + } + provider := &Fosite{ - Store: store, - Config: &Config{ - JWKSFetcherStrategy: NewDefaultJWKSFetcherStrategy(), - AllowedJWTAssertionAudiences: []string{"token-url"}, - }, + Store: store, + Config: config, } formValues := url.Values{"client_id": []string{"bar"}, "client_assertion": {mustGenerateRSAAssertion(t, jwt.MapClaims{ diff --git a/compose/compose.go b/compose/compose.go index dde1380c..73ef4b40 100644 --- a/compose/compose.go +++ b/compose/compose.go @@ -82,7 +82,6 @@ func ComposeAllEnabled(config *oauth2.Config, storage any, key any) oauth2.Provi CoreStrategy: NewOAuth2HMACStrategy(config), OpenIDConnectTokenStrategy: NewOpenIDConnectStrategy(keyGetter, strategy, config), Strategy: strategy, - //Signer: &jwt.DefaultSigner{GetPrivateKey: keyGetter}, }, OAuth2AuthorizeExplicitFactory, OAuth2AuthorizeImplicitFactory, diff --git a/config_default.go b/config_default.go index 7c7677e5..961f1856 100644 --- a/config_default.go +++ b/config_default.go @@ -397,6 +397,12 @@ func (c *Config) GetJWTSecuredAuthorizeResponseModeStrategy(ctx context.Context) } func (c *Config) GetJWTStrategy(ctx context.Context) jwt.Strategy { + if c.JWTStrategy == nil { + c.JWTStrategy = &jwt.DefaultStrategy{ + Config: c, + } + } + return c.JWTStrategy } diff --git a/handler/oauth2/strategy_jwt_profile.go b/handler/oauth2/strategy_jwt_profile.go index 95608478..d0703554 100644 --- a/handler/oauth2/strategy_jwt_profile.go +++ b/handler/oauth2/strategy_jwt_profile.go @@ -188,23 +188,31 @@ func validateJWT(ctx context.Context, strategy jwt.Strategy, client jwt.Client, func fmtValidateJWTError(token *jwt.Token, client jwt.Client, inner error) (err error) { var ( - clientText string - skid, salg string + clientText string + sigKID, sigAlg string + encKID, encAlg, enc string ) if client != nil { clientText = fmt.Sprintf("provided by client with id '%s' ", client.GetID()) - skid, salg = client.GetSigningKeyID(), client.GetSigningAlg() + sigKID, sigAlg = client.GetSigningKeyID(), client.GetSigningAlg() + encKID, encAlg, enc = client.GetEncryptionKeyID(), client.GetEncryptionAlg(), client.GetEncryptionEnc() } if errJWTValidation := new(jwt.ValidationError); errors.As(inner, &errJWTValidation) { switch { case errJWTValidation.Has(jwt.ValidationErrorHeaderKeyIDInvalid): - return oauth2.ErrInvalidTokenFormat.WithDebugf("Token %sis expected to be signed with the 'kid' value '%s' but it was signed with the 'kid' value '%s'.", clientText, skid, token.KeyID) + return oauth2.ErrInvalidTokenFormat.WithDebugf("Token %sis expected to be signed with the 'kid' value '%s' but it was signed with the 'kid' value '%s'.", clientText, sigKID, token.KeyID) case errJWTValidation.Has(jwt.ValidationErrorHeaderAlgorithmInvalid): - return oauth2.ErrInvalidTokenFormat.WithDebugf("Token %sis expected to be signed with the 'alg' value '%s' but it was signed with the 'alg' value '%s'.", clientText, salg, token.SignatureAlgorithm) + return oauth2.ErrInvalidTokenFormat.WithDebugf("Token %sis expected to be signed with the 'alg' value '%s' but it was signed with the 'alg' value '%s'.", clientText, sigAlg, token.SignatureAlgorithm) case errJWTValidation.Has(jwt.ValidationErrorHeaderTypeInvalid): return oauth2.ErrInvalidTokenFormat.WithDebugf("Token %sis expected to be signed with the 'typ' value '%s' but it was signed with the 'typ' value '%s'.", clientText, consts.JSONWebTokenTypeJWT, token.Header[consts.JSONWebTokenHeaderType]) + case errJWTValidation.Has(jwt.ValidationErrorHeaderEncryptionKeyIDInvalid): + return oauth2.ErrInvalidTokenFormat.WithDebugf("Token %sis expected to be encrypted with the 'kid' value '%s' but it was encrypted with the 'kid' value '%s'.", clientText, encKID, token.EncryptionKeyID) + case errJWTValidation.Has(jwt.ValidationErrorHeaderKeyAlgorithmInvalid): + return oauth2.ErrInvalidTokenFormat.WithDebugf("Token %sis expected to be encrypted with the 'alg' value '%s' but it was encrypted with the 'alg' value '%s'.", clientText, encAlg, token.KeyAlgorithm) + case errJWTValidation.Has(jwt.ValidationErrorHeaderContentEncryptionInvalid): + return oauth2.ErrInvalidTokenFormat.WithDebugf("Token %sis expected to be encrypted with the 'enc' value '%s' but it was encrypted with the 'enc' value '%s'.", clientText, enc, token.ContentEncryption) case errJWTValidation.Has(jwt.ValidationErrorMalformed): return oauth2.ErrInvalidTokenFormat.WithDebugf("Token %sis malformed. %s.", clientText, strings.TrimPrefix(errJWTValidation.Error(), "go-jose/go-jose: ")) case errJWTValidation.Has(jwt.ValidationErrorUnverifiable): diff --git a/internal/consts/jwt.go b/internal/consts/jwt.go index 9dc52a9c..ecc35b4f 100644 --- a/internal/consts/jwt.go +++ b/internal/consts/jwt.go @@ -7,7 +7,6 @@ const ( JSONWebTokenHeaderCompressionAlgorithm = "zip" JSONWebTokenHeaderPBES2Count = "p2c" - JSONWebTokenHeaderUse = "use" JSONWebTokenHeaderType = "typ" JSONWebTokenHeaderContentType = "cty" ) diff --git a/token/jwt/claims_map.go b/token/jwt/claims_map.go index 8b57c01e..3a067301 100644 --- a/token/jwt/claims_map.go +++ b/token/jwt/claims_map.go @@ -182,9 +182,9 @@ func (m MapClaims) Valid(opts ...ClaimValidationOption) (err error) { var now int64 if vopts.timef != nil { - now = vopts.timef().Unix() + now = vopts.timef().UTC().Unix() } else { - now = TimeFunc().Unix() + now = TimeFunc().UTC().Unix() } vErr := new(ValidationError) diff --git a/token/jwt/issuer.go b/token/jwt/issuer.go index da07d599..de778d40 100644 --- a/token/jwt/issuer.go +++ b/token/jwt/issuer.go @@ -12,6 +12,7 @@ import ( "authelia.com/provider/oauth2/internal/consts" ) +// NewDefaultIssuer returns a new issuer and verifies that one RS256 key exists. func NewDefaultIssuer(keys ...jose.JSONWebKey) (issuer *DefaultIssuer, err error) { jwks := &jose.JSONWebKeySet{ Keys: make([]jose.JSONWebKey, len(keys)), @@ -22,6 +23,10 @@ func NewDefaultIssuer(keys ...jose.JSONWebKey) (issuer *DefaultIssuer, err error for i, key := range keys { jwks.Keys[i] = key + if hasRS256 { + continue + } + if key.Use != consts.JSONWebTokenUseSignature { continue } @@ -37,9 +42,31 @@ func NewDefaultIssuer(keys ...jose.JSONWebKey) (issuer *DefaultIssuer, err error return nil, errors.New("no RS256 signature algorithm found") } - return issuer, nil + return NewDefaultIssuerUnverifiedFromJWKS(jwks), nil +} + +func NewDefaultIssuerFromJWKS(jwks *jose.JSONWebKeySet) (issuer *DefaultIssuer, err error) { + for _, key := range jwks.Keys { + if key.Use != consts.JSONWebTokenUseSignature { + continue + } + + if key.Algorithm != string(jose.RS256) { + continue + } + + return &DefaultIssuer{jwks: jwks}, nil + } + + return nil, errors.New("no RS256 signature algorithm found") +} + +// NewDefaultIssuerUnverifiedFromJWKS returns a new issuer from a jose.JSONWebKeySet without verification. +func NewDefaultIssuerUnverifiedFromJWKS(jwks *jose.JSONWebKeySet) (issuer *DefaultIssuer) { + return &DefaultIssuer{jwks: jwks} } +// MustNewDefaultIssuerRS256 is the same as NewDefaultIssuerRS256 but it panics if an error occurs. func MustNewDefaultIssuerRS256(key any) (issuer *DefaultIssuer) { var err error @@ -50,6 +77,7 @@ func MustNewDefaultIssuerRS256(key any) (issuer *DefaultIssuer) { return issuer } +// NewDefaultIssuerRS256 returns an issuer with a single key and returns an error if it's not an RSA2048 or higher key. func NewDefaultIssuerRS256(key any) (issuer *DefaultIssuer, err error) { switch k := key.(type) { case *rsa.PrivateKey: @@ -63,6 +91,7 @@ func NewDefaultIssuerRS256(key any) (issuer *DefaultIssuer, err error) { } } +// NewDefaultIssuerRS256Unverified returns an issuer with a single key asserting the type is an RSA key. func NewDefaultIssuerRS256Unverified(key any) (issuer *DefaultIssuer) { return &DefaultIssuer{ jwks: &jose.JSONWebKeySet{ @@ -78,6 +107,7 @@ func NewDefaultIssuerRS256Unverified(key any) (issuer *DefaultIssuer) { } } +// GenDefaultIssuer generates a *DefaultIssuer with a random RSA key. func GenDefaultIssuer() (issuer *DefaultIssuer, err error) { key, err := rsa.GenerateKey(rand.Reader, 2048) if err != nil { @@ -87,6 +117,7 @@ func GenDefaultIssuer() (issuer *DefaultIssuer, err error) { return NewDefaultIssuerRS256(key) } +// MustGenDefaultIssuer is the same as GenDefaultIssuer but it panics on an error. func MustGenDefaultIssuer() (issuer *DefaultIssuer) { var err error diff --git a/token/jwt/jwt_strategy.go b/token/jwt/jwt_strategy.go index ba759c77..efa8657d 100644 --- a/token/jwt/jwt_strategy.go +++ b/token/jwt/jwt_strategy.go @@ -192,7 +192,6 @@ func (j *DefaultStrategy) Decode(ctx context.Context, tokenString string, opts . } var ( - // key *jose.JSONWebKey t *jwt.JSONWebToken jwe *jose.JSONWebEncryption ) @@ -202,54 +201,6 @@ func (j *DefaultStrategy) Decode(ctx context.Context, tokenString string, opts . return token, err } - /* - if IsEncryptedJWT(tokenString) { - if jwe, err = jose.ParseEncryptedCompact(tokenString, o.keyAlgorithm, o.contentEncryption); err != nil { - return token, errorsx.WithStack(&ValidationError{Errors: ValidationErrorMalformed, Inner: err}) - } - - var ( - kid, alg, cty string - ) - - if kid, alg, _, cty, err = headerValidateJWE(jwe.Header); err != nil { - return token, errorsx.WithStack(&ValidationError{Errors: ValidationErrorMalformed, Inner: err}) - } - - if o.jweKeyFunc != nil { - if key, err = o.jweKeyFunc(ctx, jwe, kid, alg); err != nil { - return token, errorsx.WithStack(&ValidationError{Errors: ValidationErrorUnverifiable, Inner: err}) - } - } else if IsEncryptedJWTClientSecretAlg(alg) { - if o.client == nil { - return token, errorsx.WithStack(&ValidationError{Errors: ValidationErrorUnverifiable, Inner: err}) - } - - if key, err = NewJWKFromClientSecret(ctx, o.client, kid, alg, consts.JSONWebTokenUseEncryption); err != nil { - return token, errorsx.WithStack(&ValidationError{Errors: ValidationErrorUnverifiable, Inner: err}) - } - } else if key, err = j.Issuer.GetIssuerStrictJWK(ctx, kid, alg, consts.JSONWebTokenUseEncryption); err != nil { - return token, errorsx.WithStack(&ValidationError{Errors: ValidationErrorUnverifiable, Inner: err}) - } - - var rawJWT []byte - - if rawJWT, err = jwe.Decrypt(key); err != nil { - return token, errorsx.WithStack(&ValidationError{Errors: ValidationErrorMalformed, Inner: err}) - } - - if t, err = jwt.ParseSigned(string(rawJWT), o.sigAlgorithm); err != nil { - return token, errorsx.WithStack(&ValidationError{Errors: ValidationErrorMalformed, Inner: err}) - } - - if err = headerValidateJWSNested(t.Headers, cty); err != nil { - return token, errorsx.WithStack(&ValidationError{Errors: ValidationErrorMalformed, Inner: err}) - } - } else if t, err = jwt.ParseSigned(tokenString, o.sigAlgorithm); err != nil { - return token, errorsx.WithStack(&ValidationError{Errors: ValidationErrorMalformed, Inner: err}) - } - */ - if t, err = jwt.ParseSigned(tokenString, o.sigAlgorithm); err != nil { return token, errorsx.WithStack(&ValidationError{Errors: ValidationErrorMalformed, Inner: err}) } diff --git a/token/jwt/token.go b/token/jwt/token.go index ab82586e..b887d826 100644 --- a/token/jwt/token.go +++ b/token/jwt/token.go @@ -135,7 +135,8 @@ func ParseCustomWithClaims(tokenString string, claims MapClaims, keyFunc Keyfunc // using go-json type Token struct { KeyID string - SignatureAlgorithm jose.SignatureAlgorithm // alg (JWS) + SignatureAlgorithm jose.SignatureAlgorithm // alg (JWS) + EncryptionKeyID string KeyAlgorithm jose.KeyAlgorithm // alg (JWE) ContentEncryption jose.ContentEncryption // enc (JWE) CompressionAlgorithm jose.CompressionAlgorithm // zip (JWE) @@ -194,18 +195,20 @@ func (t *Token) toEncryptedJoseHeader() (header map[jose.HeaderKey]any) { } // SetJWS sets the JWS output values. -func (t *Token) SetJWS(header Mapper, claims MapClaims, alg jose.SignatureAlgorithm) { +func (t *Token) SetJWS(header Mapper, claims MapClaims, kid string, alg jose.SignatureAlgorithm) { assign(t.Header, header.ToMap()) + t.KeyID = kid t.SignatureAlgorithm = alg t.Claims = claims } // SetJWE sets the JWE output values. -func (t *Token) SetJWE(header Mapper, alg jose.KeyAlgorithm, enc jose.ContentEncryption, zip jose.CompressionAlgorithm) { +func (t *Token) SetJWE(header Mapper, kid string, alg jose.KeyAlgorithm, enc jose.ContentEncryption, zip jose.CompressionAlgorithm) { assign(t.HeaderJWE, header.ToMap()) + t.EncryptionKeyID = kid t.KeyAlgorithm = alg t.ContentEncryption = enc t.CompressionAlgorithm = zip @@ -223,6 +226,7 @@ func (t *Token) AssignJWE(jwe *jose.JSONWebEncryption) { if jwe.Header.KeyID != "" { t.HeaderJWE[consts.JSONWebTokenHeaderKeyIdentifier] = jwe.Header.KeyID + t.EncryptionKeyID = jwe.Header.KeyID } for header, value := range jwe.Header.ExtraHeaders { @@ -344,8 +348,8 @@ func (t *Token) CompactSignedString(k any) (tokenString string, err error) { } // Valid validates the token headers given various input options. This does not validate any claims. -func (t *Token) Valid(opts ...TokenValidationOption) (err error) { - vopts := &TokenValidationOptions{ +func (t *Token) Valid(opts ...HeaderValidationOption) (err error) { + vopts := &HeaderValidationOptions{ types: []string{consts.JSONWebTokenTypeJWT}, } @@ -362,25 +366,46 @@ func (t *Token) Valid(opts ...TokenValidationOption) (err error) { if len(vopts.types) != 0 { if !validateTokenType(vopts.types, t.Header) { - vErr.Inner = errors.New("token has an invalid typ") + vErr.Inner = errors.New("token was signed with an invalid typ") vErr.Errors |= ValidationErrorHeaderTypeInvalid } } if len(vopts.alg) != 0 { if vopts.alg != string(t.SignatureAlgorithm) { - vErr.Inner = errors.New("token has an invalid alg") + vErr.Inner = errors.New("token was signed with an invalid alg") vErr.Errors |= ValidationErrorHeaderAlgorithmInvalid } } if len(vopts.kid) != 0 { if vopts.kid != t.KeyID { - vErr.Inner = errors.New("token has an invalid kid") + vErr.Inner = errors.New("token was signed with an invalid kid") vErr.Errors |= ValidationErrorHeaderKeyIDInvalid } } + if len(vopts.keyAlg) != 0 && len(t.KeyAlgorithm) != 0 { + if vopts.keyAlg != string(t.KeyAlgorithm) { + vErr.Inner = errors.New("token was encrypted with an invalid alg") + vErr.Errors |= ValidationErrorHeaderKeyAlgorithmInvalid + } + } + + if len(vopts.contentEnc) != 0 && len(t.ContentEncryption) != 0 { + if vopts.contentEnc != string(t.ContentEncryption) { + vErr.Inner = errors.New("token was encrypted with an invalid enc") + vErr.Errors |= ValidationErrorHeaderContentEncryptionInvalid + } + } + + if len(vopts.kidEnc) != 0 && len(t.EncryptionKeyID) != 0 { + if vopts.kidEnc != t.EncryptionKeyID { + vErr.Inner = errors.New("token was encrypted with an invalid kid") + vErr.Errors |= ValidationErrorHeaderEncryptionKeyIDInvalid + } + } + if vErr.valid() { return nil } @@ -418,29 +443,50 @@ func (t *Token) IsJWTProfileAccessToken() (ok bool) { return ok && (typ == consts.JSONWebTokenTypeAccessToken || typ == consts.JSONWebTokenTypeAccessTokenAlternative) } -type TokenValidationOption func(opts *TokenValidationOptions) +type HeaderValidationOption func(opts *HeaderValidationOptions) -type TokenValidationOptions struct { - types []string - alg string - kid string +type HeaderValidationOptions struct { + types []string + alg string + kid string + kidEnc string + keyAlg string + contentEnc string } -func ValidateTypes(types ...string) TokenValidationOption { - return func(validator *TokenValidationOptions) { +func ValidateTypes(types ...string) HeaderValidationOption { + return func(validator *HeaderValidationOptions) { validator.types = types } } -func ValidateAlgorithm(alg string) TokenValidationOption { - return func(validator *TokenValidationOptions) { +func ValidateKeyID(kid string) HeaderValidationOption { + return func(validator *HeaderValidationOptions) { + validator.kid = kid + } +} + +func ValidateAlgorithm(alg string) HeaderValidationOption { + return func(validator *HeaderValidationOptions) { validator.alg = alg } } -func ValidateKeyID(kid string) TokenValidationOption { - return func(validator *TokenValidationOptions) { - validator.kid = kid +func ValidateEncryptionKeyID(kid string) HeaderValidationOption { + return func(validator *HeaderValidationOptions) { + validator.kidEnc = kid + } +} + +func ValidateKeyAlgorithm(alg string) HeaderValidationOption { + return func(validator *HeaderValidationOptions) { + validator.keyAlg = alg + } +} + +func ValidateContentEncryption(enc string) HeaderValidationOption { + return func(validator *HeaderValidationOptions) { + validator.contentEnc = enc } } diff --git a/token/jwt/util.go b/token/jwt/util.go index c0f69c66..54bf3334 100644 --- a/token/jwt/util.go +++ b/token/jwt/util.go @@ -303,7 +303,7 @@ func NewJWKFromClientSecret(ctx context.Context, client BaseClient, kid, alg, us func encodeCompactSigned(ctx context.Context, claims MapClaims, headers Mapper, key *jose.JSONWebKey) (tokenString string, signature string, err error) { token := New() - token.SetJWS(headers, claims, jose.SignatureAlgorithm(key.Algorithm)) + token.SetJWS(headers, claims, key.KeyID, jose.SignatureAlgorithm(key.Algorithm)) return token.CompactSigned(key) } @@ -311,8 +311,8 @@ func encodeCompactSigned(ctx context.Context, claims MapClaims, headers Mapper, func encodeNestedCompactEncrypted(ctx context.Context, claims MapClaims, headers, headersJWE Mapper, keySig, keyEnc *jose.JSONWebKey, enc jose.ContentEncryption) (tokenString string, signature string, err error) { token := New() - token.SetJWS(headers, claims, jose.SignatureAlgorithm(keySig.Algorithm)) - token.SetJWE(headersJWE, jose.KeyAlgorithm(keyEnc.Algorithm), enc, jose.NONE) + token.SetJWS(headers, claims, keySig.KeyID, jose.SignatureAlgorithm(keySig.Algorithm)) + token.SetJWE(headersJWE, keyEnc.KeyID, jose.KeyAlgorithm(keyEnc.Algorithm), enc, jose.NONE) return token.CompactEncrypted(keySig, keyEnc) } diff --git a/token/jwt/validation_error.go b/token/jwt/validation_error.go index 548c581e..1e53cf03 100644 --- a/token/jwt/validation_error.go +++ b/token/jwt/validation_error.go @@ -6,20 +6,23 @@ package jwt // Validation provides a backwards compatible error definition // from `jwt-go` to `go-jose`. const ( - ValidationErrorMalformed uint32 = 1 << iota // Token is malformed - ValidationErrorUnverifiable // Token could not be verified because of signing problems - ValidationErrorSignatureInvalid // Signature validation failed. - ValidationErrorHeaderKeyIDInvalid // Header KID invalid error. - ValidationErrorHeaderAlgorithmInvalid // Header ALG invalid error. - ValidationErrorHeaderTypeInvalid // Header TYP invalid error - ValidationErrorId // Claim JTI validation failed - ValidationErrorAudience // Claim AUD validation failed - ValidationErrorExpired // Claim EXP validation failed - ValidationErrorIssuedAt // Claim IAT validation failed - ValidationErrorNotValidYet // Claim NBF validation failed - ValidationErrorIssuer // Claim ISS validation failed - ValidationErrorSubject // Claim SUB validation failed - ValidationErrorClaimsInvalid // Generic claims validation error + ValidationErrorMalformed uint32 = 1 << iota // Token is malformed + ValidationErrorUnverifiable // Token could not be verified because of signing problems + ValidationErrorSignatureInvalid // Signature validation failed. + ValidationErrorHeaderKeyIDInvalid // Header KID invalid error. + ValidationErrorHeaderAlgorithmInvalid // Header ALG invalid error. + ValidationErrorHeaderTypeInvalid // Header TYP invalid error + ValidationErrorHeaderEncryptionKeyIDInvalid // Header KID invalid error (JWE). + ValidationErrorHeaderKeyAlgorithmInvalid // Header ALG invalid error (JWE). + ValidationErrorHeaderContentEncryptionInvalid // Header ENC invalid error (JWE). + ValidationErrorId // Claim JTI validation failed + ValidationErrorAudience // Claim AUD validation failed + ValidationErrorExpired // Claim EXP validation failed + ValidationErrorIssuedAt // Claim IAT validation failed + ValidationErrorNotValidYet // Claim NBF validation failed + ValidationErrorIssuer // Claim ISS validation failed + ValidationErrorSubject // Claim SUB validation failed + ValidationErrorClaimsInvalid // Generic claims validation error ) // The ValidationError is an error implementation from Parse if token is not valid. diff --git a/token/jwt/validator.go b/token/jwt/validator.go index 0efde232..a983eec9 100644 --- a/token/jwt/validator.go +++ b/token/jwt/validator.go @@ -72,7 +72,7 @@ func ValidateIssuedAt(iat int64) ValidatorOpt { } } -func ValidateRequireIssuedAt() TokenValidationOption { +func ValidateRequireIssuedAt() HeaderValidationOption { return func(validator *Validator) { validator.requireIAT = true } From e98f4cf8e0462199df6452555b0de8e8cc7fb785 Mon Sep 17 00:00:00 2001 From: James Elliott Date: Tue, 24 Sep 2024 06:25:07 +1000 Subject: [PATCH 11/33] client auth tests --- client_authentication_test.go | 27 +++++++++++++++++++++++---- 1 file changed, 23 insertions(+), 4 deletions(-) diff --git a/client_authentication_test.go b/client_authentication_test.go index 15a186a0..3fe18b4f 100644 --- a/client_authentication_test.go +++ b/client_authentication_test.go @@ -29,6 +29,7 @@ import ( func TestAuthenticateClient(t *testing.T) { keyRSA := gen.MustRSAKey() + jwksRSA := &jose.JSONWebKeySet{ Keys: []jose.JSONWebKey{ { @@ -44,9 +45,27 @@ func TestAuthenticateClient(t *testing.T) { jwksECDSA := &jose.JSONWebKeySet{ Keys: []jose.JSONWebKey{ { - KeyID: "kid-foo", - Use: "sig", - Key: &keyECDSA.PublicKey, + KeyID: "kid-foo", + Use: "sig", + Algorithm: "ES256", + Key: &keyECDSA.PublicKey, + }, + }, + } + + jwks := &jose.JSONWebKeySet{ + Keys: []jose.JSONWebKey{ + { + KeyID: "kid-foo", + Use: "sig", + Algorithm: "RS256", + Key: keyRSA, + }, + { + KeyID: "kid-foo", + Use: "sig", + Algorithm: "ES256", + Key: keyECDSA, }, }, } @@ -703,7 +722,7 @@ func TestAuthenticateClient(t *testing.T) { config.JWTStrategy = &jwt.DefaultStrategy{ Config: config, - Issuer: jwt.NewDefaultIssuerUnverifiedFromJWKS(jwksRSA), + Issuer: jwt.NewDefaultIssuerUnverifiedFromJWKS(jwks), } provider := &Fosite{ From 5a202af5e04822e898352e35da0a3110992801bf Mon Sep 17 00:00:00 2001 From: James Elliott Date: Tue, 24 Sep 2024 22:30:30 +1000 Subject: [PATCH 12/33] client auth tests --- authorize_request_handler.go | 2 + client_authentication_strategy.go | 42 ++++-- client_authentication_test.go | 30 ++-- handler/oauth2/strategy_jwt_profile.go | 2 + token/jwt/jwt_strategy.go | 2 +- token/jwt/util.go | 2 +- token/jwt/validation_error.go | 1 + token/jwt/validator.go | 199 ------------------------- 8 files changed, 53 insertions(+), 227 deletions(-) delete mode 100644 token/jwt/validator.go diff --git a/authorize_request_handler.go b/authorize_request_handler.go index 2d8647b9..cc0f48ca 100644 --- a/authorize_request_handler.go +++ b/authorize_request_handler.go @@ -597,6 +597,8 @@ func fmtRequestObjectDecodeError(token *jwt.Token, client JARClient, issuer stri return outer.WithDebugf("%s client with id '%s' expects request objects to be encrypted with the 'alg' value '%s' due to the client registration 'request_object_encryption_alg' value but the request object was encrypted with the 'alg' value '%s'.", hintRequestObjectPrefix(openid), client.GetID(), client.GetRequestObjectEncryptionAlg(), token.KeyAlgorithm) case errJWTValidation.Has(jwt.ValidationErrorHeaderContentEncryptionInvalid): return outer.WithDebugf("%s client with id '%s' expects request objects to be encrypted with the 'enc' value '%s' due to the client registration 'request_object_encryption_enc' value but the request object was encrypted with the 'enc' value '%s'.", hintRequestObjectPrefix(openid), client.GetID(), client.GetRequestObjectEncryptionEnc(), token.ContentEncryption) + case errJWTValidation.Has(jwt.ValidationErrorMalformedNotCompactSerialized): + return outer.WithDebugf("%s client with id '%s' provided a request object that was malformed. The request object does not appear to be a JWE or JWS compact serialized JWT.", hintRequestObjectPrefix(openid), client.GetID()) case errJWTValidation.Has(jwt.ValidationErrorMalformed): return outer.WithDebugf("%s client with id '%s' provided a request object that was malformed. %s.", hintRequestObjectPrefix(openid), client.GetID(), strings.TrimPrefix(errJWTValidation.Error(), "go-jose/go-jose: ")) case errJWTValidation.Has(jwt.ValidationErrorUnverifiable): diff --git a/client_authentication_strategy.go b/client_authentication_strategy.go index a96ea824..69da9544 100644 --- a/client_authentication_strategy.go +++ b/client_authentication_strategy.go @@ -10,8 +10,6 @@ import ( "strings" "time" - xjwt "github.com/golang-jwt/jwt/v5" - "authelia.com/provider/oauth2/internal/consts" "authelia.com/provider/oauth2/token/jwt" "authelia.com/provider/oauth2/x/errorsx" @@ -341,18 +339,36 @@ func (s *DefaultClientAuthenticationStrategy) getClientCredentialsSecretPost(for func resolveJWTErrorToRFCError(err error) (rfc error) { var e *RFC6749Error - switch { - case errors.As(err, &e): + if errors.As(err, &e) { return errorsx.WithStack(e) - case errors.Is(err, xjwt.ErrTokenMalformed): - return errorsx.WithStack(ErrInvalidClient.WithHint("Unable to decode the 'client_assertion' value as it is malformed or incomplete.").WithWrap(err).WithDebugError(err)) - case errors.Is(err, xjwt.ErrTokenUnverifiable): - return errorsx.WithStack(ErrInvalidClient.WithHint("Unable to decode the 'client_assertion' value as it is missing the information required to validate it.").WithWrap(err).WithDebugError(err)) - case errors.Is(err, xjwt.ErrTokenNotValidYet), errors.Is(err, xjwt.ErrTokenExpired), errors.Is(err, xjwt.ErrTokenUsedBeforeIssued): - return errorsx.WithStack(ErrInvalidClient.WithHint("Unable to verify the integrity of the 'client_assertion' value. It may have been used before it was issued, may have been used before it's allowed to be used, may have been used after it's expired, or otherwise doesn't meet a particular validation constraint.").WithWrap(err).WithDebugError(err)) - default: - return errorsx.WithStack(ErrInvalidClient.WithHint("Unable to decode 'client_assertion' value for an unknown reason.").WithWrap(err).WithDebugError(err)) } + + if errJWTValidation := new(jwt.ValidationError); errors.As(err, &errJWTValidation) { + switch { + case errJWTValidation.Has(jwt.ValidationErrorMalformed): + e = ErrInvalidClient. + WithHint("OAuth 2.0 client provided a client assertion which could not be decoded or validated."). + WithWrap(err). + WithDebugf("OAuth 2.0 client provided a client assertion that was malformed. %s.", strings.TrimPrefix(errJWTValidation.Error(), "go-jose/go-jose: ")) + case errJWTValidation.Has(jwt.ValidationErrorMalformedNotCompactSerialized): + e = ErrInvalidClient. + WithHint("OAuth 2.0 client provided a client assertion which could not be decoded or validated."). + WithWrap(err). + WithDebugf("OAuth 2.0 client provided a client assertion that was malformed. The client assertion does not appear to be a JWE or JWS compact serialized JWT.") + case errJWTValidation.Has(jwt.ValidationErrorUnverifiable): + e = ErrInvalidClient. + WithHint("OAuth 2.0 client provided a client assertion which could not be decoded or validated."). + WithWrap(err). + WithDebugf("OAuth 2.0 client provided a client assertion that was not able to be verified. %s.", strings.TrimPrefix(errJWTValidation.Error(), "go-jose/go-jose: ")) + default: + e = ErrInvalidClient. + WithHint("OAuth 2.0 client provided a client assertion which could not be decoded or validated."). + WithWrap(err). + WithDebugf("Unknown error occurred handling the client assertion.") + } + } + + return errorsx.WithStack(e) } func fmtClientAssertionDecodeError(token *jwt.Token, client AuthenticationMethodClient, handler EndpointClientAuthHandler, audience []string, inner error) (outer *RFC6749Error) { @@ -372,6 +388,8 @@ func fmtClientAssertionDecodeError(token *jwt.Token, client AuthenticationMethod return outer.WithDebugf("OAuth 2.0 client with id '%s' expects client assertions to be encrypted with the 'alg' value '%s' due to the client registration 'request_object_encryption_alg' value but the client assertion was encrypted with the 'alg' value '%s'.", client.GetID(), handler.GetAuthEncryptionAlg(client), token.KeyAlgorithm) case errJWTValidation.Has(jwt.ValidationErrorHeaderContentEncryptionInvalid): return outer.WithDebugf("OAuth 2.0 client with id '%s' expects client assertions to be encrypted with the 'enc' value '%s' due to the client registration 'request_object_encryption_enc' value but the client assertion was encrypted with the 'enc' value '%s'.", client.GetID(), handler.GetAuthEncryptionEnc(client), token.ContentEncryption) + case errJWTValidation.Has(jwt.ValidationErrorMalformedNotCompactSerialized): + return outer.WithDebugf("OAuth 2.0 client with id '%s' provided a client assertion that was malformed. The client assertion does not appear to be a JWE or JWS compact serialized JWT.", client.GetID()) case errJWTValidation.Has(jwt.ValidationErrorMalformed): return outer.WithDebugf("OAuth 2.0 client with id '%s' provided a client assertion that was malformed. %s.", client.GetID(), strings.TrimPrefix(errJWTValidation.Error(), "go-jose/go-jose: ")) case errJWTValidation.Has(jwt.ValidationErrorUnverifiable): diff --git a/client_authentication_test.go b/client_authentication_test.go index 3fe18b4f..19034572 100644 --- a/client_authentication_test.go +++ b/client_authentication_test.go @@ -12,6 +12,7 @@ import ( "net/http" "net/http/httptest" "net/url" + "regexp" "testing" "time" @@ -59,13 +60,13 @@ func TestAuthenticateClient(t *testing.T) { KeyID: "kid-foo", Use: "sig", Algorithm: "RS256", - Key: keyRSA, + Key: &keyRSA.PublicKey, }, { KeyID: "kid-foo", Use: "sig", Algorithm: "ES256", - Key: keyECDSA, + Key: &keyECDSA.PublicKey, }, }, } @@ -80,6 +81,7 @@ func TestAuthenticateClient(t *testing.T) { r *http.Request form url.Values err string + errRegexp *regexp.Regexp expectErr error }{ { @@ -392,7 +394,7 @@ func TestAuthenticateClient(t *testing.T) { { name: "ShouldFailBecauseRSAAssertionIsUsedButECDSAAssertionIsRequired", client: func(ts *httptest.Server) Client { - return &DefaultJARClient{DefaultClient: &DefaultClient{ID: "bar", ClientSecret: testClientSecretBar}, JSONWebKeys: jwksECDSA, TokenEndpointAuthMethod: "private_key_jwt", TokenEndpointAuthSigningAlg: "ES256"} + return &DefaultJARClient{DefaultClient: &DefaultClient{ID: "bar", ClientSecret: testClientSecretBar}, JSONWebKeys: jwks, TokenEndpointAuthMethod: "private_key_jwt", TokenEndpointAuthSigningAlg: "ES256"} }, form: url.Values{"client_assertion": {mustGenerateRSAAssertion(t, jwt.MapClaims{ consts.ClaimSubject: "bar", consts.ClaimExpirationTime: time.Now().Add(time.Hour).Unix(), @@ -402,7 +404,7 @@ func TestAuthenticateClient(t *testing.T) { }, keyRSA, "kid-foo")}, "client_assertion_type": []string{consts.ClientAssertionTypeJWTBearer}}, r: new(http.Request), expectErr: ErrInvalidClient, - err: "Client authentication failed (e.g., unknown client, no client authentication included, or unsupported authentication method). The requested OAuth 2.0 client does not support the 'token_endpoint_auth_signing_alg' value 'RS256'. The registered OAuth 2.0 client with id 'bar' only supports the 'ES256' algorithm.", + err: "Client authentication failed (e.g., unknown client, no client authentication included, or unsupported authentication method). OAuth 2.0 client with id 'bar' provided a client assertion which could not be decoded or validated. OAuth 2.0 client with id 'bar' expects client assertions to be signed with the 'alg' value 'ES256' due to the client registration 'request_object_signing_alg' value but the client assertion was signed with the 'alg' value 'RS256'.", }, { name: "ShouldFailBecauseMalformedAssertionUsed", @@ -411,7 +413,7 @@ func TestAuthenticateClient(t *testing.T) { }, form: url.Values{"client_assertion": []string{"bad.assertion"}, "client_assertion_type": []string{consts.ClientAssertionTypeJWTBearer}}, r: new(http.Request), expectErr: ErrInvalidClient, - err: "Client authentication failed (e.g., unknown client, no client authentication included, or unsupported authentication method). Unable to decode the 'client_assertion' value as it is malformed or incomplete. token is malformed: token contains an invalid number of segments", + err: "Client authentication failed (e.g., unknown client, no client authentication included, or unsupported authentication method). OAuth 2.0 client provided a client assertion which could not be decoded or validated. OAuth 2.0 client provided a client assertion that was malformed. The client assertion does not appear to be a JWE or JWS compact serialized JWT.", }, { name: "ShouldFailBecauseExpired", @@ -426,7 +428,7 @@ func TestAuthenticateClient(t *testing.T) { }, keyECDSA, "kid-foo")}, "client_assertion_type": []string{consts.ClientAssertionTypeJWTBearer}}, r: new(http.Request), expectErr: ErrInvalidClient, - err: "Client authentication failed (e.g., unknown client, no client authentication included, or unsupported authentication method). Unable to verify the integrity of the 'client_assertion' value. It may have been used before it was issued, may have been used before it's allowed to be used, may have been used after it's expired, or otherwise doesn't meet a particular validation constraint. token has invalid claims: token is expired", + errRegexp: regexp.MustCompile(`^Client authentication failed \(e\.g\., unknown client, no client authentication included, or unsupported authentication method\)\. OAuth 2\.0 client with id 'bar' provided a client assertion which could not be decoded or validated\. OAuth 2\.0 client with id 'bar' provided a client assertion that was expired\. The client assertion expired at \d+\.$`), }, { name: "ShouldFailBecauseNotBefore", @@ -442,7 +444,7 @@ func TestAuthenticateClient(t *testing.T) { }, keyECDSA, "kid-foo")}, "client_assertion_type": []string{consts.ClientAssertionTypeJWTBearer}}, r: new(http.Request), expectErr: ErrInvalidClient, - err: "Client authentication failed (e.g., unknown client, no client authentication included, or unsupported authentication method). Unable to verify the integrity of the 'client_assertion' value. It may have been used before it was issued, may have been used before it's allowed to be used, may have been used after it's expired, or otherwise doesn't meet a particular validation constraint. token has invalid claims: token is not valid yet", + errRegexp: regexp.MustCompile(`^Client authentication failed \(e\.g\., unknown client, no client authentication included, or unsupported authentication method\)\. OAuth 2\.0 client with id 'bar' provided a client assertion which could not be decoded or validated\. OAuth 2\.0 client with id 'bar' provided a client assertion that was issued in the future\. The client assertion is not valid before \d+\.$`), }, { name: "ShouldFailBecauseIssuedInFuture", @@ -458,7 +460,7 @@ func TestAuthenticateClient(t *testing.T) { }, keyECDSA, "kid-foo")}, "client_assertion_type": []string{consts.ClientAssertionTypeJWTBearer}}, r: new(http.Request), expectErr: ErrInvalidClient, - err: "Client authentication failed (e.g., unknown client, no client authentication included, or unsupported authentication method). Unable to verify the integrity of the 'client_assertion' value. It may have been used before it was issued, may have been used before it's allowed to be used, may have been used after it's expired, or otherwise doesn't meet a particular validation constraint. token has invalid claims: token used before issued", + errRegexp: regexp.MustCompile(`^Client authentication failed \(e\.g\., unknown client, no client authentication included, or unsupported authentication method\)\. OAuth 2\.0 client with id 'bar' provided a client assertion which could not be decoded or validated. OAuth 2\.0 client with id 'bar' provided a client assertion that was issued in the future\. The client assertion was issued at \d+\.$`), }, { name: "ShouldFailBecauseNoKeys", @@ -473,7 +475,7 @@ func TestAuthenticateClient(t *testing.T) { }, keyECDSA, "kid-foo")}, "client_assertion_type": []string{consts.ClientAssertionTypeJWTBearer}}, r: new(http.Request), expectErr: ErrInvalidClient, - err: "Client authentication failed (e.g., unknown client, no client authentication included, or unsupported authentication method). The OAuth 2.0 Client has no JSON Web Keys set registered, but they are needed to complete the request.", + err: "Client authentication failed (e.g., unknown client, no client authentication included, or unsupported authentication method). OAuth 2.0 client with id 'bar' provided a client assertion which could not be decoded or validated. OAuth 2.0 client with id 'bar' provided a client assertion that was not able to be verified. Error occurred retrieving the JSON Web Key. No JWKs have been registered for the client.", }, { name: "ShouldFailBecauseNotBefore", @@ -747,11 +749,7 @@ func TestAuthenticateClient(t *testing.T) { c, _, err := provider.AuthenticateClient(context.Background(), tc.r, tc.form) - if len(tc.err) != 0 { - require.EqualError(t, ErrorToDebugRFC6749Error(err), tc.err) - } - - if len(tc.err) == 0 && tc.expectErr == nil { + if len(tc.err) == 0 && tc.expectErr == nil && tc.errRegexp == nil { require.NoError(t, ErrorToDebugRFC6749Error(err)) assert.EqualValues(t, client, c) } else { @@ -762,6 +760,10 @@ func TestAuthenticateClient(t *testing.T) { if tc.expectErr != nil { assert.EqualError(t, err, tc.expectErr.Error()) } + + if tc.errRegexp != nil { + require.Regexp(t, tc.errRegexp, ErrorToDebugRFC6749Error(err).Error()) + } } }) } diff --git a/handler/oauth2/strategy_jwt_profile.go b/handler/oauth2/strategy_jwt_profile.go index d0703554..eb3f9c31 100644 --- a/handler/oauth2/strategy_jwt_profile.go +++ b/handler/oauth2/strategy_jwt_profile.go @@ -213,6 +213,8 @@ func fmtValidateJWTError(token *jwt.Token, client jwt.Client, inner error) (err return oauth2.ErrInvalidTokenFormat.WithDebugf("Token %sis expected to be encrypted with the 'alg' value '%s' but it was encrypted with the 'alg' value '%s'.", clientText, encAlg, token.KeyAlgorithm) case errJWTValidation.Has(jwt.ValidationErrorHeaderContentEncryptionInvalid): return oauth2.ErrInvalidTokenFormat.WithDebugf("Token %sis expected to be encrypted with the 'enc' value '%s' but it was encrypted with the 'enc' value '%s'.", clientText, enc, token.ContentEncryption) + case errJWTValidation.Has(jwt.ValidationErrorMalformedNotCompactSerialized): + return oauth2.ErrInvalidTokenFormat.WithDebugf("Token %sis malformed. The token does not appear to be a JWE or JWS compact serialized JWT.", clientText) case errJWTValidation.Has(jwt.ValidationErrorMalformed): return oauth2.ErrInvalidTokenFormat.WithDebugf("Token %sis malformed. %s.", clientText, strings.TrimPrefix(errJWTValidation.Error(), "go-jose/go-jose: ")) case errJWTValidation.Has(jwt.ValidationErrorUnverifiable): diff --git a/token/jwt/jwt_strategy.go b/token/jwt/jwt_strategy.go index efa8657d..f2ec3ab7 100644 --- a/token/jwt/jwt_strategy.go +++ b/token/jwt/jwt_strategy.go @@ -103,7 +103,7 @@ func (j *DefaultStrategy) Decrypt(ctx context.Context, tokenStringEnc string, op if IsSignedJWT(tokenStringEnc) { return tokenStringEnc, "", nil, nil } else { - return tokenStringEnc, "", nil, fmt.Errorf("Provided value does not appear to be a JWE or JWS compact serialized JWT") + return tokenStringEnc, "", nil, errorsx.WithStack(&ValidationError{text: "Provided value does not appear to be a JWE or JWS compact serialized JWT", Errors: ValidationErrorMalformedNotCompactSerialized}) } } diff --git a/token/jwt/util.go b/token/jwt/util.go index 54bf3334..040fb6ef 100644 --- a/token/jwt/util.go +++ b/token/jwt/util.go @@ -225,7 +225,7 @@ func FindClientPublicJWK(ctx context.Context, client BaseClient, fetcher JWKSFet return SearchJWKS(keys, kid, alg, use, strict) } - return nil, &JWKLookupError{Description: "No JWKs have been registered for the client."} + return nil, &JWKLookupError{Description: "No JWKs have been registered for the client"} } func SearchJWKS(jwks *jose.JSONWebKeySet, kid, alg, use string, strict bool) (key *jose.JSONWebKey, err error) { diff --git a/token/jwt/validation_error.go b/token/jwt/validation_error.go index 1e53cf03..085102fb 100644 --- a/token/jwt/validation_error.go +++ b/token/jwt/validation_error.go @@ -7,6 +7,7 @@ package jwt // from `jwt-go` to `go-jose`. const ( ValidationErrorMalformed uint32 = 1 << iota // Token is malformed + ValidationErrorMalformedNotCompactSerialized // Token is malformed specifically it does not have the compact serialized format. ValidationErrorUnverifiable // Token could not be verified because of signing problems ValidationErrorSignatureInvalid // Signature validation failed. ValidationErrorHeaderKeyIDInvalid // Header KID invalid error. diff --git a/token/jwt/validator.go b/token/jwt/validator.go deleted file mode 100644 index a983eec9..00000000 --- a/token/jwt/validator.go +++ /dev/null @@ -1,199 +0,0 @@ -package jwt - -/* -func NewValidator(opts ...ValidatorOpt) (validator *Validator) { - validator = &Validator{ - types: []string{consts.JSONWebTokenTypeJWT}, - nbf: -1, - exp: -1, - iat: -1, - } - - for _, opt := range opts { - opt(validator) - } - - return validator -} - -type ValidatorOpt func(*Validator) - -func ValidateIssuer(iss string) ValidatorOpt { - return func(validator *Validator) { - validator.iss = iss - } -} - -func ValidateSubject(sub string) ValidatorOpt { - return func(validator *Validator) { - validator.sub = sub - } -} - -func ValidateAudienceAll(aud []string) ValidatorOpt { - return func(validator *Validator) { - validator.audAll = aud - } -} - -func ValidateAudienceAny(aud []string) ValidatorOpt { - return func(validator *Validator) { - validator.audAny = aud - } -} - -func ValidateNotBefore(nbf int64) ValidatorOpt { - return func(validator *Validator) { - validator.nbf = nbf - } -} - -func ValidateRequireNotBefore() ValidatorOpt { - return func(validator *Validator) { - validator.requireNBF = true - } -} - -func ValidateExpires(exp int64) ValidatorOpt { - return func(validator *Validator) { - validator.exp = exp - } -} - -func ValidateRequireExpires() ValidatorOpt { - return func(validator *Validator) { - validator.requireEXP = true - } -} - -func ValidateIssuedAt(iat int64) ValidatorOpt { - return func(validator *Validator) { - validator.iat = iat - } -} - -func ValidateRequireIssuedAt() HeaderValidationOption { - return func(validator *Validator) { - validator.requireIAT = true - } -} - -*/ - -/* -type Validator struct { - types []string - alg string - kid string - iss string - sub string - audAll []string - audAny []string - nbf int64 - requireNBF bool - exp int64 - requireEXP bool - iat int64 - requireIAT bool -} - -func (v Validator) Validate(token *Token) (err error) { - vErr := new(ValidationError) - now := TimeFunc().Unix() - - if len(v.types) != 0 { - if !validateTokenType(v.types, token.Header) { - vErr.Inner = errors.New("token has an invalid typ") - vErr.Errors |= ValidationErrorHeaderTypeInvalid - } - } - - if len(v.alg) != 0 { - if v.alg != string(token.SignatureAlgorithm) { - vErr.Inner = errors.New("token has an invalid alg") - vErr.Errors |= ValidationErrorHeaderAlgorithmInvalid - } - } - - if len(v.kid) != 0 { - if v.kid != token.KeyID { - vErr.Inner = errors.New("token has an invalid kid") - vErr.Errors |= ValidationErrorHeaderKeyIDInvalid - } - } - - if len(v.iss) != 0 { - if !token.Claims.VerifyIssuer(v.iss, true) { - vErr.Inner = errors.New("token has an invalid issuer") - vErr.Errors |= ValidationErrorIssuer - } - } - - if len(v.sub) != 0 { - if !token.Claims.VerifySubject(v.sub, true) { - vErr.Inner = errors.New("token has an invalid subject") - vErr.Errors |= ValidationErrorSubject - } - } - - if len(v.audAll) != 0 { - if !token.Claims.VerifyAudienceAll(v.audAll, true) { - vErr.Inner = errors.New("token has an invalid audience") - vErr.Errors |= ValidationErrorAudience - } - } - - if len(v.audAny) != 0 { - if !token.Claims.VerifyAudienceAny(v.audAny, true) { - vErr.Inner = errors.New("token has an invalid audience") - vErr.Errors |= ValidationErrorAudience - } - } - - if v.exp != -1 { - exp := v.exp - - if exp == 0 { - exp = now - } - - if !token.Claims.VerifyExpiresAt(exp, v.requireEXP) { - vErr.Inner = errors.New("token is expired") - vErr.Errors |= ValidationErrorExpired - } - } - - if v.iat != -1 { - iat := v.iat - - if iat == 0 { - iat = now - } - - if !token.Claims.VerifyIssuedAt(iat, v.requireIAT) { - vErr.Inner = errors.New("token used before issued") - vErr.Errors |= ValidationErrorIssuedAt - } - } - - if v.nbf != -1 { - nbf := v.nbf - - if nbf == 0 { - nbf = now - } - - if !token.Claims.VerifyNotBefore(nbf, v.requireNBF) { - vErr.Inner = errors.New("token is not valid yet") - vErr.Errors |= ValidationErrorNotValidYet - } - } - - if vErr.valid() { - return nil - } - - return vErr -} - -*/ From 16ba9fd52df265824bfd2cd6b13c3eca99c20ec2 Mon Sep 17 00:00:00 2001 From: James Elliott Date: Wed, 25 Sep 2024 19:57:35 +1000 Subject: [PATCH 13/33] client auth tests --- ...orize_request_handler_oidc_request_test.go | 2 +- client_authentication.go | 29 -------- client_authentication_strategy.go | 69 +++++++++++-------- client_authentication_test.go | 18 ++--- token/jwt/jwt_strategy.go | 2 +- token/jwt/util.go | 16 +++-- 6 files changed, 63 insertions(+), 73 deletions(-) diff --git a/authorize_request_handler_oidc_request_test.go b/authorize_request_handler_oidc_request_test.go index 5402e7de..6f52afcf 100644 --- a/authorize_request_handler_oidc_request_test.go +++ b/authorize_request_handler_oidc_request_test.go @@ -187,7 +187,7 @@ func TestAuthorizeRequestParametersFromOpenIDConnectRequestObject(t *testing.T) client: &DefaultJARClient{JSONWebKeys: jwks, RequestObjectSigningAlg: "RS256", DefaultClient: &DefaultClient{ID: "foo"}}, expected: url.Values{consts.FormParameterScope: {consts.ScopeOpenID}}, err: ErrInvalidRequestObject, - errString: "The request parameter contains an invalid Request Object. OpenID Connect 1.0 request object could not be decoded or validated. OpenID Connect 1.0 client with id 'foo' provided a request object that could not be validated. Provided value does not appear to be a JWE or JWS compact serialized JWT.", + errString: "The request parameter contains an invalid Request Object. OpenID Connect 1.0 request object could not be decoded or validated. OpenID Connect 1.0 client with id 'foo' provided a request object that was malformed. The request object does not appear to be a JWE or JWS compact serialized JWT.", }, { name: "ShouldFailUnknownKID", diff --git a/client_authentication.go b/client_authentication.go index b7610d3a..5a4eddd2 100644 --- a/client_authentication.go +++ b/client_authentication.go @@ -111,35 +111,6 @@ func CompareClientSecret(ctx context.Context, client Client, rawSecret []byte) ( return err } -// FindClientPublicJWK takes a JARClient and a kid, alg, and use to resolve a Public JWK for the client. -func FindClientPublicJWK(ctx context.Context, provider JWKSFetcherStrategyProvider, client JSONWebKeysClient, kid, alg, use string) (key any, err error) { - if set := client.GetJSONWebKeys(); set != nil { - return findPublicKeyByKID(kid, alg, use, set) - } - - strategy := provider.GetJWKSFetcherStrategy(ctx) - - var keys *jose.JSONWebKeySet - - if location := client.GetJSONWebKeysURI(); len(location) > 0 { - if keys, err = strategy.Resolve(ctx, location, false); err != nil { - return nil, err - } - - if key, err = findPublicKeyByKID(kid, alg, use, keys); err == nil { - return key, nil - } - - if keys, err = strategy.Resolve(ctx, location, true); err != nil { - return nil, err - } - - return findPublicKeyByKID(kid, alg, use, keys) - } - - return nil, errorsx.WithStack(ErrInvalidClient.WithHint("The OAuth 2.0 Client has no JSON Web Keys set registered, but they are needed to complete the request.")) -} - func getClientCredentialsSecretBasic(r *http.Request) (id, secret string, ok bool, err error) { auth := r.Header.Get(consts.HeaderAuthorization) diff --git a/client_authentication_strategy.go b/client_authentication_strategy.go index 69da9544..e59b9032 100644 --- a/client_authentication_strategy.go +++ b/client_authentication_strategy.go @@ -139,8 +139,8 @@ func NewClientAssertion(ctx context.Context, strategy jwt.Strategy, store Client var ( token *jwt.Token - id, alg, method string - client Client + id, method string + client Client ) switch assertionType { @@ -152,7 +152,7 @@ func NewClientAssertion(ctx context.Context, strategy jwt.Strategy, store Client return &ClientAssertion{Assertion: assertion, Type: assertionType}, errorsx.WithStack(ErrInvalidRequest.WithHintf("Unknown client_assertion_type '%s'.", assertionType)) } - if token, err = strategy.Decode(ctx, assertion, jwt.WithAllowUnverified()); err != nil { + if token, err = strategy.Decode(ctx, assertion, jwt.WithAllowUnverified(), jwt.WithSigAlgorithm(jwt.SignatureAlgorithmsNone...)); err != nil { return &ClientAssertion{Assertion: assertion, Type: assertionType}, resolveJWTErrorToRFCError(err) } @@ -168,10 +168,10 @@ func NewClientAssertion(ctx context.Context, strategy jwt.Strategy, store Client return &ClientAssertion{Assertion: assertion, Type: assertionType, ID: id}, nil } - var c AuthenticationMethodClient + method = consts.ClientAuthMethodPrivateKeyJWT - if c, ok = client.(AuthenticationMethodClient); ok { - alg, method = handler.GetAuthSigningAlg(c), handler.GetAuthMethod(c) + if jwt.IsSignedJWTClientSecretAlg(token.SignatureAlgorithm) { + method = consts.ClientAuthMethodClientSecretJWT } return &ClientAssertion{ @@ -180,7 +180,7 @@ func NewClientAssertion(ctx context.Context, strategy jwt.Strategy, store Client Parsed: true, ID: id, Method: method, - Algorithm: alg, + Algorithm: string(token.SignatureAlgorithm), Client: client, }, nil } @@ -228,7 +228,7 @@ func (s *DefaultClientAuthenticationStrategy) doAuthenticateClientSecret(ctx con return "", errorsx.WithStack( ErrInvalidClient. WithHintf("The request was determined to be using '%s_endpoint_auth_method' method '%s', however the OAuth 2.0 client registration does not allow this method.", handler.Name(), method). - WithDebugf("The registered client with id '%s' is configured to only support '%s_endpoint_auth_method' method '%s'. Either the Authorization Server client registration will need to have the '%s_endpoint_auth_method' updated to '%s' or the Relying Party will need to be configured to use '%s'.", client.GetID(), handler.Name(), cmethod, handler.Name(), method, cmethod)) + WithDebugf("The registered client with id '%s' is configured to only support '%s_endpoint_auth_method' method '%s'. Either the Authorization Server client registration will need to have the '%s_endpoint_auth_method' updated to '%s'x or the Relying Party will need to be configured to use '%s'.", client.GetID(), handler.Name(), cmethod, handler.Name(), method, cmethod)) } } @@ -249,9 +249,19 @@ func (s *DefaultClientAuthenticationStrategy) doAuthenticateClientSecret(ctx con func (s *DefaultClientAuthenticationStrategy) doAuthenticateAssertionJWTBearer(ctx context.Context, client Client, assertion *ClientAssertion, handler EndpointClientAuthHandler) (method string, err error) { var ( token *jwt.Token + c AuthenticationMethodClient + ok bool ) - if method, _, _, token, err = s.doAuthenticateAssertionParseAssertionJWTBearer(ctx, client, assertion, handler); err != nil { + if c, ok = client.(AuthenticationMethodClient); !ok { + return "", errorsx.WithStack(ErrInvalidRequest.WithHint("The registered client does not support OAuth 2.0 JWT Profile Client Authentication RFC7523 or OpenID Connect 1.0 specific authentication methods.")) + } + + if !assertion.Parsed { + + } + + if method, _, _, token, err = s.doAuthenticateAssertionParseAssertionJWTBearer(ctx, c, assertion, handler); err != nil { return "", err } @@ -273,6 +283,16 @@ func (s *DefaultClientAuthenticationStrategy) doAuthenticateAssertionJWTBearer(c case claims.JTI == "": return "", errorsx.WithStack(ErrInvalidClient.WithHint("Claim 'jti' from 'client_assertion' must be set but is not.")) default: + switch cmethod := handler.GetAuthMethod(c); { + case cmethod == "" && handler.AllowAuthMethodAny(): + break + case cmethod != method: + return "", errorsx.WithStack( + ErrInvalidClient. + WithHintf("The request was determined to be using '%s_endpoint_auth_method' method '%s', however the OAuth 2.0 client registration does not allow this method.", handler.Name(), method). + WithDebugf("The registered client with id '%s' is configured to only support '%s_endpoint_auth_method' method '%s'. Either the Authorization Server client registration will need to have the '%s_endpoint_auth_method' updated to '%s' or the Relying Party will need to be configured to use '%s'.", client.GetID(), handler.Name(), cmethod, handler.Name(), method, cmethod)) + } + if err = s.Store.ClientAssertionJWTValid(ctx, claims.JTI); err != nil { return "", errorsx.WithStack(ErrJTIKnown.WithHint("Claim 'jti' from 'client_assertion' MUST only be used once.").WithDebugError(err)) } @@ -285,24 +305,15 @@ func (s *DefaultClientAuthenticationStrategy) doAuthenticateAssertionJWTBearer(c } } -func (s *DefaultClientAuthenticationStrategy) doAuthenticateAssertionParseAssertionJWTBearer(ctx context.Context, client Client, assertion *ClientAssertion, handler EndpointClientAuthHandler) (method, kid, alg string, token *jwt.Token, err error) { +func (s *DefaultClientAuthenticationStrategy) doAuthenticateAssertionParseAssertionJWTBearer(ctx context.Context, client AuthenticationMethodClient, assertion *ClientAssertion, handler EndpointClientAuthHandler) (method, kid, alg string, token *jwt.Token, err error) { audience := s.Config.GetAllowedJWTAssertionAudiences(ctx) if len(audience) == 0 { return "", "", "", nil, errorsx.WithStack(ErrMisconfiguration.WithHint("The authorization server does not support OAuth 2.0 JWT Profile Client Authentication RFC7523 or OpenID Connect 1.0 specific authentication methods.").WithDebug("The authorization server could not determine any safe value for it's audience but it's required to validate the RFC7523 client assertions.")) } - var ( - c AuthenticationMethodClient - ok bool - ) - - if c, ok = client.(AuthenticationMethodClient); !ok { - return "", "", "", nil, errorsx.WithStack(ErrInvalidRequest.WithHint("The registered client does not support OAuth 2.0 JWT Profile Client Authentication RFC7523 or OpenID Connect 1.0 specific authentication methods.")) - } - - if token, err = s.Config.GetJWTStrategy(ctx).Decode(ctx, assertion.Assertion, jwt.WithClient(&EndpointClientAuthJWTClient{client: c, handler: handler})); err != nil { - return "", "", "", nil, errorsx.WithStack(fmtClientAssertionDecodeError(token, c, handler, audience, err)) + if token, err = s.Config.GetJWTStrategy(ctx).Decode(ctx, assertion.Assertion, jwt.WithClient(&EndpointClientAuthJWTClient{client: client, handler: handler}), jwt.WithSigAlgorithm(jwt.SignatureAlgorithmsNone...)); err != nil { + return "", "", "", nil, errorsx.WithStack(fmtClientAssertionDecodeError(token, client, handler, audience, err)) } optsClaims := []jwt.ClaimValidationOption{ @@ -312,22 +323,22 @@ func (s *DefaultClientAuthenticationStrategy) doAuthenticateAssertionParseAssert } if err = token.Claims.Valid(optsClaims...); err != nil { - return "", "", "", nil, errorsx.WithStack(fmtClientAssertionDecodeError(token, c, handler, audience, err)) + return "", "", "", nil, errorsx.WithStack(fmtClientAssertionDecodeError(token, client, handler, audience, err)) } optsHeader := []jwt.HeaderValidationOption{ - jwt.ValidateKeyID(handler.GetAuthSigningKeyID(c)), - jwt.ValidateAlgorithm(handler.GetAuthSigningAlg(c)), - jwt.ValidateEncryptionKeyID(handler.GetAuthEncryptionKeyID(c)), - jwt.ValidateKeyAlgorithm(handler.GetAuthEncryptionAlg(c)), - jwt.ValidateContentEncryption(handler.GetAuthEncryptionEnc(c)), + jwt.ValidateKeyID(handler.GetAuthSigningKeyID(client)), + jwt.ValidateAlgorithm(handler.GetAuthSigningAlg(client)), + jwt.ValidateEncryptionKeyID(handler.GetAuthEncryptionKeyID(client)), + jwt.ValidateKeyAlgorithm(handler.GetAuthEncryptionAlg(client)), + jwt.ValidateContentEncryption(handler.GetAuthEncryptionEnc(client)), } if err = token.Valid(optsHeader...); err != nil { - return "", "", "", nil, errorsx.WithStack(fmtClientAssertionDecodeError(token, c, handler, audience, err)) + return "", "", "", nil, errorsx.WithStack(fmtClientAssertionDecodeError(token, client, handler, audience, err)) } - return method, kid, alg, token, nil + return assertion.Method, kid, alg, token, nil } func (s *DefaultClientAuthenticationStrategy) getClientCredentialsSecretPost(form url.Values) (id, secret string, ok bool) { diff --git a/client_authentication_test.go b/client_authentication_test.go index 19034572..7167940d 100644 --- a/client_authentication_test.go +++ b/client_authentication_test.go @@ -478,7 +478,7 @@ func TestAuthenticateClient(t *testing.T) { err: "Client authentication failed (e.g., unknown client, no client authentication included, or unsupported authentication method). OAuth 2.0 client with id 'bar' provided a client assertion which could not be decoded or validated. OAuth 2.0 client with id 'bar' provided a client assertion that was not able to be verified. Error occurred retrieving the JSON Web Key. No JWKs have been registered for the client.", }, { - name: "ShouldFailBecauseNotBefore", + name: "ShouldFailBecauseNotBeforeAlternative", client: func(ts *httptest.Server) Client { return &DefaultJARClient{DefaultClient: &DefaultClient{ID: "bar", ClientSecret: testClientSecretBar}, JSONWebKeys: jwksECDSA, TokenEndpointAuthMethod: "private_key_jwt", TokenEndpointAuthSigningAlg: "ES256"} }, form: url.Values{"client_assertion": {mustGenerateECDSAAssertion(t, jwt.MapClaims{ @@ -491,7 +491,7 @@ func TestAuthenticateClient(t *testing.T) { }, keyECDSA, "kid-foo")}, "client_assertion_type": []string{consts.ClientAssertionTypeJWTBearer}}, r: new(http.Request), expectErr: ErrInvalidClient, - err: "Client authentication failed (e.g., unknown client, no client authentication included, or unsupported authentication method). Unable to verify the integrity of the 'client_assertion' value. It may have been used before it was issued, may have been used before it's allowed to be used, may have been used after it's expired, or otherwise doesn't meet a particular validation constraint. token has invalid claims: token is not valid yet", + errRegexp: regexp.MustCompile(`^Client authentication failed \(e\.g\., unknown client, no client authentication included, or unsupported authentication method\)\. OAuth 2\.0 client with id 'bar' provided a client assertion which could not be decoded or validated\. OAuth 2\.0 client with id 'bar' provided a client assertion that was issued in the future\. The client assertion is not valid before \d+\.$`), }, { name: "ShouldFailBecauseTokenAuthMethodIsNotPrivateKeyJwtButClientSecretJwt", @@ -506,7 +506,7 @@ func TestAuthenticateClient(t *testing.T) { }, keyRSA, "kid-foo")}, "client_assertion_type": []string{consts.ClientAssertionTypeJWTBearer}}, r: new(http.Request), expectErr: ErrInvalidClient, - err: "Client authentication failed (e.g., unknown client, no client authentication included, or unsupported authentication method). The requested OAuth 2.0 client does not support the 'token_endpoint_auth_signing_alg' value 'RS256'.", + err: "Client authentication failed (e.g., unknown client, no client authentication included, or unsupported authentication method). The request was determined to be using 'token_endpoint_auth_method' method 'private_key_jwt', however the OAuth 2.0 client registration does not allow this method. The registered client with id 'bar' is configured to only support 'token_endpoint_auth_method' method 'client_secret_jwt'. Either the Authorization Server client registration will need to have the 'token_endpoint_auth_method' updated to 'private_key_jwt' or the Relying Party will need to be configured to use 'client_secret_jwt'.", }, { name: "ShouldFailBecauseTokenAuthMethodIsNotPrivateKeyJwtButNone", @@ -521,7 +521,7 @@ func TestAuthenticateClient(t *testing.T) { }, keyRSA, "kid-foo")}, "client_assertion_type": []string{consts.ClientAssertionTypeJWTBearer}}, r: new(http.Request), expectErr: ErrInvalidClient, - err: "Client authentication failed (e.g., unknown client, no client authentication included, or unsupported authentication method). This requested OAuth 2.0 client does not support client authentication, however 'client_assertion' was provided in the request.", + err: "Client authentication failed (e.g., unknown client, no client authentication included, or unsupported authentication method). The request was determined to be using 'token_endpoint_auth_method' method 'private_key_jwt', however the OAuth 2.0 client registration does not allow this method. The registered client with id 'bar' is configured to only support 'token_endpoint_auth_method' method 'none'. Either the Authorization Server client registration will need to have the 'token_endpoint_auth_method' updated to 'private_key_jwt' or the Relying Party will need to be configured to use 'none'.", }, { name: "ShouldFailBecauseTokenAuthMethodIsNotPrivateKeyJwtButClientSecretPost", @@ -536,7 +536,7 @@ func TestAuthenticateClient(t *testing.T) { }, keyRSA, "kid-foo")}, "client_assertion_type": []string{consts.ClientAssertionTypeJWTBearer}}, r: new(http.Request), expectErr: ErrInvalidClient, - err: "Client authentication failed (e.g., unknown client, no client authentication included, or unsupported authentication method). This requested OAuth 2.0 client only supports client authentication method 'client_secret_post', however 'client_assertion' was provided in the request.", + err: "Client authentication failed (e.g., unknown client, no client authentication included, or unsupported authentication method). The request was determined to be using 'token_endpoint_auth_method' method 'private_key_jwt', however the OAuth 2.0 client registration does not allow this method. The registered client with id 'bar' is configured to only support 'token_endpoint_auth_method' method 'client_secret_post'. Either the Authorization Server client registration will need to have the 'token_endpoint_auth_method' updated to 'private_key_jwt' or the Relying Party will need to be configured to use 'client_secret_post'.", }, { name: "ShouldFailBecauseTokenAuthMethodIsNotPrivateKeyJwtButClientSecretBasic", @@ -551,7 +551,7 @@ func TestAuthenticateClient(t *testing.T) { }, keyRSA, "kid-foo")}, "client_assertion_type": []string{consts.ClientAssertionTypeJWTBearer}}, r: new(http.Request), expectErr: ErrInvalidClient, - err: "Client authentication failed (e.g., unknown client, no client authentication included, or unsupported authentication method). This requested OAuth 2.0 client only supports client authentication method 'client_secret_basic', however 'client_assertion' was provided in the request.", + err: "Client authentication failed (e.g., unknown client, no client authentication included, or unsupported authentication method). The request was determined to be using 'token_endpoint_auth_method' method 'private_key_jwt', however the OAuth 2.0 client registration does not allow this method. The registered client with id 'bar' is configured to only support 'token_endpoint_auth_method' method 'client_secret_basic'. Either the Authorization Server client registration will need to have the 'token_endpoint_auth_method' updated to 'private_key_jwt' or the Relying Party will need to be configured to use 'client_secret_basic'.", }, { name: "ShouldFailBecauseTokenAuthMethodIsNotPrivateKeyJwtButFoobar", @@ -593,7 +593,7 @@ func TestAuthenticateClient(t *testing.T) { }, keyRSA, "kid-foo")}, "client_assertion_type": []string{consts.ClientAssertionTypeJWTBearer}}, r: new(http.Request), expectErr: ErrInvalidClient, - err: "Client authentication failed (e.g., unknown client, no client authentication included, or unsupported authentication method). Unable to verify the integrity of the 'client_assertion' value. It may have been used before it was issued, may have been used before it's allowed to be used, may have been used after it's expired, or otherwise doesn't meet a particular validation constraint. Unable to validate the 'aud' claim of the 'client_assertion' value 'token-url-1', 'token-url-2' as it doesn't match any of the expected values 'token-url'.", + err: "Client authentication failed (e.g., unknown client, no client authentication included, or unsupported authentication method). OAuth 2.0 client with id 'bar' provided a client assertion which could not be decoded or validated. OAuth 2.0 client with id 'bar' provided a client assertion that has an invalid audience. The client assertion was expected to have an 'aud' claim which matches one of the values 'token-url' but the 'aud' claim had the values 'token-url-1', 'token-url-2'.", }, { name: "ShouldPassWithProperAssertionWhenJWKsAreSetWithinTheClient", @@ -635,7 +635,7 @@ func TestAuthenticateClient(t *testing.T) { }, keyRSA, "kid-foo")}, "client_assertion_type": []string{consts.ClientAssertionTypeJWTBearer}}, r: new(http.Request), expectErr: ErrInvalidClient, - err: "Client authentication failed (e.g., unknown client, no client authentication included, or unsupported authentication method). The requested OAuth 2.0 client does not support the 'token_endpoint_auth_signing_alg' value 'none'. The registered OAuth 2.0 client with id 'bar' only supports the 'RS256' algorithm.", + err: "Client authentication failed (e.g., unknown client, no client authentication included, or unsupported authentication method). OAuth 2.0 client with id 'bar' provided a client assertion which could not be decoded or validated. OAuth 2.0 client with id 'bar' expects client assertions to be signed with the 'alg' value 'RS256' due to the client registration 'request_object_signing_alg' value but the client assertion was signed with the 'alg' value 'none'.", }, { name: "ShouldPassWithProperAssertionWhenJWKsURIIsSet", @@ -663,7 +663,7 @@ func TestAuthenticateClient(t *testing.T) { }, keyRSA, "kid-foo")}, "client_assertion_type": []string{consts.ClientAssertionTypeJWTBearer}}, r: new(http.Request), expectErr: ErrInvalidClient, - err: "Client authentication failed (e.g., unknown client, no client authentication included, or unsupported authentication method). The supplied 'client_id' did not match the 'sub' claim of the 'client_assertion'.", + err: "Client authentication failed (e.g., unknown client, no client authentication included, or unsupported authentication method). Claim 'sub' from 'client_assertion' must match the 'client_id' of the OAuth 2.0 Client.", }, { name: "ShouldFailBecauseClientAssertionIssDoesNotMatchClient", diff --git a/token/jwt/jwt_strategy.go b/token/jwt/jwt_strategy.go index f2ec3ab7..e45a2e85 100644 --- a/token/jwt/jwt_strategy.go +++ b/token/jwt/jwt_strategy.go @@ -291,7 +291,7 @@ func (j *DefaultStrategy) validate(ctx context.Context, t *jwt.JSONWebToken, des return errorsx.WithStack(&ValidationError{Errors: ValidationErrorUnverifiable, Inner: err}) } } else if o.client != nil && o.client.IsClientSigned() { - if IsSignedJWTClientSecretAlg(alg) { + if IsSignedJWTClientSecretAlgStr(alg) { if kid != "" { return errorsx.WithStack(&ValidationError{Errors: ValidationErrorHeaderKeyIDInvalid, Inner: fmt.Errorf("error validating the jws header: alg '%s' does not support tokens with a kid but the token has kid '%s'", alg, kid)}) } diff --git a/token/jwt/util.go b/token/jwt/util.go index 040fb6ef..da7c7714 100644 --- a/token/jwt/util.go +++ b/token/jwt/util.go @@ -28,9 +28,17 @@ func IsEncryptedJWT(tokenString string) (encrypted bool) { return reEncryptedJWT.MatchString(tokenString) } -// IsSignedJWTClientSecretAlg returns true if the given alg string is a client secret based signature algorithm. -func IsSignedJWTClientSecretAlg(alg string) (csa bool) { - switch a := jose.SignatureAlgorithm(alg); a { +// IsSignedJWTClientSecretAlgStr returns true if the given alg string is a client secret based signature algorithm. +func IsSignedJWTClientSecretAlgStr(alg string) (csa bool) { + if a := jose.SignatureAlgorithm(alg); IsSignedJWTClientSecretAlg(a) { + return true + } + + return false +} + +func IsSignedJWTClientSecretAlg(alg jose.SignatureAlgorithm) (csa bool) { + switch alg { case jose.HS256, jose.HS384, jose.HS512: return true default: @@ -343,7 +351,7 @@ func getPublicJWK(jwk *jose.JSONWebKey) jose.JSONWebKey { return jose.JSONWebKey{} } - if _, ok := jwk.Key.([]byte); ok && IsSignedJWTClientSecretAlg(jwk.Algorithm) { + if _, ok := jwk.Key.([]byte); ok && IsSignedJWTClientSecretAlgStr(jwk.Algorithm) { return jose.JSONWebKey{ KeyID: jwk.KeyID, Key: jwk.Key, From b4a19f7d48438f07e75da809822fb8a2142fa7f2 Mon Sep 17 00:00:00 2001 From: James Elliott Date: Wed, 25 Sep 2024 20:26:57 +1000 Subject: [PATCH 14/33] client auth tests --- client_authentication_test.go | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/client_authentication_test.go b/client_authentication_test.go index 7167940d..48b2c507 100644 --- a/client_authentication_test.go +++ b/client_authentication_test.go @@ -779,9 +779,10 @@ func TestAuthenticateClientTwice(t *testing.T) { JSONWebKeys: &jose.JSONWebKeySet{ Keys: []jose.JSONWebKey{ { - KeyID: "kid-foo", - Use: consts.JSONWebTokenUseSignature, - Key: &key.PublicKey, + KeyID: "kid-foo", + Use: consts.JSONWebTokenUseSignature, + Algorithm: "RS256", + Key: &key.PublicKey, }, }, }, @@ -814,13 +815,14 @@ func TestAuthenticateClientTwice(t *testing.T) { }, key, "kid-foo")}, consts.FormParameterClientAssertionType: []string{consts.ClientAssertionTypeJWTBearer}} c, _, err := provider.AuthenticateClient(context.TODO(), new(http.Request), formValues) - require.NoError(t, err, "%#v", err) + require.NoError(t, ErrorToDebugRFC6749Error(err)) assert.Equal(t, client, c) // replay the request and expect it to fail c, _, err = provider.AuthenticateClient(context.TODO(), new(http.Request), formValues) require.Error(t, err) assert.EqualError(t, err, ErrJTIKnown.Error()) + assert.EqualError(t, ErrorToDebugRFC6749Error(err), "The jti was already used. Claim 'jti' from 'client_assertion' MUST only be used once. The jti was already used.") assert.Nil(t, c) } From fbc22dbb3326558e1346e2281ba87e3a6958aac9 Mon Sep 17 00:00:00 2001 From: James Elliott Date: Wed, 25 Sep 2024 21:33:49 +1000 Subject: [PATCH 15/33] client auth tests --- token/jwt/errors.go | 7 ------- token/jwt/jwt_strategy.go | 4 ++-- 2 files changed, 2 insertions(+), 9 deletions(-) delete mode 100644 token/jwt/errors.go diff --git a/token/jwt/errors.go b/token/jwt/errors.go deleted file mode 100644 index 6607be86..00000000 --- a/token/jwt/errors.go +++ /dev/null @@ -1,7 +0,0 @@ -package jwt - -import "errors" - -var ( - ErrNotRegistered = errors.New("error: no JWKS registered") -) diff --git a/token/jwt/jwt_strategy.go b/token/jwt/jwt_strategy.go index e45a2e85..f068e474 100644 --- a/token/jwt/jwt_strategy.go +++ b/token/jwt/jwt_strategy.go @@ -108,7 +108,7 @@ func (j *DefaultStrategy) Decrypt(ctx context.Context, tokenStringEnc string, op } o := &StrategyOpts{ - sigAlgorithm: SignatureAlgorithmsNone, + sigAlgorithm: SignatureAlgorithms, keyAlgorithm: EncryptionKeyAlgorithms, contentEncryption: ContentEncryptionAlgorithms, } @@ -161,7 +161,7 @@ func (j *DefaultStrategy) Decrypt(ctx context.Context, tokenStringEnc string, op var t *jwt.JSONWebToken - if t, err = jwt.ParseSigned(tokenString, SignatureAlgorithmsNone); err != nil { + if t, err = jwt.ParseSigned(tokenString, o.sigAlgorithm); err != nil { return "", "", nil, errorsx.WithStack(&ValidationError{Errors: ValidationErrorMalformed, Inner: err}) } From 2d195d15b1551b6a4f3ec2530cfe2a41f18098db Mon Sep 17 00:00:00 2001 From: James Elliott Date: Wed, 25 Sep 2024 21:49:54 +1000 Subject: [PATCH 16/33] client auth tests --- client_authentication.go | 145 ------------------------------ client_authentication_strategy.go | 10 ++- client_authentication_test.go | 8 +- 3 files changed, 11 insertions(+), 152 deletions(-) diff --git a/client_authentication.go b/client_authentication.go index 5a4eddd2..4bc22158 100644 --- a/client_authentication.go +++ b/client_authentication.go @@ -6,19 +6,15 @@ package oauth2 import ( "context" "crypto" - "crypto/ecdsa" - "crypto/rsa" "encoding/base64" "errors" "net/http" "net/url" - "sort" "strings" "github.com/go-jose/go-jose/v4" "authelia.com/provider/oauth2/internal/consts" - "authelia.com/provider/oauth2/token/jwt" "authelia.com/provider/oauth2/x/errorsx" ) @@ -47,34 +43,6 @@ func (f *Fosite) AuthenticateClientWithAuthHandler(ctx context.Context, r *http. return strategy.AuthenticateClient(ctx, r, form, handler) } -func (f *Fosite) findClientPublicJWK(ctx context.Context, client JARClient, t *jwt.Token, expectsRSAKey bool) (key any, err error) { - var ( - keys *jose.JSONWebKeySet - ) - - if keys = client.GetJSONWebKeys(); keys != nil { - return findPublicKey(t, keys, expectsRSAKey) - } - - if location := client.GetJSONWebKeysURI(); len(location) > 0 { - if keys, err = f.Config.GetJWKSFetcherStrategy(ctx).Resolve(ctx, location, false); err != nil { - return nil, err - } - - if key, err = findPublicKey(t, keys, expectsRSAKey); err == nil { - return key, nil - } - - if keys, err = f.Config.GetJWKSFetcherStrategy(ctx).Resolve(ctx, location, true); err != nil { - return nil, err - } - - return findPublicKey(t, keys, expectsRSAKey) - } - - return nil, errorsx.WithStack(ErrInvalidClient.WithHint("The OAuth 2.0 Client has no JSON Web Keys set registered, but they are needed to complete the request.")) -} - // CompareClientSecret compares a raw secret input from a client to the registered client secret. If the secret is valid // it returns nil, otherwise it returns an error. The ErrClientSecretNotRegistered error indicates the ClientSecret // is nil, all other errors are returned directly from the ClientSecret.Compare function. @@ -159,119 +127,6 @@ func getClientCredentialsSecretBasic(r *http.Request) (id, secret string, ok boo return id, secret, secret != "", nil } -func getJWTHeaderKIDAlg(header map[string]any) (kid, alg string) { - kid, _ = header[consts.JSONWebTokenHeaderKeyIdentifier].(string) - alg, _ = header[consts.JSONWebTokenHeaderAlgorithm].(string) - - return kid, alg -} - -type partial struct { - points int - jwk jose.JSONWebKey -} - -func findPublicKeyByKID(kid, alg, use string, set *jose.JSONWebKeySet) (key any, err error) { - if len(set.Keys) == 0 { - return nil, errorsx.WithStack(ErrInvalidRequest.WithHintf("The retrieved JSON Web Key Set does not contain any JSON Web Keys.")) - } - - partials := []partial{} - - for _, jwk := range set.Keys { - if jwk.Use == use && jwk.Algorithm == alg && jwk.KeyID == kid { - switch k := jwk.Key.(type) { - case PrivateKey: - return k.Public(), nil - default: - return k, nil - } - } - - p := partial{} - - if jwk.KeyID != kid { - if jwk.KeyID == "" { - p.points -= 3 - } else { - continue - } - } - - if jwk.Use != use { - if jwk.Use == "" { - p.points -= 2 - } else { - continue - } - } - - if jwk.Algorithm != alg && jwk.Algorithm != "" { - if jwk.Algorithm == "" { - p.points -= 1 - } else { - continue - } - } - - p.jwk = jwk - - partials = append(partials, p) - } - - if len(partials) != 0 { - sort.Slice(partials, func(i, j int) bool { - return partials[i].points > partials[j].points - }) - - switch k := partials[0].jwk.Key.(type) { - case PrivateKey: - return k.Public(), nil - default: - return k, nil - } - } - - return nil, errorsx.WithStack(ErrInvalidRequest.WithHintf("Unable to find JWK with kid value '%s', alg value '%s', and use value '%s' in the JSON Web Key Set.", kid, alg, use)) -} - -func findPublicKey(t *jwt.Token, set *jose.JSONWebKeySet, expectsRSAKey bool) (any, error) { - keys := set.Keys - if len(keys) == 0 { - return nil, errorsx.WithStack(ErrInvalidRequest.WithHintf("The retrieved JSON Web Key Set does not contain any key.")) - } - - kid, ok := t.Header[consts.JSONWebTokenHeaderKeyIdentifier].(string) - if ok { - keys = set.Key(kid) - } - - if len(keys) == 0 { - return nil, errorsx.WithStack(ErrInvalidRequest.WithHintf("The JSON Web Token uses signing key with kid '%s', which could not be found.", kid)) - } - - for _, key := range keys { - if key.Use != consts.JSONWebTokenUseSignature { - continue - } - if expectsRSAKey { - if k, ok := key.Key.(*rsa.PublicKey); ok { - return k, nil - } - } else { - if k, ok := key.Key.(*ecdsa.PublicKey); ok { - return k, nil - } - } - } - - if expectsRSAKey { - return nil, errorsx.WithStack(ErrInvalidRequest.WithHintf("Unable to find RSA public key with use='sig' for kid '%s' in JSON Web Key Set.", kid)) - } else { - return nil, errorsx.WithStack(ErrInvalidRequest.WithHintf("Unable to find ECDSA public key with use='sig' for kid '%s' in JSON Web Key Set.", kid)) - } -} - func getClientCredentialsClientAssertion(form url.Values) (assertion, assertionType string, hasAssertion bool) { assertionType, assertion = form.Get(consts.FormParameterClientAssertionType), form.Get(consts.FormParameterClientAssertion) diff --git a/client_authentication_strategy.go b/client_authentication_strategy.go index e59b9032..6baa63c2 100644 --- a/client_authentication_strategy.go +++ b/client_authentication_strategy.go @@ -277,11 +277,11 @@ func (s *DefaultClientAuthenticationStrategy) doAuthenticateAssertionJWTBearer(c switch { case subtle.ConstantTimeCompare([]byte(claims.Issuer), clientID) == 0: - return "", errorsx.WithStack(ErrInvalidClient.WithHint("Claim 'iss' from 'client_assertion' must match the 'client_id' of the OAuth 2.0 Client.")) + return "", errorsx.WithStack(ErrInvalidClient.WithHint("The client assertion had invalid claims.").WithDebug("Claim 'iss' from 'client_assertion' must match the 'client_id' of the OAuth 2.0 Client.")) case subtle.ConstantTimeCompare([]byte(claims.Subject), clientID) == 0: - return "", errorsx.WithStack(ErrInvalidClient.WithHint("Claim 'sub' from 'client_assertion' must match the 'client_id' of the OAuth 2.0 Client.")) + return "", errorsx.WithStack(ErrInvalidClient.WithHint("The client assertion had invalid claims.").WithDebug("Claim 'sub' from 'client_assertion' must match the 'client_id' of the OAuth 2.0 Client.")) case claims.JTI == "": - return "", errorsx.WithStack(ErrInvalidClient.WithHint("Claim 'jti' from 'client_assertion' must be set but is not.")) + return "", errorsx.WithStack(ErrInvalidClient.WithHint("The client assertion had invalid claims.").WithDebug("Claim 'jti' from 'client_assertion' must be set but is not.")) default: switch cmethod := handler.GetAuthMethod(c); { case cmethod == "" && handler.AllowAuthMethodAny(): @@ -293,6 +293,10 @@ func (s *DefaultClientAuthenticationStrategy) doAuthenticateAssertionJWTBearer(c WithDebugf("The registered client with id '%s' is configured to only support '%s_endpoint_auth_method' method '%s'. Either the Authorization Server client registration will need to have the '%s_endpoint_auth_method' updated to '%s' or the Relying Party will need to be configured to use '%s'.", client.GetID(), handler.Name(), cmethod, handler.Name(), method, cmethod)) } + if !assertion.Parsed { + return "", errorsx.WithStack(ErrInvalidClient.WithDebug("The client assertion was not able to be parsed.")) + } + if err = s.Store.ClientAssertionJWTValid(ctx, claims.JTI); err != nil { return "", errorsx.WithStack(ErrJTIKnown.WithHint("Claim 'jti' from 'client_assertion' MUST only be used once.").WithDebugError(err)) } diff --git a/client_authentication_test.go b/client_authentication_test.go index 48b2c507..3ad21753 100644 --- a/client_authentication_test.go +++ b/client_authentication_test.go @@ -663,7 +663,7 @@ func TestAuthenticateClient(t *testing.T) { }, keyRSA, "kid-foo")}, "client_assertion_type": []string{consts.ClientAssertionTypeJWTBearer}}, r: new(http.Request), expectErr: ErrInvalidClient, - err: "Client authentication failed (e.g., unknown client, no client authentication included, or unsupported authentication method). Claim 'sub' from 'client_assertion' must match the 'client_id' of the OAuth 2.0 Client.", + err: "Client authentication failed (e.g., unknown client, no client authentication included, or unsupported authentication method). The client assertion had invalid claims. Claim 'sub' from 'client_assertion' must match the 'client_id' of the OAuth 2.0 Client.", }, { name: "ShouldFailBecauseClientAssertionIssDoesNotMatchClient", @@ -678,10 +678,10 @@ func TestAuthenticateClient(t *testing.T) { }, keyRSA, "kid-foo")}, "client_assertion_type": []string{consts.ClientAssertionTypeJWTBearer}}, r: new(http.Request), expectErr: ErrInvalidClient, - err: "Client authentication failed (e.g., unknown client, no client authentication included, or unsupported authentication method). Claim 'iss' from 'client_assertion' must match the 'client_id' of the OAuth 2.0 Client.", + err: "Client authentication failed (e.g., unknown client, no client authentication included, or unsupported authentication method). The client assertion had invalid claims. Claim 'iss' from 'client_assertion' must match the 'client_id' of the OAuth 2.0 Client.", }, { - name: "ShouldFailBecauseClientAssertionJtiIsNotSet", + name: "ShouldFailBecauseClientAssertionJTIIsNotSet", client: func(ts *httptest.Server) Client { return &DefaultJARClient{DefaultClient: &DefaultClient{ID: "bar", ClientSecret: testClientSecretBar}, JSONWebKeys: jwksRSA, TokenEndpointAuthMethod: "private_key_jwt"} }, form: url.Values{"client_id": []string{"bar"}, "client_assertion": {mustGenerateRSAAssertion(t, jwt.MapClaims{ @@ -692,7 +692,7 @@ func TestAuthenticateClient(t *testing.T) { }, keyRSA, "kid-foo")}, "client_assertion_type": []string{consts.ClientAssertionTypeJWTBearer}}, r: new(http.Request), expectErr: ErrInvalidClient, - err: "Client authentication failed (e.g., unknown client, no client authentication included, or unsupported authentication method). Claim 'jti' from 'client_assertion' must be set but is not.", + err: "Client authentication failed (e.g., unknown client, no client authentication included, or unsupported authentication method). The client assertion had invalid claims. Claim 'jti' from 'client_assertion' must be set but is not.", }, { name: "ShouldFailBecauseClientAssertionAudIsNotSet", From ea8db712871082adef1d131362008d09fb88955a Mon Sep 17 00:00:00 2001 From: James Elliott Date: Fri, 27 Sep 2024 16:21:33 +1000 Subject: [PATCH 17/33] client auth tests --- authorize_request_handler.go | 16 +- ...orize_request_handler_oidc_request_test.go | 10 +- client_authentication_strategy.go | 16 +- client_authentication_test.go | 6 +- handler/oauth2/strategy_jwt_profile.go | 16 +- token/jwt/jwt_signer_test.go | 320 ------------- token/jwt/jwt_strategy_test.go | 449 ++++++++++++++++-- token/jwt/token.go | 22 + token/jwt/util.go | 8 + token/jwt/validation_error.go | 20 +- 10 files changed, 486 insertions(+), 397 deletions(-) delete mode 100644 token/jwt/jwt_signer_test.go diff --git a/authorize_request_handler.go b/authorize_request_handler.go index cc0f48ca..fb24c97f 100644 --- a/authorize_request_handler.go +++ b/authorize_request_handler.go @@ -586,17 +586,21 @@ func fmtRequestObjectDecodeError(token *jwt.Token, client JARClient, issuer stri if errJWTValidation := new(jwt.ValidationError); errors.As(inner, &errJWTValidation) { switch { case errJWTValidation.Has(jwt.ValidationErrorHeaderKeyIDInvalid): - return outer.WithDebugf("%s client with id '%s' expects request objects to be signed with the 'kid' value '%s' due to the client registration 'request_object_signing_key_id' value but the request object was signed with the 'kid' value '%s'.", hintRequestObjectPrefix(openid), client.GetID(), client.GetRequestObjectSigningKeyID(), token.KeyID) + return outer.WithDebugf("%s client with id '%s' expects request objects to be signed with the 'kid' header value '%s' due to the client registration 'request_object_signing_key_id' value but the request object was signed with the 'kid' header value '%s'.", hintRequestObjectPrefix(openid), client.GetID(), client.GetRequestObjectSigningKeyID(), token.KeyID) case errJWTValidation.Has(jwt.ValidationErrorHeaderAlgorithmInvalid): - return outer.WithDebugf("%s client with id '%s' expects request objects to be signed with the 'alg' value '%s' due to the client registration 'request_object_signing_alg' value but the request object was signed with the 'alg' value '%s'.", hintRequestObjectPrefix(openid), client.GetID(), client.GetRequestObjectSigningAlg(), token.SignatureAlgorithm) + return outer.WithDebugf("%s client with id '%s' expects request objects to be signed with the 'alg' header value '%s' due to the client registration 'request_object_signing_alg' value but the request object was signed with the 'alg' header value '%s'.", hintRequestObjectPrefix(openid), client.GetID(), client.GetRequestObjectSigningAlg(), token.SignatureAlgorithm) case errJWTValidation.Has(jwt.ValidationErrorHeaderTypeInvalid): - return outer.WithDebugf("%s client with id '%s' expects request objects to be signed with the 'typ' value '%s' but the request object was signed with the 'typ' value '%s'.", hintRequestObjectPrefix(openid), client.GetID(), consts.JSONWebTokenTypeJWT, token.Header[consts.JSONWebTokenHeaderType]) + return outer.WithDebugf("%s client with id '%s' expects request objects to be signed with the 'typ' header value '%s' but the request object was signed with the 'typ' header value '%s'.", hintRequestObjectPrefix(openid), client.GetID(), consts.JSONWebTokenTypeJWT, token.Header[consts.JSONWebTokenHeaderType]) + case errJWTValidation.Has(jwt.ValidationErrorHeaderEncryptionTypeInvalid): + return outer.WithDebugf("%s client with id '%s' expects request objects to be encrypted with the 'typ' header value '%s' but the request object was encrypted with the 'typ' header value '%s'.", hintRequestObjectPrefix(openid), client.GetID(), consts.JSONWebTokenTypeJWT, token.HeaderJWE[consts.JSONWebTokenHeaderType]) + case errJWTValidation.Has(jwt.ValidationErrorHeaderContentTypeInvalid): + return outer.WithDebugf("%s client with id '%s' expects request objects to be encrypted with the 'cty' header value '%s' but the request object was encrypted with the 'cty' header value '%s'.", hintRequestObjectPrefix(openid), client.GetID(), consts.JSONWebTokenTypeJWT, token.HeaderJWE[consts.JSONWebTokenHeaderContentType]) case errJWTValidation.Has(jwt.ValidationErrorHeaderEncryptionKeyIDInvalid): - return outer.WithDebugf("%s client with id '%s' expects request objects to be encrypted with the 'kid' value '%s' due to the client registration 'request_object_encryption_key_id' value but the request object was encrypted with the 'kid' value '%s'.", hintRequestObjectPrefix(openid), client.GetID(), client.GetRequestObjectEncryptionKeyID(), token.EncryptionKeyID) + return outer.WithDebugf("%s client with id '%s' expects request objects to be encrypted with the 'kid' header value '%s' due to the client registration 'request_object_encryption_key_id' value but the request object was encrypted with the 'kid' header value '%s'.", hintRequestObjectPrefix(openid), client.GetID(), client.GetRequestObjectEncryptionKeyID(), token.EncryptionKeyID) case errJWTValidation.Has(jwt.ValidationErrorHeaderKeyAlgorithmInvalid): - return outer.WithDebugf("%s client with id '%s' expects request objects to be encrypted with the 'alg' value '%s' due to the client registration 'request_object_encryption_alg' value but the request object was encrypted with the 'alg' value '%s'.", hintRequestObjectPrefix(openid), client.GetID(), client.GetRequestObjectEncryptionAlg(), token.KeyAlgorithm) + return outer.WithDebugf("%s client with id '%s' expects request objects to be encrypted with the 'alg' header value '%s' due to the client registration 'request_object_encryption_alg' value but the request object was encrypted with the 'alg' header value '%s'.", hintRequestObjectPrefix(openid), client.GetID(), client.GetRequestObjectEncryptionAlg(), token.KeyAlgorithm) case errJWTValidation.Has(jwt.ValidationErrorHeaderContentEncryptionInvalid): - return outer.WithDebugf("%s client with id '%s' expects request objects to be encrypted with the 'enc' value '%s' due to the client registration 'request_object_encryption_enc' value but the request object was encrypted with the 'enc' value '%s'.", hintRequestObjectPrefix(openid), client.GetID(), client.GetRequestObjectEncryptionEnc(), token.ContentEncryption) + return outer.WithDebugf("%s client with id '%s' expects request objects to be encrypted with the 'enc' header value '%s' due to the client registration 'request_object_encryption_enc' value but the request object was encrypted with the 'enc' header value '%s'.", hintRequestObjectPrefix(openid), client.GetID(), client.GetRequestObjectEncryptionEnc(), token.ContentEncryption) case errJWTValidation.Has(jwt.ValidationErrorMalformedNotCompactSerialized): return outer.WithDebugf("%s client with id '%s' provided a request object that was malformed. The request object does not appear to be a JWE or JWS compact serialized JWT.", hintRequestObjectPrefix(openid), client.GetID()) case errJWTValidation.Has(jwt.ValidationErrorMalformed): diff --git a/authorize_request_handler_oidc_request_test.go b/authorize_request_handler_oidc_request_test.go index 6f52afcf..c8059888 100644 --- a/authorize_request_handler_oidc_request_test.go +++ b/authorize_request_handler_oidc_request_test.go @@ -203,7 +203,7 @@ func TestAuthorizeRequestParametersFromOpenIDConnectRequestObject(t *testing.T) client: &DefaultJARClient{JSONWebKeys: jwks, RequestObjectSigningAlg: "RS256", DefaultClient: &DefaultClient{ID: "test", ClientSecret: NewPlainTextClientSecret("aaaaaaaaaaaaaaabbbbbbbbbbbbbbbbbbbbbbbcccccccccccccccccccccddddddddddddddddddddddd")}}, expected: url.Values{consts.FormParameterScope: {consts.ScopeOpenID}}, err: ErrInvalidRequestObject, - errString: "The request parameter contains an invalid Request Object. OpenID Connect 1.0 request object could not be decoded or validated. OpenID Connect 1.0 client with id 'test' expects request objects to be signed with the 'alg' value 'RS256' due to the client registration 'request_object_signing_alg' value but the request object was signed with the 'alg' value 'HS256'.", + errString: "The request parameter contains an invalid Request Object. OpenID Connect 1.0 request object could not be decoded or validated. OpenID Connect 1.0 client with id 'test' expects request objects to be signed with the 'alg' header value 'RS256' due to the client registration 'request_object_signing_alg' value but the request object was signed with the 'alg' header value 'HS256'.", }, { name: "ShouldFailMismatchedClientID", @@ -279,7 +279,7 @@ func TestAuthorizeRequestParametersFromOpenIDConnectRequestObject(t *testing.T) client: &DefaultJARClient{JSONWebKeysURI: root.JoinPath("jwks.json").String(), RequestObjectSigningAlg: "RS256", DefaultClient: &DefaultClient{ID: "foo"}}, expected: url.Values{consts.FormParameterState: {"some-state"}, consts.FormParameterClientID: {"foo"}, consts.FormParameterResponseType: {consts.ResponseTypeAuthorizationCodeFlow}, consts.FormParameterScope: {"foo openid"}, consts.FormParameterRequest: {assertionRequestObjectValidNone}, "foo": {"bar"}, "baz": {"baz"}}, err: ErrInvalidRequestObject, - errString: "The request parameter contains an invalid Request Object. OpenID Connect 1.0 request object could not be decoded or validated. OpenID Connect 1.0 client with id 'foo' expects request objects to be signed with the 'alg' value 'RS256' due to the client registration 'request_object_signing_alg' value but the request object was signed with the 'alg' value 'none'.", + errString: "The request parameter contains an invalid Request Object. OpenID Connect 1.0 request object could not be decoded or validated. OpenID Connect 1.0 client with id 'foo' expects request objects to be signed with the 'alg' header value 'RS256' due to the client registration 'request_object_signing_alg' value but the request object was signed with the 'alg' header value 'none'.", }, { name: "ShouldFailRequestURIAlgNone", @@ -287,7 +287,7 @@ func TestAuthorizeRequestParametersFromOpenIDConnectRequestObject(t *testing.T) client: &DefaultJARClient{JSONWebKeysURI: root.JoinPath("jwks.json").String(), RequestObjectSigningAlg: "RS256", RequestURIs: []string{root.JoinPath("request-object", "valid", "none.jwk").String()}, DefaultClient: &DefaultClient{ID: "foo"}}, expected: url.Values{consts.FormParameterResponseType: {"token"}, consts.FormParameterClientID: {"foo"}, consts.FormParameterState: {"some-state"}, consts.FormParameterScope: {"foo openid"}, consts.FormParameterRequestURI: {root.JoinPath("request-object", "valid", "none.jwk").String()}, "foo": {"bar"}, "baz": {"baz"}}, err: ErrInvalidRequestObject, - errString: "The request parameter contains an invalid Request Object. OpenID Connect 1.0 request object could not be decoded or validated. OpenID Connect 1.0 client with id 'foo' expects request objects to be signed with the 'alg' value 'RS256' due to the client registration 'request_object_signing_alg' value but the request object was signed with the 'alg' value 'none'.", + errString: "The request parameter contains an invalid Request Object. OpenID Connect 1.0 request object could not be decoded or validated. OpenID Connect 1.0 client with id 'foo' expects request objects to be signed with the 'alg' header value 'RS256' due to the client registration 'request_object_signing_alg' value but the request object was signed with the 'alg' header value 'none'.", }, { name: "ShouldFailRequestRS256", @@ -295,7 +295,7 @@ func TestAuthorizeRequestParametersFromOpenIDConnectRequestObject(t *testing.T) client: &DefaultJARClient{JSONWebKeysURI: root.JoinPath("jwks.json").String(), RequestObjectSigningAlg: consts.JSONWebTokenAlgNone, DefaultClient: &DefaultClient{ID: "foo"}}, expected: url.Values{consts.FormParameterState: {"some-state"}, consts.FormParameterClientID: {"foo"}, consts.FormParameterResponseType: {consts.ResponseTypeAuthorizationCodeFlow}, consts.FormParameterScope: {"foo openid"}, consts.FormParameterRequest: {assertionRequestObjectValid}, "foo": {"bar"}, "baz": {"baz"}}, err: ErrInvalidRequestObject, - errString: "The request parameter contains an invalid Request Object. OpenID Connect 1.0 request object could not be decoded or validated. OpenID Connect 1.0 client with id 'foo' expects request objects to be signed with the 'alg' value 'none' due to the client registration 'request_object_signing_alg' value but the request object was signed with the 'alg' value 'RS256'.", + errString: "The request parameter contains an invalid Request Object. OpenID Connect 1.0 request object could not be decoded or validated. OpenID Connect 1.0 client with id 'foo' expects request objects to be signed with the 'alg' header value 'none' due to the client registration 'request_object_signing_alg' value but the request object was signed with the 'alg' header value 'RS256'.", }, { name: "ShouldFailRequestURIRS256", @@ -303,7 +303,7 @@ func TestAuthorizeRequestParametersFromOpenIDConnectRequestObject(t *testing.T) client: &DefaultJARClient{JSONWebKeysURI: root.JoinPath("jwks.json").String(), RequestObjectSigningAlg: consts.JSONWebTokenAlgNone, RequestURIs: []string{root.JoinPath("request-object", "valid", "standard.jwk").String()}, DefaultClient: &DefaultClient{ID: "foo"}}, expected: url.Values{consts.FormParameterResponseType: {"token"}, consts.FormParameterClientID: {"foo"}, consts.FormParameterState: {"some-state"}, consts.FormParameterScope: {"foo openid"}, consts.FormParameterRequestURI: {root.JoinPath("request-object", "valid", "standard.jwk").String()}, "foo": {"bar"}, "baz": {"baz"}}, err: ErrInvalidRequestObject, - errString: "The request parameter contains an invalid Request Object. OpenID Connect 1.0 request object could not be decoded or validated. OpenID Connect 1.0 client with id 'foo' expects request objects to be signed with the 'alg' value 'none' due to the client registration 'request_object_signing_alg' value but the request object was signed with the 'alg' value 'RS256'.", + errString: "The request parameter contains an invalid Request Object. OpenID Connect 1.0 request object could not be decoded or validated. OpenID Connect 1.0 client with id 'foo' expects request objects to be signed with the 'alg' header value 'none' due to the client registration 'request_object_signing_alg' value but the request object was signed with the 'alg' header value 'RS256'.", }, { name: "ShouldPassRequestAlgNone", diff --git a/client_authentication_strategy.go b/client_authentication_strategy.go index 6baa63c2..36ad6462 100644 --- a/client_authentication_strategy.go +++ b/client_authentication_strategy.go @@ -392,17 +392,21 @@ func fmtClientAssertionDecodeError(token *jwt.Token, client AuthenticationMethod if errJWTValidation := new(jwt.ValidationError); errors.As(inner, &errJWTValidation) { switch { case errJWTValidation.Has(jwt.ValidationErrorHeaderKeyIDInvalid): - return outer.WithDebugf("OAuth 2.0 client with id '%s' expects client assertions to be signed with the 'kid' value '%s' due to the client registration 'request_object_signing_key_id' value but the client assertion was signed with the 'kid' value '%s'.", client.GetID(), handler.GetAuthSigningKeyID(client), token.KeyID) + return outer.WithDebugf("OAuth 2.0 client with id '%s' expects client assertions to be signed with the 'kid' header value '%s' due to the client registration 'request_object_signing_key_id' value but the client assertion was signed with the 'kid' header value '%s'.", client.GetID(), handler.GetAuthSigningKeyID(client), token.KeyID) case errJWTValidation.Has(jwt.ValidationErrorHeaderAlgorithmInvalid): - return outer.WithDebugf("OAuth 2.0 client with id '%s' expects client assertions to be signed with the 'alg' value '%s' due to the client registration 'request_object_signing_alg' value but the client assertion was signed with the 'alg' value '%s'.", client.GetID(), handler.GetAuthSigningAlg(client), token.SignatureAlgorithm) + return outer.WithDebugf("OAuth 2.0 client with id '%s' expects client assertions to be signed with the 'alg' header value '%s' due to the client registration 'request_object_signing_alg' value but the client assertion was signed with the 'alg' header value '%s'.", client.GetID(), handler.GetAuthSigningAlg(client), token.SignatureAlgorithm) case errJWTValidation.Has(jwt.ValidationErrorHeaderTypeInvalid): - return outer.WithDebugf("OAuth 2.0 client with id '%s' expects client assertions to be signed with the 'typ' value '%s' but the client assertion was signed with the 'typ' value '%s'.", client.GetID(), consts.JSONWebTokenTypeJWT, token.Header[consts.JSONWebTokenHeaderType]) + return outer.WithDebugf("OAuth 2.0 client with id '%s' expects client assertions to be signed with the 'typ' header value '%s' but the client assertion was signed with the 'typ' header value '%s'.", client.GetID(), consts.JSONWebTokenTypeJWT, token.Header[consts.JSONWebTokenHeaderType]) + case errJWTValidation.Has(jwt.ValidationErrorHeaderEncryptionTypeInvalid): + return outer.WithDebugf("OAuth 2.0 client with id '%s' expects client assertions to be encrypted with the 'typ' header value '%s' but the client assertion was encrypted with the 'typ' header value '%s'.", client.GetID(), consts.JSONWebTokenTypeJWT, token.HeaderJWE[consts.JSONWebTokenHeaderType]) + case errJWTValidation.Has(jwt.ValidationErrorHeaderContentTypeInvalid): + return outer.WithDebugf("OAuth 2.0 client with id '%s' expects client assertions to be encrypted with the 'cty' header value '%s' but the client assertion was encrypted with the 'cty' header value '%s'.", client.GetID(), consts.JSONWebTokenTypeJWT, token.HeaderJWE[consts.JSONWebTokenHeaderContentType]) case errJWTValidation.Has(jwt.ValidationErrorHeaderEncryptionKeyIDInvalid): - return outer.WithDebugf("OAuth 2.0 client with id '%s' expects client assertions to be encrypted with the 'kid' value '%s' due to the client registration 'request_object_encryption_key_id' value but the client assertion was encrypted with the 'kid' value '%s'.", client.GetID(), handler.GetAuthEncryptionKeyID(client), token.EncryptionKeyID) + return outer.WithDebugf("OAuth 2.0 client with id '%s' expects client assertions to be encrypted with the 'kid' header value '%s' due to the client registration 'request_object_encryption_key_id' value but the client assertion was encrypted with the 'kid' header value '%s'.", client.GetID(), handler.GetAuthEncryptionKeyID(client), token.EncryptionKeyID) case errJWTValidation.Has(jwt.ValidationErrorHeaderKeyAlgorithmInvalid): - return outer.WithDebugf("OAuth 2.0 client with id '%s' expects client assertions to be encrypted with the 'alg' value '%s' due to the client registration 'request_object_encryption_alg' value but the client assertion was encrypted with the 'alg' value '%s'.", client.GetID(), handler.GetAuthEncryptionAlg(client), token.KeyAlgorithm) + return outer.WithDebugf("OAuth 2.0 client with id '%s' expects client assertions to be encrypted with the 'alg' header value '%s' due to the client registration 'request_object_encryption_alg' value but the client assertion was encrypted with the 'alg' header value '%s'.", client.GetID(), handler.GetAuthEncryptionAlg(client), token.KeyAlgorithm) case errJWTValidation.Has(jwt.ValidationErrorHeaderContentEncryptionInvalid): - return outer.WithDebugf("OAuth 2.0 client with id '%s' expects client assertions to be encrypted with the 'enc' value '%s' due to the client registration 'request_object_encryption_enc' value but the client assertion was encrypted with the 'enc' value '%s'.", client.GetID(), handler.GetAuthEncryptionEnc(client), token.ContentEncryption) + return outer.WithDebugf("OAuth 2.0 client with id '%s' expects client assertions to be encrypted with the 'enc' header value '%s' due to the client registration 'request_object_encryption_enc' value but the client assertion was encrypted with the 'enc' header value '%s'.", client.GetID(), handler.GetAuthEncryptionEnc(client), token.ContentEncryption) case errJWTValidation.Has(jwt.ValidationErrorMalformedNotCompactSerialized): return outer.WithDebugf("OAuth 2.0 client with id '%s' provided a client assertion that was malformed. The client assertion does not appear to be a JWE or JWS compact serialized JWT.", client.GetID()) case errJWTValidation.Has(jwt.ValidationErrorMalformed): diff --git a/client_authentication_test.go b/client_authentication_test.go index 3ad21753..142910bf 100644 --- a/client_authentication_test.go +++ b/client_authentication_test.go @@ -404,7 +404,7 @@ func TestAuthenticateClient(t *testing.T) { }, keyRSA, "kid-foo")}, "client_assertion_type": []string{consts.ClientAssertionTypeJWTBearer}}, r: new(http.Request), expectErr: ErrInvalidClient, - err: "Client authentication failed (e.g., unknown client, no client authentication included, or unsupported authentication method). OAuth 2.0 client with id 'bar' provided a client assertion which could not be decoded or validated. OAuth 2.0 client with id 'bar' expects client assertions to be signed with the 'alg' value 'ES256' due to the client registration 'request_object_signing_alg' value but the client assertion was signed with the 'alg' value 'RS256'.", + err: "Client authentication failed (e.g., unknown client, no client authentication included, or unsupported authentication method). OAuth 2.0 client with id 'bar' provided a client assertion which could not be decoded or validated. OAuth 2.0 client with id 'bar' expects client assertions to be signed with the 'alg' header value 'ES256' due to the client registration 'request_object_signing_alg' value but the client assertion was signed with the 'alg' header value 'RS256'.", }, { name: "ShouldFailBecauseMalformedAssertionUsed", @@ -635,7 +635,7 @@ func TestAuthenticateClient(t *testing.T) { }, keyRSA, "kid-foo")}, "client_assertion_type": []string{consts.ClientAssertionTypeJWTBearer}}, r: new(http.Request), expectErr: ErrInvalidClient, - err: "Client authentication failed (e.g., unknown client, no client authentication included, or unsupported authentication method). OAuth 2.0 client with id 'bar' provided a client assertion which could not be decoded or validated. OAuth 2.0 client with id 'bar' expects client assertions to be signed with the 'alg' value 'RS256' due to the client registration 'request_object_signing_alg' value but the client assertion was signed with the 'alg' value 'none'.", + err: "Client authentication failed (e.g., unknown client, no client authentication included, or unsupported authentication method). OAuth 2.0 client with id 'bar' provided a client assertion which could not be decoded or validated. OAuth 2.0 client with id 'bar' expects client assertions to be signed with the 'alg' header value 'RS256' due to the client registration 'request_object_signing_alg' value but the client assertion was signed with the 'alg' header value 'none'.", }, { name: "ShouldPassWithProperAssertionWhenJWKsURIIsSet", @@ -681,7 +681,7 @@ func TestAuthenticateClient(t *testing.T) { err: "Client authentication failed (e.g., unknown client, no client authentication included, or unsupported authentication method). The client assertion had invalid claims. Claim 'iss' from 'client_assertion' must match the 'client_id' of the OAuth 2.0 Client.", }, { - name: "ShouldFailBecauseClientAssertionJTIIsNotSet", + name: "ShouldFailBecauseClientAssertionJTIClaimIsNotSet", client: func(ts *httptest.Server) Client { return &DefaultJARClient{DefaultClient: &DefaultClient{ID: "bar", ClientSecret: testClientSecretBar}, JSONWebKeys: jwksRSA, TokenEndpointAuthMethod: "private_key_jwt"} }, form: url.Values{"client_id": []string{"bar"}, "client_assertion": {mustGenerateRSAAssertion(t, jwt.MapClaims{ diff --git a/handler/oauth2/strategy_jwt_profile.go b/handler/oauth2/strategy_jwt_profile.go index eb3f9c31..7dd51d5d 100644 --- a/handler/oauth2/strategy_jwt_profile.go +++ b/handler/oauth2/strategy_jwt_profile.go @@ -202,17 +202,21 @@ func fmtValidateJWTError(token *jwt.Token, client jwt.Client, inner error) (err if errJWTValidation := new(jwt.ValidationError); errors.As(inner, &errJWTValidation) { switch { case errJWTValidation.Has(jwt.ValidationErrorHeaderKeyIDInvalid): - return oauth2.ErrInvalidTokenFormat.WithDebugf("Token %sis expected to be signed with the 'kid' value '%s' but it was signed with the 'kid' value '%s'.", clientText, sigKID, token.KeyID) + return oauth2.ErrInvalidTokenFormat.WithDebugf("Token %sis expected to be signed with the 'kid' header value '%s' but it was signed with the 'kid' header value '%s'.", clientText, sigKID, token.KeyID) case errJWTValidation.Has(jwt.ValidationErrorHeaderAlgorithmInvalid): - return oauth2.ErrInvalidTokenFormat.WithDebugf("Token %sis expected to be signed with the 'alg' value '%s' but it was signed with the 'alg' value '%s'.", clientText, sigAlg, token.SignatureAlgorithm) + return oauth2.ErrInvalidTokenFormat.WithDebugf("Token %sis expected to be signed with the 'alg' header value '%s' but it was signed with the 'alg' header value '%s'.", clientText, sigAlg, token.SignatureAlgorithm) case errJWTValidation.Has(jwt.ValidationErrorHeaderTypeInvalid): - return oauth2.ErrInvalidTokenFormat.WithDebugf("Token %sis expected to be signed with the 'typ' value '%s' but it was signed with the 'typ' value '%s'.", clientText, consts.JSONWebTokenTypeJWT, token.Header[consts.JSONWebTokenHeaderType]) + return oauth2.ErrInvalidTokenFormat.WithDebugf("Token %sis expected to be signed with the 'typ' header value '%s' but it was signed with the 'typ' header value '%s'.", clientText, consts.JSONWebTokenTypeJWT, token.Header[consts.JSONWebTokenHeaderType]) + case errJWTValidation.Has(jwt.ValidationErrorHeaderEncryptionTypeInvalid): + return oauth2.ErrInvalidTokenFormat.WithDebugf("Token %sis expected to be encrypted with the 'typ' header value '%s' but it was encrypted with the 'typ' header value '%s'.", clientText, consts.JSONWebTokenTypeJWT, token.HeaderJWE[consts.JSONWebTokenHeaderType]) + case errJWTValidation.Has(jwt.ValidationErrorHeaderContentTypeInvalid): + return oauth2.ErrInvalidTokenFormat.WithDebugf("Token %sis expected to be encrypted with the 'cty' header value '%s' but it was encrypted with the 'cty' header value '%s'.", clientText, consts.JSONWebTokenTypeJWT, token.HeaderJWE[consts.JSONWebTokenHeaderContentType]) case errJWTValidation.Has(jwt.ValidationErrorHeaderEncryptionKeyIDInvalid): - return oauth2.ErrInvalidTokenFormat.WithDebugf("Token %sis expected to be encrypted with the 'kid' value '%s' but it was encrypted with the 'kid' value '%s'.", clientText, encKID, token.EncryptionKeyID) + return oauth2.ErrInvalidTokenFormat.WithDebugf("Token %sis expected to be encrypted with the 'kid' header value '%s' but it was encrypted with the 'kid' header value '%s'.", clientText, encKID, token.EncryptionKeyID) case errJWTValidation.Has(jwt.ValidationErrorHeaderKeyAlgorithmInvalid): - return oauth2.ErrInvalidTokenFormat.WithDebugf("Token %sis expected to be encrypted with the 'alg' value '%s' but it was encrypted with the 'alg' value '%s'.", clientText, encAlg, token.KeyAlgorithm) + return oauth2.ErrInvalidTokenFormat.WithDebugf("Token %sis expected to be encrypted with the 'alg' header value '%s' but it was encrypted with the 'alg' header value '%s'.", clientText, encAlg, token.KeyAlgorithm) case errJWTValidation.Has(jwt.ValidationErrorHeaderContentEncryptionInvalid): - return oauth2.ErrInvalidTokenFormat.WithDebugf("Token %sis expected to be encrypted with the 'enc' value '%s' but it was encrypted with the 'enc' value '%s'.", clientText, enc, token.ContentEncryption) + return oauth2.ErrInvalidTokenFormat.WithDebugf("Token %sis expected to be encrypted with the 'enc' header value '%s' but it was encrypted with the 'enc' header value '%s'.", clientText, enc, token.ContentEncryption) case errJWTValidation.Has(jwt.ValidationErrorMalformedNotCompactSerialized): return oauth2.ErrInvalidTokenFormat.WithDebugf("Token %sis malformed. The token does not appear to be a JWE or JWS compact serialized JWT.", clientText) case errJWTValidation.Has(jwt.ValidationErrorMalformed): diff --git a/token/jwt/jwt_signer_test.go b/token/jwt/jwt_signer_test.go deleted file mode 100644 index f1cf090f..00000000 --- a/token/jwt/jwt_signer_test.go +++ /dev/null @@ -1,320 +0,0 @@ -// Copyright © 2023 Ory Corp -// SPDX-License-Identifier: Apache-2.0 - -package jwt - -/* - -import ( - "context" - "crypto/ecdsa" - "crypto/elliptic" - "crypto/rand" - "encoding/json" - "fmt" - "strings" - "testing" - "time" - - "github.com/go-jose/go-jose/v4" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" - - "authelia.com/provider/oauth2/internal/gen" -) - - -var header = &Headers{ - Extra: map[string]any{ - "foo": "bar", - }, -} - -func TestEncrypt(t *testing.T) { - i, err := ecdsa.GenerateKey(elliptic.P521(), rand.Reader) - require.NoError(t, err) - - c, err := ecdsa.GenerateKey(elliptic.P521(), rand.Reader) - require.NoError(t, err) - - issuer := jose.JSONWebKey{ - Key: i, - KeyID: "iss-abc123-es512", - Algorithm: string(jose.ES512), - Use: "sig", - } - - clientP := jose.JSONWebKey{ - Key: c, - KeyID: "client-abc123-es512", - Algorithm: string(jose.ECDH_ES_A256KW), - Use: "enc", - } - - client := jose.JSONWebKey{ - Key: &c.PublicKey, - KeyID: "client-abc123-es512", - Algorithm: string(jose.ECDH_ES_A256KW), - Use: "enc", - } - - issuerPublic := jose.JSONWebKey{ - Key: &i.PublicKey, - KeyID: "iss-abc123-es512", - Algorithm: string(jose.ES512), - Use: "sig", - } - - key := make([]byte, 64) - - _, err = rand.Read(key) - require.NoError(t, err) - - issuerDirect := jose.JSONWebKey{ - Key: key, - KeyID: "iss-abc123-es512", - Algorithm: string(jose.DIRECT), - Use: "enc", - } - - data, err := json.Marshal(issuer) - require.NoError(t, err) - fmt.Println(string(data)) - - data, err = json.Marshal(issuer.Public()) - require.NoError(t, err) - fmt.Println(string(data)) - - data, err = json.Marshal(issuerPublic) - require.NoError(t, err) - fmt.Println(string(data)) - - data, err = json.Marshal(issuerPublic.Public()) - require.NoError(t, err) - fmt.Println(string(data)) - - data, err = json.Marshal(client) - require.NoError(t, err) - fmt.Println(string(data)) - - data, err = json.Marshal(clientP) - require.NoError(t, err) - fmt.Println(string(data)) - - data, err = json.Marshal(issuerDirect) - require.NoError(t, err) - fmt.Println(string(data)) - - jwk2 := New() - jwk := New() - - claims := MapClaims{ - "name": "example", - } - - jwsHeaders := &Headers{} - jweHeaders := &Headers{} - - jwk.SetJWS(jwsHeaders, claims, jose.SignatureAlgorithm(issuer.Algorithm)) - jwk2.SetJWS(jwsHeaders, claims, jose.ES256) - jwk.SetJWE(jweHeaders, jose.KeyAlgorithm(client.Algorithm), jose.A256GCM, jose.NONE) - - token, signature, err := jwk.CompactEncrypted(&issuer, &client) - require.NoError(t, err) - - fmt.Println(token) - fmt.Println(signature) - - token, signature, err = jwk2.CompactSigned(&issuer) - require.NoError(t, err) - - fmt.Println(token) - fmt.Println(signature) -} - -func TestHash(t *testing.T) { - for k, tc := range []struct { - d string - strategy Signer - }{ - { - d: "RS256", - strategy: &DefaultSigner{GetPrivateKey: func(_ context.Context) (any, error) { - return gen.MustRSAKey(), nil - }}, - }, - { - d: "ES256", - strategy: &DefaultSigner{GetPrivateKey: func(_ context.Context) (any, error) { - return gen.MustES256Key(), nil - }}, - }, - } { - t.Run(fmt.Sprintf("case=%d/strategy=%s", k, tc.d), func(t *testing.T) { - in := []byte("foo") - out, err := tc.strategy.Hash(context.TODO(), in) - assert.NoError(t, err) - assert.NotEqual(t, in, out) - }) - } -} - -func TestAssign(t *testing.T) { - for k, c := range [][]map[string]any{ - { - {"foo": "bar"}, - {"baz": "bar"}, - {"foo": "bar", "baz": "bar"}, - }, - { - {"foo": "bar"}, - {"foo": "baz"}, - {"foo": "bar"}, - }, - { - {}, - {"foo": "baz"}, - {"foo": "baz"}, - }, - { - {"foo": "bar"}, - {"foo": "baz", "bar": "baz"}, - {"foo": "bar", "bar": "baz"}, - }, - } { - assert.EqualValues(t, c[2], assign(c[0], c[1]), "Case %d", k) - } -} - -func TestGenerateJWT(t *testing.T) { - testCases := []struct { - name string - key func() any - }{ - { - name: "DefaultSigner", - key: func() any { - return gen.MustRSAKey() - }, - }, - { - name: "ES256JWTStrategy", - key: func() any { - return gen.MustES256Key() - }, - }, - { - name: "ES256JWTStrategyWithJSONWebKey", - key: func() any { - return &jose.JSONWebKey{ - Key: gen.MustES521Key(), - Algorithm: "ES512", - } - }, - }, - } - - for _, tc := range testCases { - t.Run(tc.name, func(t *testing.T) { - ctx := context.Background() - - key := tc.key() - - strategy := &DefaultSigner{ - GetPrivateKey: func(_ context.Context) (any, error) { - return key, nil - }, - } - - claims := &JWTClaims{ - ExpiresAt: time.Now().UTC().Add(time.Hour), - } - - token, sig, err := strategy.Generate(ctx, claims.ToMapClaims(), header) - require.NoError(t, err) - require.NotNil(t, token) - assert.NotEmpty(t, sig) - - sig, err = strategy.Validate(ctx, token) - require.NoError(t, err) - assert.NotEmpty(t, sig) - - sig, err = strategy.Validate(ctx, token+"."+"0123456789") - require.Error(t, err) - assert.Empty(t, sig) - - partToken := strings.Split(token, ".")[2] - - sig, err = strategy.Validate(ctx, partToken) - require.Error(t, err) - assert.Empty(t, sig) - - key = tc.key() - - claims = &JWTClaims{ - ExpiresAt: time.Now().UTC().Add(-time.Hour), - } - - token, sig, err = strategy.Generate(ctx, claims.ToMapClaims(), header) - require.NoError(t, err) - require.NotNil(t, token) - assert.NotEmpty(t, sig) - - sig, err = strategy.Validate(ctx, token) - require.Error(t, err) - require.Empty(t, sig) - - claims = &JWTClaims{ - NotBefore: time.Now().UTC().Add(time.Hour), - } - - token, sig, err = strategy.Generate(ctx, claims.ToMapClaims(), header) - require.NoError(t, err) - require.NotNil(t, token) - assert.NotEmpty(t, sig) - - sig, err = strategy.Validate(ctx, token) - require.Error(t, err) - require.Empty(t, sig, "%s", err) - }) - } -} - -func TestValidateSignatureRejectsJWT(t *testing.T) { - for k, tc := range []struct { - d string - strategy Signer - }{ - { - d: "RS256", - strategy: &DefaultSigner{GetPrivateKey: func(_ context.Context) (any, error) { - return gen.MustRSAKey(), nil - }, - }, - }, - { - d: "ES256", - strategy: &DefaultSigner{ - GetPrivateKey: func(_ context.Context) (any, error) { - return gen.MustES256Key(), nil - }, - }, - }, - } { - t.Run(fmt.Sprintf("case=%d/strategy=%s", k, tc.d), func(t *testing.T) { - for k, c := range []string{ - "", - " ", - "foo.bar", - "foo.", - ".foo", - } { - _, err := tc.strategy.Validate(context.TODO(), c) - assert.Error(t, err) - t.Logf("Passed test case %d", k) - } - }) - } -} - -*/ diff --git a/token/jwt/jwt_strategy_test.go b/token/jwt/jwt_strategy_test.go index 65c37418..60f8ca61 100644 --- a/token/jwt/jwt_strategy_test.go +++ b/token/jwt/jwt_strategy_test.go @@ -6,12 +6,15 @@ import ( "crypto/elliptic" "crypto/rand" "crypto/rsa" + "crypto/x509" "encoding/json" "fmt" "net/http" "testing" + "time" "github.com/go-jose/go-jose/v4" + "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "authelia.com/provider/oauth2/internal/consts" @@ -217,16 +220,10 @@ func TestDefaultStrategy(t *testing.T) { token1, signature1, err = strategy.Encode(ctx, WithClaims(claims), WithHeaders(headers1), WithClient(client)) require.NoError(t, err) + assert.NotEmpty(t, signature1) require.True(t, IsSignedJWT(token1)) - fmt.Println("---------") - fmt.Println("Token 1:") - fmt.Println("\tValue:", token1) - fmt.Println("\tSignature:", signature1) - fmt.Println("---------") - fmt.Println("") - headersEnc = &Headers{} var ( @@ -242,13 +239,7 @@ func TestDefaultStrategy(t *testing.T) { token2, signature2, err = strategy.Encode(ctx, WithClaims(claims), WithHeaders(headers2), WithHeadersJWE(headersEnc), WithClient(clientEnc)) require.NoError(t, err) require.True(t, IsEncryptedJWT(token2)) - - fmt.Println("---------") - fmt.Println("Token 2:") - fmt.Println("\tValue:", token2) - fmt.Println("\tSignature:", signature2) - fmt.Println("---------") - fmt.Println("") + require.NotEmpty(t, signature2) var ( token3, signature3 string @@ -256,13 +247,7 @@ func TestDefaultStrategy(t *testing.T) { token3, signature3, err = strategy.Encode(ctx, WithClaims(claims), WithHeaders(headers1), WithHeadersJWE(headersEnc), WithClient(clientEncAsymmetric)) require.NoError(t, err) - - fmt.Println("---------") - fmt.Println("Token 3:") - fmt.Println("\tValue:", token3) - fmt.Println("\tSignature:", signature3) - fmt.Println("---------") - fmt.Println("") + assert.NotEmpty(t, signature3) clientIssuer := &DefaultIssuer{ jwks: clientIssuerJWKS, @@ -286,40 +271,292 @@ func TestDefaultStrategy(t *testing.T) { tokenString, signature, jwe, err := clientStrategy.Decrypt(ctx, token2, WithClient(clientEncAsymmetric)) require.NoError(t, err) - - fmt.Println("---------") - fmt.Println("Token 2 (Decrypted):") - fmt.Println("\tValue:", tokenString) - fmt.Println("\tSignature:", signature) - fmt.Println("\tJWE:", jwe) - fmt.Println("---------") - fmt.Println("") + assert.NotEmpty(t, signature) + assert.NotEmpty(t, tokenString) + assert.NotNil(t, jwe) tokenString, signature, jwe, err = clientStrategy.Decrypt(ctx, token3, WithClient(clientEncAsymmetric)) require.NoError(t, err) - fmt.Println("---------") - fmt.Println("Token 3 (Decrypted):") - fmt.Println("\tValue:", tokenString) - fmt.Println("\tSignature:", signature) - fmt.Println("\tJWE:", jwe) - fmt.Println("---------") - fmt.Println("") - tok, err := clientStrategy.Decode(ctx, token1, WithClient(issuerClient)) require.NoError(t, err) - - fmt.Printf("%v+\n", tok) + require.NotNil(t, tok) tok, err = clientStrategy.Decode(ctx, token2, WithClient(issuerClient)) require.NoError(t, err) - fmt.Printf("%v+\n", tok) - tok, err = clientStrategy.Decode(ctx, token3, WithClient(clientEncAsymmetric)) require.NoError(t, err) + require.NotNil(t, tok) +} + +func TestDefaultStrategy_Decode_RejectNonCompactSerializedJWT(t *testing.T) { + testCases := []struct { + name string + strategy Strategy + }{ + { + name: "RS256", + strategy: &DefaultStrategy{}, + }, + { + name: "ES256", + strategy: &DefaultStrategy{}, + }, + } + + inputs := []struct { + name string + value string + }{ + {"Empty", ""}, + {"Space", " "}, + {"TwoParts", "foo.bar"}, + {"TwoPartsEmptySecond", "foo."}, + {"TwoPartsEmptyFirst", "foo."}, + } + + for _, tc := range testCases { + for _, input := range inputs { + t.Run(fmt.Sprintf("%s/%s", tc.name, input.name), func(t *testing.T) { + _, err := tc.strategy.Decode(context.TODO(), input.value) + + assert.EqualError(t, err, "Provided value does not appear to be a JWE or JWS compact serialized JWT") + }) + } + } +} + +func TestNestedJWTEncodeDecode(t *testing.T) { + claims := MapClaims{ + "iss": "example.com", + "sub": "john", + "iat": time.Now().UTC().Unix(), + "exp": time.Now().Add(time.Hour).UTC().Unix(), + "aud": []string{"test"}, + } + + providerStrategy := &DefaultStrategy{ + Config: &testConfig{}, + Issuer: NewDefaultIssuerUnverifiedFromJWKS(&jose.JSONWebKeySet{ + Keys: []jose.JSONWebKey{ + testKeySigRSA, + testKeySigECDSA, + }, + }), + } + + encodeClientRSA := &testClient{ + id: "test", + kid: "test-rsa-sig", + alg: string(jose.RS256), + encKID: "test-rsa-enc", + encAlg: string(jose.RSA_OAEP_256), + enc: string(jose.A128GCM), + jwks: &jose.JSONWebKeySet{ + Keys: []jose.JSONWebKey{ + testKeyPublicEncRSA, + testKeyPublicEncECDSA, + }, + }, + } + + tokenString, sig, err := providerStrategy.Encode(context.TODO(), WithClaims(claims), WithClient(encodeClientRSA)) + require.NoError(t, err) + assert.NotEmpty(t, sig) + assert.NotEmpty(t, tokenString) + + clientStrategy := &DefaultStrategy{ + Config: &testConfig{}, + Issuer: NewDefaultIssuerUnverifiedFromJWKS(&jose.JSONWebKeySet{ + Keys: []jose.JSONWebKey{ + testKeyEncRSA, + testKeyEncECDSA, + }, + }), + } + + decodeClientRSA := &testClient{ + id: "test", + kid: "test-rsa-sig", + alg: string(jose.RS256), + encKID: "test-rsa-enc", + encAlg: string(jose.RSA_OAEP_256), + enc: string(jose.A128GCM), + jwks: &jose.JSONWebKeySet{ + Keys: []jose.JSONWebKey{ + testKeyPublicSigRSA, + testKeyPublicSigECDSA, + }, + }, + csigned: true, + } + + token, err := clientStrategy.Decode(context.TODO(), tokenString, WithClient(decodeClientRSA)) + require.NoError(t, err) + + assert.NotNil(t, token) + + assert.NoError(t, token.Valid(ValidateAlgorithm(string(jose.RS256)), ValidateKeyAlgorithm(string(jose.RSA_OAEP_256)), ValidateContentEncryption(string(jose.A128GCM)), ValidateKeyID("test-rsa-sig"), ValidateEncryptionKeyID("test-rsa-enc"))) + assert.NoError(t, token.Claims.Valid(ValidateRequireExpiresAt(), ValidateRequireIssuedAt(), ValidateIssuer("example.com"), ValidateAudienceAny("test"))) + assert.EqualError(t, token.Claims.Valid(ValidateAudienceAny("nope")), "Token has invalid audience") + + encodeClientECDSA := &testClient{ + id: "test", + kid: "test-ecdsa-sig", + alg: string(jose.ES256), + encKID: "test-ecdsa-enc", + encAlg: string(jose.ECDH_ES_A128KW), + enc: string(jose.A128GCM), + jwks: &jose.JSONWebKeySet{ + Keys: []jose.JSONWebKey{ + testKeyPublicEncRSA, + testKeyPublicEncECDSA, + }, + }, + } + + tokenString, sig, err = providerStrategy.Encode(context.TODO(), WithClaims(claims), WithClient(encodeClientECDSA)) + require.NoError(t, err) + assert.NotEmpty(t, sig) + assert.NotEmpty(t, tokenString) + + clientStrategy = &DefaultStrategy{ + Config: &testConfig{}, + Issuer: NewDefaultIssuerUnverifiedFromJWKS(&jose.JSONWebKeySet{ + Keys: []jose.JSONWebKey{ + testKeyEncRSA, + testKeyEncECDSA, + }, + }), + } + + decodeClientECDSA := &testClient{ + id: "test", + kid: "test-ecdsa-sig", + alg: string(jose.RS256), + encKID: "test-ecdsa-enc", + encAlg: string(jose.RSA_OAEP_256), + enc: string(jose.A128GCM), + jwks: &jose.JSONWebKeySet{ + Keys: []jose.JSONWebKey{ + testKeyPublicSigRSA, + testKeyPublicSigECDSA, + }, + }, + csigned: true, + } + + token, err = clientStrategy.Decode(context.TODO(), tokenString, WithClient(decodeClientECDSA)) + require.NoError(t, err) + + assert.NotNil(t, token) + + assert.NoError(t, token.Valid(ValidateAlgorithm(string(jose.ES256)), ValidateKeyAlgorithm(string(jose.ECDH_ES_A128KW)), ValidateContentEncryption(string(jose.A128GCM)), ValidateKeyID("test-ecdsa-sig"), ValidateEncryptionKeyID("test-ecdsa-enc"))) + assert.NoError(t, token.Claims.Valid(ValidateRequireExpiresAt(), ValidateRequireIssuedAt(), ValidateIssuer("example.com"), ValidateAudienceAny("test"))) + assert.EqualError(t, token.Claims.Valid(ValidateAudienceAny("nope")), "Token has invalid audience") + + k, _ := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) + + decodeClientECDSA = &testClient{ + id: "test", + kid: "test-ecdsa-sig", + alg: string(jose.RS256), + encKID: "test-ecdsa-enc", + encAlg: string(jose.RSA_OAEP_256), + enc: string(jose.A128GCM), + jwks: &jose.JSONWebKeySet{ + Keys: []jose.JSONWebKey{ + testKeyPublicSigRSA, + { + Key: k, + KeyID: "test-ecdsa-sig", + Use: "sig", + Algorithm: string(jose.ES256), + }, + }, + }, + csigned: true, + } + + token, err = clientStrategy.Decode(context.TODO(), tokenString, WithClient(decodeClientECDSA)) + assert.EqualError(t, err, "go-jose/go-jose: error in cryptographic primitive") - fmt.Printf("%v+\n", tok) + clientStrategy = &DefaultStrategy{ + Config: &testConfig{}, + Issuer: NewDefaultIssuerUnverifiedFromJWKS(&jose.JSONWebKeySet{ + Keys: []jose.JSONWebKey{ + testKeyEncRSA, + }, + }), + } + + token, err = clientStrategy.Decode(context.TODO(), tokenString, WithClient(decodeClientECDSA)) + assert.EqualError(t, err, "Error occurred retrieving the JSON Web Key. The JSON Web Token uses signing key with kid 'test-ecdsa-enc' which was not found") + + clientStrategy = &DefaultStrategy{ + Config: &testConfig{}, + Issuer: NewDefaultIssuerUnverifiedFromJWKS(&jose.JSONWebKeySet{ + Keys: []jose.JSONWebKey{ + testKeyEncRSA, + { + Key: k, + KeyID: "test-ecdsa-enc", + Algorithm: string(jose.ECDH_ES_A128KW), + Use: "enc", + }, + }, + }), + } + + token, err = clientStrategy.Decode(context.TODO(), tokenString, WithClient(decodeClientECDSA)) + assert.EqualError(t, err, "go-jose/go-jose: error in cryptographic primitive") +} + +func TestShouldDecodeEncrypedTokens(t *testing.T) { + testCases := []struct { + name string + have string + }{ + { + "ShouldDecodeRS256", + testCompactSerializedNestedJWEWithRSA, + }, + { + "ShouldDecodeES256", + testCompactSerializedNestedJWEWithECDSA, + }, + } + + for _, tc := range testCases { + strategy := &DefaultStrategy{ + Config: &testConfig{}, + Issuer: NewDefaultIssuerUnverifiedFromJWKS(&jose.JSONWebKeySet{ + Keys: []jose.JSONWebKey{ + testKeyEncRSA, + testKeyEncECDSA, + }, + }), + } + + client := &testClient{ + id: "test", + jwks: &jose.JSONWebKeySet{ + Keys: []jose.JSONWebKey{ + testKeyPublicSigRSA, + testKeyPublicSigECDSA, + }, + }, + csigned: true, + } + + token, err := strategy.Decode(context.Background(), tc.have, WithClient(client)) + assert.NoError(t, err) + assert.NotNil(t, token) + + assert.NoError(t, token.Valid()) + assert.NoError(t, token.Claims.Valid(ValidateIssuer("example.com"), ValidateRequireIssuedAt(), ValidateRequireExpiresAt(), ValidateSubject("john"))) + } } type testConfig struct{} @@ -359,3 +596,131 @@ func (f *testFetcher) Resolve(ctx context.Context, location string, _ bool) (jwk return jwks, nil } + +var ( + testKeyBytesSigRSA = []byte{48, 130, 4, 189, 2, 1, 0, 48, 13, 6, 9, 42, 134, 72, 134, 247, 13, 1, 1, 1, 5, 0, 4, 130, 4, 167, 48, 130, 4, 163, 2, 1, 0, 2, 130, 1, 1, 0, 225, 139, 207, 167, 201, 162, 77, 23, 186, 56, 185, 177, 22, 31, 170, 37, 145, 252, 98, 102, 229, 160, 101, 31, 227, 76, 122, 70, 190, 1, 96, 53, 29, 58, 134, 211, 50, 166, 81, 143, 163, 141, 118, 165, 192, 238, 26, 250, 66, 42, 204, 164, 160, 205, 124, 175, 215, 202, 172, 74, 210, 134, 147, 10, 64, 58, 134, 46, 185, 15, 164, 126, 106, 106, 89, 206, 12, 95, 74, 229, 236, 132, 101, 98, 140, 40, 94, 177, 113, 24, 215, 184, 7, 210, 23, 98, 129, 207, 27, 221, 233, 168, 88, 68, 180, 157, 65, 141, 145, 53, 186, 230, 143, 128, 77, 47, 240, 181, 142, 236, 105, 119, 177, 231, 0, 123, 86, 215, 201, 165, 158, 234, 185, 236, 225, 72, 237, 28, 192, 193, 45, 212, 194, 19, 195, 9, 140, 141, 143, 3, 148, 118, 244, 25, 50, 18, 9, 37, 199, 202, 225, 142, 44, 138, 99, 84, 25, 162, 48, 243, 44, 31, 38, 88, 62, 244, 2, 64, 84, 53, 105, 92, 175, 156, 152, 32, 191, 26, 39, 153, 166, 188, 190, 33, 221, 116, 180, 174, 229, 191, 129, 198, 183, 23, 178, 20, 183, 250, 29, 42, 63, 63, 107, 170, 126, 121, 31, 90, 17, 180, 63, 137, 86, 255, 49, 163, 181, 91, 110, 160, 163, 147, 67, 149, 152, 218, 212, 23, 231, 76, 133, 208, 190, 161, 230, 4, 156, 186, 206, 145, 106, 7, 213, 50, 161, 159, 2, 3, 1, 0, 1, 2, 130, 1, 0, 106, 140, 145, 220, 193, 244, 90, 87, 11, 50, 33, 6, 247, 92, 158, 20, 129, 146, 169, 41, 210, 240, 162, 213, 29, 155, 211, 103, 247, 250, 206, 104, 73, 22, 140, 250, 216, 194, 153, 101, 49, 238, 114, 78, 123, 134, 0, 88, 153, 73, 126, 195, 134, 243, 140, 35, 197, 221, 136, 231, 15, 237, 99, 41, 68, 142, 97, 53, 81, 87, 130, 109, 245, 247, 167, 213, 31, 35, 37, 78, 217, 28, 242, 136, 75, 142, 6, 173, 236, 175, 191, 184, 192, 121, 15, 115, 9, 191, 189, 122, 104, 23, 143, 27, 101, 247, 164, 48, 44, 153, 37, 98, 38, 8, 134, 110, 79, 88, 117, 220, 89, 54, 162, 100, 110, 101, 213, 239, 215, 216, 210, 212, 103, 101, 141, 155, 163, 163, 92, 200, 89, 244, 21, 136, 197, 41, 119, 24, 87, 64, 179, 1, 128, 223, 166, 65, 16, 163, 99, 42, 251, 49, 59, 200, 176, 174, 9, 79, 90, 174, 171, 221, 68, 38, 200, 224, 123, 116, 79, 105, 97, 196, 164, 173, 200, 47, 199, 130, 84, 201, 111, 87, 76, 249, 117, 200, 83, 104, 195, 123, 171, 176, 39, 221, 101, 23, 152, 39, 148, 179, 79, 32, 135, 148, 86, 252, 85, 226, 50, 222, 84, 230, 174, 202, 149, 64, 87, 170, 2, 7, 20, 26, 118, 251, 7, 161, 55, 11, 127, 143, 78, 169, 209, 225, 94, 173, 164, 149, 37, 191, 21, 182, 38, 56, 168, 129, 2, 129, 129, 0, 238, 131, 86, 36, 159, 247, 119, 204, 99, 214, 255, 44, 244, 160, 224, 151, 30, 172, 198, 6, 76, 189, 52, 147, 11, 99, 164, 161, 245, 49, 224, 145, 118, 186, 229, 106, 207, 53, 208, 16, 74, 222, 118, 57, 230, 237, 5, 165, 224, 90, 194, 146, 162, 98, 85, 30, 162, 195, 214, 117, 130, 141, 43, 225, 169, 222, 247, 190, 77, 187, 244, 50, 2, 35, 153, 192, 253, 128, 34, 227, 128, 209, 145, 41, 192, 79, 185, 78, 169, 144, 71, 211, 58, 107, 50, 125, 152, 174, 177, 58, 121, 239, 95, 47, 248, 156, 53, 97, 126, 48, 112, 232, 206, 60, 139, 36, 111, 213, 98, 254, 233, 211, 168, 187, 115, 205, 142, 45, 2, 129, 129, 0, 242, 21, 25, 43, 57, 228, 145, 144, 61, 19, 122, 145, 201, 78, 122, 47, 209, 218, 227, 90, 209, 224, 174, 252, 191, 209, 172, 99, 217, 112, 127, 131, 200, 134, 90, 159, 37, 16, 46, 86, 118, 145, 140, 31, 83, 194, 111, 23, 83, 18, 151, 223, 126, 58, 186, 235, 33, 180, 24, 200, 101, 114, 148, 199, 203, 57, 190, 239, 21, 45, 194, 140, 182, 45, 222, 30, 162, 173, 189, 249, 203, 158, 162, 138, 246, 185, 164, 21, 216, 228, 146, 180, 162, 165, 48, 170, 215, 113, 204, 223, 200, 194, 140, 54, 191, 157, 251, 30, 218, 126, 228, 27, 228, 158, 30, 126, 131, 243, 169, 230, 172, 88, 183, 51, 70, 241, 218, 123, 2, 129, 129, 0, 156, 61, 214, 201, 73, 45, 0, 2, 25, 8, 246, 193, 201, 66, 53, 189, 104, 239, 191, 12, 211, 106, 66, 45, 109, 17, 138, 0, 58, 49, 193, 45, 40, 252, 199, 90, 79, 128, 173, 218, 110, 97, 10, 75, 101, 213, 176, 148, 119, 194, 156, 161, 23, 212, 152, 115, 232, 37, 167, 175, 244, 164, 107, 177, 120, 232, 193, 155, 157, 42, 89, 142, 4, 206, 179, 98, 179, 237, 35, 109, 170, 174, 29, 140, 159, 24, 218, 136, 8, 21, 166, 167, 93, 38, 105, 189, 210, 173, 229, 21, 44, 89, 61, 30, 156, 154, 31, 113, 205, 11, 8, 123, 200, 213, 234, 68, 37, 42, 64, 158, 66, 40, 79, 232, 243, 180, 28, 197, 2, 129, 128, 115, 240, 44, 212, 169, 238, 80, 212, 142, 155, 180, 152, 251, 155, 77, 35, 119, 210, 232, 14, 7, 244, 30, 122, 71, 247, 200, 35, 45, 241, 21, 240, 236, 105, 132, 31, 49, 229, 244, 251, 77, 223, 217, 6, 235, 219, 115, 206, 236, 231, 59, 187, 58, 190, 47, 229, 10, 136, 49, 82, 80, 91, 182, 235, 148, 229, 252, 14, 142, 203, 18, 160, 199, 99, 98, 60, 179, 214, 151, 228, 121, 99, 105, 31, 58, 152, 160, 0, 34, 151, 29, 183, 203, 41, 104, 12, 122, 16, 51, 121, 125, 177, 198, 235, 53, 140, 24, 199, 167, 7, 28, 130, 75, 84, 122, 240, 70, 139, 188, 244, 15, 216, 145, 44, 202, 174, 107, 223, 2, 129, 128, 106, 85, 157, 106, 91, 201, 27, 113, 197, 111, 239, 104, 141, 30, 73, 67, 30, 204, 18, 195, 1, 99, 13, 200, 69, 81, 13, 185, 250, 196, 26, 127, 67, 184, 226, 65, 176, 119, 163, 86, 176, 24, 120, 179, 50, 36, 76, 156, 108, 138, 164, 204, 65, 133, 112, 236, 122, 246, 227, 137, 244, 216, 112, 246, 212, 114, 24, 155, 88, 42, 17, 161, 70, 196, 67, 90, 209, 73, 58, 73, 82, 26, 116, 15, 229, 107, 35, 158, 89, 49, 241, 154, 7, 230, 219, 92, 234, 144, 136, 4, 221, 149, 130, 120, 64, 127, 225, 248, 241, 183, 6, 25, 225, 10, 236, 21, 141, 152, 122, 70, 111, 82, 177, 175, 205, 116, 72, 142} + testKeyBytesEncRSA = []byte{48, 130, 4, 191, 2, 1, 0, 48, 13, 6, 9, 42, 134, 72, 134, 247, 13, 1, 1, 1, 5, 0, 4, 130, 4, 169, 48, 130, 4, 165, 2, 1, 0, 2, 130, 1, 1, 0, 172, 196, 70, 138, 2, 105, 68, 113, 29, 120, 167, 117, 74, 253, 195, 218, 40, 62, 221, 198, 42, 216, 67, 65, 21, 181, 204, 211, 51, 45, 62, 127, 3, 219, 96, 95, 31, 191, 226, 255, 108, 87, 135, 133, 134, 197, 74, 188, 3, 244, 121, 123, 171, 192, 82, 213, 54, 61, 142, 226, 42, 71, 79, 59, 30, 197, 1, 67, 182, 236, 39, 62, 5, 234, 69, 4, 167, 72, 82, 76, 46, 146, 234, 117, 99, 90, 189, 205, 19, 75, 210, 105, 225, 110, 172, 236, 14, 158, 33, 176, 193, 58, 136, 147, 140, 151, 222, 181, 95, 11, 121, 56, 172, 215, 239, 222, 200, 76, 237, 183, 50, 104, 240, 1, 164, 197, 247, 180, 54, 216, 196, 34, 137, 211, 108, 74, 222, 188, 190, 84, 118, 244, 249, 97, 192, 147, 126, 67, 209, 24, 80, 37, 180, 88, 169, 112, 37, 242, 249, 49, 28, 152, 168, 182, 175, 21, 80, 88, 153, 23, 132, 136, 53, 25, 84, 128, 216, 88, 118, 173, 4, 241, 238, 122, 226, 190, 134, 116, 167, 94, 196, 131, 175, 156, 213, 115, 140, 105, 63, 15, 31, 237, 25, 243, 85, 156, 85, 37, 4, 238, 154, 36, 10, 235, 11, 213, 222, 238, 208, 69, 197, 201, 139, 76, 3, 137, 214, 175, 63, 112, 150, 41, 24, 122, 110, 250, 27, 14, 48, 7, 63, 158, 35, 107, 59, 185, 141, 200, 178, 123, 10, 97, 247, 100, 9, 85, 2, 3, 1, 0, 1, 2, 130, 1, 1, 0, 152, 106, 91, 244, 187, 37, 213, 60, 153, 140, 108, 231, 156, 109, 253, 207, 195, 123, 154, 185, 141, 232, 214, 132, 95, 187, 208, 100, 110, 156, 182, 170, 229, 99, 47, 69, 28, 68, 115, 229, 116, 214, 79, 119, 236, 42, 183, 192, 225, 24, 87, 232, 83, 224, 74, 243, 80, 115, 196, 79, 32, 143, 98, 133, 188, 162, 126, 120, 23, 179, 132, 247, 65, 206, 168, 110, 239, 137, 109, 25, 74, 105, 80, 48, 153, 163, 95, 24, 193, 178, 61, 130, 45, 96, 47, 107, 221, 133, 130, 33, 102, 134, 214, 32, 157, 131, 9, 246, 38, 80, 127, 244, 17, 0, 59, 220, 230, 6, 128, 29, 3, 122, 242, 105, 240, 204, 185, 182, 47, 40, 194, 94, 4, 152, 28, 251, 15, 21, 148, 149, 219, 180, 253, 154, 53, 60, 3, 82, 206, 27, 97, 137, 228, 105, 87, 0, 66, 43, 198, 230, 33, 30, 84, 92, 138, 116, 76, 202, 250, 90, 102, 81, 67, 19, 159, 126, 130, 200, 208, 208, 60, 228, 166, 103, 212, 19, 138, 196, 57, 183, 244, 13, 175, 147, 198, 124, 164, 40, 50, 3, 64, 158, 35, 238, 116, 55, 215, 168, 63, 173, 11, 78, 244, 200, 130, 120, 89, 164, 58, 35, 68, 254, 141, 69, 157, 97, 123, 97, 217, 112, 51, 2, 126, 109, 182, 243, 11, 147, 93, 125, 140, 86, 157, 156, 31, 197, 119, 225, 67, 65, 64, 190, 173, 0, 143, 1, 2, 129, 129, 0, 207, 234, 215, 66, 208, 61, 102, 219, 140, 108, 97, 58, 204, 132, 211, 206, 210, 223, 212, 210, 6, 208, 177, 231, 250, 31, 35, 164, 213, 46, 55, 179, 221, 134, 79, 253, 76, 75, 187, 250, 75, 70, 163, 150, 122, 190, 243, 196, 169, 187, 255, 91, 67, 42, 110, 35, 23, 169, 114, 47, 165, 89, 10, 97, 50, 61, 53, 45, 153, 81, 209, 82, 82, 20, 4, 15, 61, 204, 185, 24, 246, 63, 153, 3, 135, 170, 45, 112, 59, 36, 209, 49, 9, 224, 203, 226, 40, 178, 20, 19, 19, 73, 249, 21, 122, 188, 247, 161, 20, 226, 124, 14, 66, 4, 12, 37, 75, 35, 37, 36, 22, 216, 236, 189, 44, 179, 102, 193, 2, 129, 129, 0, 212, 184, 112, 164, 6, 140, 127, 190, 0, 170, 229, 42, 252, 99, 70, 29, 251, 176, 41, 87, 91, 226, 204, 173, 164, 208, 192, 70, 0, 238, 113, 97, 239, 147, 247, 156, 80, 98, 231, 47, 28, 186, 99, 25, 144, 18, 165, 68, 117, 99, 54, 99, 241, 109, 77, 82, 250, 105, 218, 148, 205, 139, 49, 181, 68, 110, 50, 146, 212, 149, 191, 97, 10, 83, 55, 240, 23, 114, 83, 116, 25, 118, 195, 85, 72, 74, 142, 57, 97, 31, 65, 4, 93, 46, 177, 32, 248, 146, 5, 0, 189, 181, 123, 116, 27, 197, 98, 195, 158, 196, 112, 60, 177, 31, 246, 237, 246, 126, 248, 63, 245, 185, 1, 224, 165, 186, 251, 149, 2, 129, 129, 0, 150, 10, 132, 249, 68, 73, 107, 54, 184, 169, 101, 169, 6, 250, 59, 215, 159, 57, 195, 221, 36, 233, 233, 216, 220, 25, 40, 161, 196, 237, 171, 104, 243, 77, 255, 223, 108, 245, 162, 91, 199, 130, 220, 126, 181, 105, 163, 132, 162, 112, 118, 160, 167, 97, 177, 69, 69, 200, 20, 12, 234, 39, 205, 99, 194, 219, 132, 202, 185, 63, 223, 236, 166, 42, 167, 155, 80, 31, 178, 219, 158, 168, 218, 133, 63, 155, 193, 90, 162, 115, 185, 58, 200, 68, 31, 29, 191, 252, 114, 156, 41, 105, 82, 132, 251, 163, 238, 151, 161, 248, 167, 73, 170, 190, 60, 253, 148, 177, 114, 22, 15, 30, 208, 8, 220, 127, 66, 129, 2, 129, 128, 60, 200, 195, 111, 43, 107, 228, 104, 191, 186, 21, 168, 21, 220, 172, 65, 143, 21, 4, 139, 48, 247, 122, 243, 55, 128, 107, 32, 213, 205, 76, 218, 230, 97, 202, 196, 128, 247, 242, 5, 181, 88, 209, 78, 145, 171, 178, 76, 0, 155, 44, 4, 157, 13, 85, 166, 27, 102, 58, 14, 129, 57, 128, 39, 194, 249, 22, 60, 124, 192, 153, 162, 58, 24, 19, 136, 232, 186, 67, 124, 142, 118, 48, 84, 227, 70, 98, 163, 164, 204, 16, 129, 21, 187, 108, 227, 246, 3, 139, 168, 109, 141, 57, 76, 177, 78, 210, 237, 1, 38, 50, 200, 52, 248, 228, 79, 149, 59, 44, 230, 225, 233, 78, 207, 9, 172, 135, 141, 2, 129, 129, 0, 159, 97, 92, 8, 76, 63, 113, 168, 5, 233, 58, 228, 155, 140, 19, 198, 225, 22, 110, 129, 71, 157, 118, 55, 94, 254, 112, 242, 244, 198, 217, 99, 175, 76, 225, 115, 14, 167, 243, 188, 219, 109, 122, 165, 15, 135, 5, 225, 235, 233, 121, 203, 211, 175, 135, 69, 205, 46, 93, 119, 95, 41, 226, 44, 233, 73, 97, 76, 15, 211, 139, 211, 206, 9, 176, 77, 30, 228, 93, 183, 93, 115, 237, 93, 140, 190, 8, 156, 58, 73, 145, 4, 224, 39, 159, 83, 234, 216, 224, 97, 165, 120, 123, 253, 159, 147, 119, 201, 144, 155, 171, 114, 129, 31, 44, 3, 58, 232, 180, 193, 155, 180, 143, 220, 132, 233, 153, 219} + testKeySigRSA jose.JSONWebKey + testKeyPublicSigRSA jose.JSONWebKey + testKeyEncRSA jose.JSONWebKey + testKeyPublicEncRSA jose.JSONWebKey + + testKeyBytesSigECDSA = []byte{48, 129, 135, 2, 1, 0, 48, 19, 6, 7, 42, 134, 72, 206, 61, 2, 1, 6, 8, 42, 134, 72, 206, 61, 3, 1, 7, 4, 109, 48, 107, 2, 1, 1, 4, 32, 225, 45, 16, 217, 198, 48, 46, 37, 59, 165, 201, 242, 244, 143, 253, 127, 88, 84, 100, 25, 17, 39, 23, 128, 105, 241, 16, 227, 43, 47, 141, 187, 161, 68, 3, 66, 0, 4, 155, 65, 151, 191, 69, 30, 154, 72, 179, 179, 4, 5, 106, 97, 81, 20, 114, 8, 188, 137, 58, 81, 123, 3, 26, 111, 172, 26, 107, 212, 60, 52, 154, 177, 135, 254, 199, 8, 246, 198, 147, 23, 228, 46, 70, 145, 133, 222, 82, 222, 243, 113, 9, 10, 149, 59, 21, 144, 195, 215, 174, 175, 82, 51} + testKeyBytesEncECDSA = []byte{48, 129, 135, 2, 1, 0, 48, 19, 6, 7, 42, 134, 72, 206, 61, 2, 1, 6, 8, 42, 134, 72, 206, 61, 3, 1, 7, 4, 109, 48, 107, 2, 1, 1, 4, 32, 107, 125, 138, 26, 157, 158, 163, 251, 241, 207, 65, 183, 174, 68, 61, 135, 25, 188, 173, 245, 37, 8, 122, 233, 53, 113, 233, 221, 63, 91, 240, 100, 161, 68, 3, 66, 0, 4, 65, 254, 168, 68, 75, 55, 98, 71, 60, 193, 218, 125, 32, 224, 117, 19, 63, 145, 62, 117, 104, 107, 245, 157, 112, 192, 2, 44, 153, 73, 158, 193, 235, 150, 58, 174, 115, 25, 79, 68, 111, 142, 179, 168, 231, 172, 54, 214, 101, 18, 244, 173, 69, 22, 255, 235, 73, 227, 247, 206, 254, 183, 17, 177} + testKeySigECDSA jose.JSONWebKey + testKeyPublicSigECDSA jose.JSONWebKey + testKeyEncECDSA jose.JSONWebKey + testKeyPublicEncECDSA jose.JSONWebKey + + testCompactSerializedNestedJWEWithRSA = "eyJhbGciOiJSU0EtT0FFUC0yNTYiLCJjdHkiOiJKV1QiLCJlbmMiOiJBMTI4R0NNIiwia2lkIjoidGVzdC1yc2EtZW5jIiwidHlwIjoiSldUIn0.oWwJ7cVudU3U_EYu9vc4bXQqg7uH3xOvmjlRDKKFssc7oRDnM5IG3mvPpPnXm-jJhB0q4pGAHsVrCeRARUHTrkNWLgU4NtNhEhaeNVBoV38KxNyMXvfBdYoc91wPwN1TBQgSSBJ7FZPwDnUODAIIxh4NTyVJt6mEPlkDNrzViwv2zRhkfosdoiUJjbee2G5tQOV3Jj5o9gKrQOwZ9fJry-zKgOSeb0VR9s9vdfL8STKnwUnQ2HIYpsEG19IpXVut17Y8cIg43n65NrAHNxv2wRj-2vWij-bXM-YugSbNB7LH1n5H2wNW20nBKlYX7oLGWJHPuVxYftFIaPsm7m3sBA.6vLOfe12rdKCd0RW.-zYmgfbJNag4h0x3NQhzZCVAgczLDFYnW7fWKE51-ZyiUllRPiNuTQg-EQvwwNmFRNvwWpWFEpuhH9ceYYqfwrY0ZAqT7m3nDM24xvx46B2jZJvY3fCxB6XA72Afw4yxnRio_KMvS_vrTGsp4sLlHujolAAc3j-8y53uJ1y1mvfsKPT8i8YLyBtOT0wm1hVZAdyue-TA7iXjXRrbe2IdbW_FagXXyBW8JhOhFe9O5-T1Ts_SbPUQhwSYb5pIfUjfDLkVeEcOB9QNQw1Ai_bN2bgOh23sD4aS6G1alWEN0zHAzp0qgS7NXbyqBQefaUlcquSdM9JB9arLVGcy2IPFuC2zy7oppfcpCqmjPhXXJ0-WYA92FSAsREJCh6KsVu445KrfEQalyfmMd6qE2agDxqnDgdrIxlMzzxwCc4FvlwZ-3c3SDY5sZK_-auynoHx1adeguU1wPWY3Wy-2tEr5qsEK_P4M2h-AvcKRvMrqp6JLu3tQeSdhBGlLoEehABWp8fQqAASHlKEOLRe8znx2qSUMmdDfry8OzDYhYzcUf14YhneoEBv-HOLKykqaPWSlE-7Mkc6BY5gnoyHqznhP5cK6q-jCIUJVBRbKeTrWT2SiwZUcv89nDJL-YAMXG3GL_POPYr_TvTueRfyQL4xm1TRmZQ.fLEVOYxpohATkcNIolk1UA" + testCompactSerializedNestedJWEWithECDSA = "eyJhbGciOiJFQ0RILUVTK0ExMjhLVyIsImN0eSI6IkpXVCIsImVuYyI6IkExMjhHQ00iLCJlcGsiOnsia3R5IjoiRUMiLCJjcnYiOiJQLTI1NiIsIngiOiJRcmdXc2wwN3lFclFjaXFMM1dKZE5FU015S01SR3d2bWRnYmlYNmRYSmpjIiwieSI6Ikl5MWRpellaZFVSTnpKS1FaVlZPVnkwSVVKWjhLaWVyN19LSllJS1hzaTAifSwia2lkIjoidGVzdC1lY2RzYS1lbmMiLCJ0eXAiOiJKV1QifQ.IDkNo5DGa6VQrL8ReJrBVixYN0S_VYYE.soDQKUuDrakfrQer.aaQWcWYoUF8ISUQ_EvkRCa75GeLFMsSK3imQjc3T0OalHsCIEXYCoV_vmjDPTd4svswMQtTiZxeajevnsBl_dtaEmykjqshxHww-07r36RhWKlix3gSTJTKUvGAhFDl24HcLnWUkZZjh0Vw0G9hLidax1OGoc43Rh08aHJ3swbj6yOA-KH-0SIBBmeK1Mfb0-1I4LRdAkCeyy4P6y8z2TvEqAtfCFDfAs5O8Zm9yVA6sxFAB5l6dK4WdotOh4F8lu-vE6MfD67Qi8xiW92ccYX7fBliNyypkDaN3B1k25N374qGXYl0_z0cX2T5ba_doVYgFNDFp.bmjYOk_ZXNTN2yZQjpFLjw" +) + +func init() { + var ( + key any + err error + ) + + if key, err = x509.ParsePKCS8PrivateKey(testKeyBytesSigRSA); err != nil { + panic(err) + } + + switch k := key.(type) { + case *rsa.PrivateKey: + testKeySigRSA = jose.JSONWebKey{ + Key: k, + KeyID: "test-rsa-sig", + Use: consts.JSONWebTokenUseSignature, + Algorithm: string(jose.RS256), + } + testKeyPublicSigRSA = jose.JSONWebKey{ + Key: k.Public(), + KeyID: "test-rsa-sig", + Use: consts.JSONWebTokenUseSignature, + Algorithm: string(jose.RS256), + } + default: + panic("unsupported private key") + } + + if key, err = x509.ParsePKCS8PrivateKey(testKeyBytesEncRSA); err != nil { + panic(err) + } + + switch k := key.(type) { + case *rsa.PrivateKey: + testKeyEncRSA = jose.JSONWebKey{ + Key: k, + KeyID: "test-rsa-enc", + Use: consts.JSONWebTokenUseEncryption, + Algorithm: string(jose.RSA_OAEP_256), + } + testKeyPublicEncRSA = jose.JSONWebKey{ + Key: k.Public(), + KeyID: "test-rsa-enc", + Use: consts.JSONWebTokenUseEncryption, + Algorithm: string(jose.RSA_OAEP_256), + } + default: + panic("unsupported private key") + } + + if key, err = x509.ParsePKCS8PrivateKey(testKeyBytesSigECDSA); err != nil { + panic(err) + } + + switch k := key.(type) { + case *ecdsa.PrivateKey: + testKeySigECDSA = jose.JSONWebKey{ + Key: k, + KeyID: "test-ecdsa-sig", + Use: consts.JSONWebTokenUseSignature, + Algorithm: string(jose.ES256), + } + testKeyPublicSigECDSA = jose.JSONWebKey{ + Key: k.Public(), + KeyID: "test-ecdsa-sig", + Use: consts.JSONWebTokenUseSignature, + Algorithm: string(jose.ES256), + } + default: + panic("unsupported private key") + } + + if key, err = x509.ParsePKCS8PrivateKey(testKeyBytesEncECDSA); err != nil { + panic(err) + } + + switch k := key.(type) { + case *ecdsa.PrivateKey: + testKeyEncECDSA = jose.JSONWebKey{ + Key: k, + KeyID: "test-ecdsa-enc", + Use: consts.JSONWebTokenUseEncryption, + Algorithm: string(jose.ECDH_ES_A128KW), + } + testKeyPublicEncECDSA = jose.JSONWebKey{ + Key: k.Public(), + KeyID: "test-ecdsa-enc", + Use: consts.JSONWebTokenUseEncryption, + Algorithm: string(jose.ECDH_ES_A128KW), + } + default: + panic("unsupported private key") + } +} + +func TestIniit(t *testing.T) { + claims := MapClaims{ + "iss": "example.com", + "sub": "john", + "iat": time.Now().UTC().Unix(), + "exp": time.Now().Add(time.Hour * 24 * 365 * 40).UTC().Unix(), + } + + out, _, err := encodeNestedCompactEncrypted(context.TODO(), claims, &Headers{}, &Headers{}, &testKeySigECDSA, &testKeyPublicEncECDSA, jose.A128GCM) + + fmt.Println(err) + fmt.Println(out) +} diff --git a/token/jwt/token.go b/token/jwt/token.go index b887d826..dcefdb20 100644 --- a/token/jwt/token.go +++ b/token/jwt/token.go @@ -364,6 +364,28 @@ func (t *Token) Valid(opts ...HeaderValidationOption) (err error) { vErr.Errors |= ValidationErrorSignatureInvalid } + if t.HeaderJWE != nil && (t.KeyAlgorithm != "" || t.ContentEncryption != "") { + var ( + cty, typ, ttyp any + ok bool + ) + + if typ, ok = t.HeaderJWE[consts.JSONWebTokenHeaderType]; !ok || typ != consts.JSONWebTokenTypeJWT { + vErr.Inner = errors.New("token was encrypted with invalid typ") + vErr.Errors |= ValidationErrorHeaderEncryptionTypeInvalid + } + + if ttyp, ok = t.Header[consts.JSONWebTokenHeaderType]; !ok { + vErr.Inner = errors.New("token was signed with invalid typ") + vErr.Errors |= ValidationErrorHeaderTypeInvalid + } + + if cty, ok = t.HeaderJWE[consts.JSONWebTokenHeaderContentType]; !ok || cty != ttyp { + vErr.Inner = errors.New("token was encrypted with invalid cty or signed with an invalid typ") + vErr.Errors |= ValidationErrorHeaderContentTypeInvalid + } + } + if len(vopts.types) != 0 { if !validateTokenType(vopts.types, t.Header) { vErr.Inner = errors.New("token was signed with an invalid typ") diff --git a/token/jwt/util.go b/token/jwt/util.go index da7c7714..05813562 100644 --- a/token/jwt/util.go +++ b/token/jwt/util.go @@ -319,6 +319,14 @@ func encodeCompactSigned(ctx context.Context, claims MapClaims, headers Mapper, func encodeNestedCompactEncrypted(ctx context.Context, claims MapClaims, headers, headersJWE Mapper, keySig, keyEnc *jose.JSONWebKey, enc jose.ContentEncryption) (tokenString string, signature string, err error) { token := New() + if headers == nil { + headers = &Headers{} + } + + if headersJWE == nil { + headersJWE = &Headers{} + } + token.SetJWS(headers, claims, keySig.KeyID, jose.SignatureAlgorithm(keySig.Algorithm)) token.SetJWE(headersJWE, keyEnc.KeyID, jose.KeyAlgorithm(keyEnc.Algorithm), enc, jose.NONE) diff --git a/token/jwt/validation_error.go b/token/jwt/validation_error.go index 085102fb..712a22df 100644 --- a/token/jwt/validation_error.go +++ b/token/jwt/validation_error.go @@ -12,18 +12,20 @@ const ( ValidationErrorSignatureInvalid // Signature validation failed. ValidationErrorHeaderKeyIDInvalid // Header KID invalid error. ValidationErrorHeaderAlgorithmInvalid // Header ALG invalid error. - ValidationErrorHeaderTypeInvalid // Header TYP invalid error + ValidationErrorHeaderTypeInvalid // Header TYP invalid error. + ValidationErrorHeaderEncryptionTypeInvalid // Header TYP invalid error (JWE). + ValidationErrorHeaderContentTypeInvalid // Header TYP invalid error (JWE). ValidationErrorHeaderEncryptionKeyIDInvalid // Header KID invalid error (JWE). ValidationErrorHeaderKeyAlgorithmInvalid // Header ALG invalid error (JWE). ValidationErrorHeaderContentEncryptionInvalid // Header ENC invalid error (JWE). - ValidationErrorId // Claim JTI validation failed - ValidationErrorAudience // Claim AUD validation failed - ValidationErrorExpired // Claim EXP validation failed - ValidationErrorIssuedAt // Claim IAT validation failed - ValidationErrorNotValidYet // Claim NBF validation failed - ValidationErrorIssuer // Claim ISS validation failed - ValidationErrorSubject // Claim SUB validation failed - ValidationErrorClaimsInvalid // Generic claims validation error + ValidationErrorId // Claim JTI validation failed. + ValidationErrorAudience // Claim AUD validation failed. + ValidationErrorExpired // Claim EXP validation failed. + ValidationErrorIssuedAt // Claim IAT validation failed. + ValidationErrorNotValidYet // Claim NBF validation failed. + ValidationErrorIssuer // Claim ISS validation failed. + ValidationErrorSubject // Claim SUB validation failed. + ValidationErrorClaimsInvalid // Generic claims validation error. ) // The ValidationError is an error implementation from Parse if token is not valid. From 089cca1a5a346a04188651398089f15fa4972626 Mon Sep 17 00:00:00 2001 From: James Elliott Date: Fri, 27 Sep 2024 16:43:06 +1000 Subject: [PATCH 18/33] client auth tests --- generate-mocks.sh | 3 +-- generate.go | 3 +-- testing/mock/client.go | 7 ++++--- testing/mock/pkce_storage.go | 6 +++--- 4 files changed, 9 insertions(+), 10 deletions(-) diff --git a/generate-mocks.sh b/generate-mocks.sh index 513824d9..5c844046 100755 --- a/generate-mocks.sh +++ b/generate-mocks.sh @@ -2,7 +2,6 @@ ${MOCKGEN:-mockgen} -package mock -destination testing/mock/rw.go net/http ResponseWriter -${MOCKGEN:-mockgen} -package mock -destination testing/mock/hash.go authelia.com/provider/oauth2 Hasher ${MOCKGEN:-mockgen} -package mock -destination testing/mock/introspector.go authelia.com/provider/oauth2 TokenIntrospector ${MOCKGEN:-mockgen} -package mock -destination testing/mock/client.go authelia.com/provider/oauth2 Client ${MOCKGEN:-mockgen} -package mock -destination testing/mock/client_secret.go authelia.com/provider/oauth2 ClientSecret @@ -18,7 +17,7 @@ ${MOCKGEN:-mockgen} -package mock -destination testing/mock/transactional.go aut ${MOCKGEN:-mockgen} -package mock -destination testing/mock/oauth2_storage.go authelia.com/provider/oauth2/handler/oauth2 CoreStorage ${MOCKGEN:-mockgen} -package mock -destination testing/mock/oauth2_device_auth_storage.go -mock_names Storage=MockRFC8628Storage authelia.com/provider/oauth2/handler/rfc8628 Storage ${MOCKGEN:-mockgen} -package mock -destination testing/mock/openid_id_token_storage.go authelia.com/provider/oauth2/handler/openid OpenIDConnectRequestStorage -${MOCKGEN:-mockgen} -package mock -destination testing/mock/pkce_storage.go authelia.com/provider/oauth2/handler/pkce PKCERequestStorage +${MOCKGEN:-mockgen} -package mock -destination testing/mock/pkce_storage.go -mock_names Storage=MockPKCERequestStorage authelia.com/provider/oauth2/handler/pkce Storage ${MOCKGEN:-mockgen} -package mock -destination testing/mock/authorize_code_storage.go authelia.com/provider/oauth2/handler/oauth2 AuthorizeCodeStorage ${MOCKGEN:-mockgen} -package mock -destination testing/mock/oauth2_auth_jwt_storage.go authelia.com/provider/oauth2/handler/rfc7523 RFC7523KeyStorage ${MOCKGEN:-mockgen} -package mock -destination testing/mock/access_token_storage.go authelia.com/provider/oauth2/handler/oauth2 AccessTokenStorage diff --git a/generate.go b/generate.go index cf07ce78..5d33df43 100644 --- a/generate.go +++ b/generate.go @@ -3,7 +3,6 @@ package oauth2 -//go:generate go run go.uber.org/mock/mockgen -package internal -destination internal/hash.go authelia.com/provider/oauth2 Hasher //go:generate go run go.uber.org/mock/mockgen -package internal -destination internal/storage.go authelia.com/provider/oauth2 Storage //go:generate go run go.uber.org/mock/mockgen -package internal -destination internal/transactional.go authelia.com/provider/oauth2/storage Transactional //go:generate go run go.uber.org/mock/mockgen -package internal -destination internal/oauth2_storage.go authelia.com/provider/oauth2/handler/oauth2 CoreStorage @@ -20,7 +19,7 @@ package oauth2 //go:generate go run go.uber.org/mock/mockgen -package internal -destination internal/refresh_token_strategy.go authelia.com/provider/oauth2/handler/oauth2 ReyfreshTokenStrategy //go:generate go run go.uber.org/mock/mockgen -package internal -destination internal/authorize_code_strategy.go authelia.com/provider/oauth2/handler/oauth2 AuthorizeCodeStrategy //go:generate go run go.uber.org/mock/mockgen -package internal -destination internal/id_token_strategy.go authelia.com/provider/oauth2/handler/openid OpenIDConnectTokenStrategy -//go:generate go run go.uber.org/mock/mockgen -package internal -destination internal/pkce_storage_strategy.go authelia.com/provider/oauth2/handler/pkce PKCERequestStorage +//go:generate go run go.uber.org/mock/mockgen -package internal -destination internal/pkce_storage_strategy.go -mock_names Storage=MockPKCERequestStorage authelia.com/provider/oauth2/handler/pkce Storage //go:generate go run go.uber.org/mock/mockgen -package internal -destination internal/authorize_handler.go authelia.com/provider/oauth2 AuthorizeEndpointHandler //go:generate go run go.uber.org/mock/mockgen -package internal -destination internal/revoke_handler.go authelia.com/provider/oauth2 RevocationHandler //go:generate go run go.uber.org/mock/mockgen -package internal -destination internal/token_handler.go authelia.com/provider/oauth2 TokenEndpointHandler diff --git a/testing/mock/client.go b/testing/mock/client.go index 1532fc6c..7c88ba6c 100644 --- a/testing/mock/client.go +++ b/testing/mock/client.go @@ -68,12 +68,13 @@ func (mr *MockClientMockRecorder) GetClientSecret() *gomock.Call { } // GetClientSecretPlainText mocks base method. -func (m *MockClient) GetClientSecretPlainText() ([]byte, error) { +func (m *MockClient) GetClientSecretPlainText() ([]byte, bool, error) { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "GetClientSecretPlainText") ret0, _ := ret[0].([]byte) - ret1, _ := ret[1].(error) - return ret0, ret1 + ret1, _ := ret[1].(bool) + ret2, _ := ret[2].(error) + return ret0, ret1, ret2 } // GetClientSecretPlainText indicates an expected call of GetClientSecretPlainText. diff --git a/testing/mock/pkce_storage.go b/testing/mock/pkce_storage.go index f3a9014d..baf4fb4f 100644 --- a/testing/mock/pkce_storage.go +++ b/testing/mock/pkce_storage.go @@ -1,9 +1,9 @@ // Code generated by MockGen. DO NOT EDIT. -// Source: authelia.com/provider/oauth2/handler/pkce (interfaces: PKCERequestStorage) +// Source: authelia.com/provider/oauth2/handler/pkce (interfaces: Storage) // // Generated by this command: // -// mockgen -package mock -destination testing/mock/pkce_storage.go authelia.com/provider/oauth2/handler/pkce PKCERequestStorage +// mockgen -package mock -destination testing/mock/pkce_storage.go -mock_names Storage=MockPKCERequestStorage authelia.com/provider/oauth2/handler/pkce Storage // // Package mock is a generated GoMock package. @@ -17,7 +17,7 @@ import ( gomock "go.uber.org/mock/gomock" ) -// MockPKCERequestStorage is a mock of PKCERequestStorage interface. +// MockPKCERequestStorage is a mock of Storage interface. type MockPKCERequestStorage struct { ctrl *gomock.Controller recorder *MockPKCERequestStorageMockRecorder From 5a03150ecb6a7f46ea497a07e3c0f0fddc1057a1 Mon Sep 17 00:00:00 2001 From: James Elliott Date: Fri, 27 Sep 2024 17:24:11 +1000 Subject: [PATCH 19/33] client auth tests --- authorize_request_handler.go | 61 ------------------------------------ 1 file changed, 61 deletions(-) diff --git a/authorize_request_handler.go b/authorize_request_handler.go index fb24c97f..36552c29 100644 --- a/authorize_request_handler.go +++ b/authorize_request_handler.go @@ -206,67 +206,6 @@ func (f *Fosite) authorizeRequestParametersFromOpenIDConnectRequestObject(ctx co return errorsx.WithStack(fmtRequestObjectDecodeError(token, client, issuer, openid, err)) } - /* - if !algNone { - issuer := f.Config.GetIDTokenIssuer(ctx) - - if len(issuer) == 0 { - return errorsx.WithStack(ErrServerError.WithHintf("%s request could not be processed due to an authorization server configuration issue.", hintRequestObjectPrefix(openid)).WithDebugf("The OAuth 2.0 client with id '%s' provided a request object that was signed but the issuer for this authorization server is not known.", request.GetClient().GetID())) - } - - claimsOpts = append(claimsOpts, jwt.(issuer)) - if err = claims.Valid(jwt.ValidateIssuer(client.GetID()), jwt.ValidateAudienceAny(issuer)); err != nil { - - } - - if v, ok = claims[consts.ClaimIssuer]; !ok { - return errorsx.WithStack(ErrInvalidRequestObject.WithHintf(hintRequestObjectInvalidAuthorizationClaim, hintRequestObjectPrefix(openid)).WithDebugf(debugRequestObjectSignedAbsentClaim, request.GetClient().GetID(), consts.ClaimIssuer)) - } - - clientID := request.GetClient().GetID() - - if value, ok = v.(string); !ok { - return errorsx.WithStack(ErrInvalidRequestObject.WithHintf(hintRequestObjectInvalidAuthorizationClaim, hintRequestObjectPrefix(openid)).WithDebugf(debugRequestObjectValueTypeNotString, request.GetClient().GetID(), consts.ClaimIssuer, v, clientID, v)) - } - - if value != clientID { - return errorsx.WithStack(ErrInvalidRequestObject.WithHintf(hintRequestObjectInvalidAuthorizationClaim, hintRequestObjectPrefix(openid)).WithDebugf(debugRequestObjectValueMismatch, clientID, consts.ClaimIssuer, value, clientID)) - } - - if v, ok = claims[consts.ClaimAudience]; !ok { - return errorsx.WithStack(ErrInvalidRequestObject.WithHintf(hintRequestObjectInvalidAuthorizationClaim, hintRequestObjectPrefix(openid)).WithDebugf(debugRequestObjectSignedAbsentClaim, request.GetClient().GetID(), consts.ClaimAudience)) - } - - var valid bool - - switch t := v.(type) { - case string: - valid = t == issuer - case []string: - for _, value = range t { - if value == issuer { - valid = true - - break - } - } - case []any: - for _, x := range t { - if value, ok = x.(string); ok && value == issuer { - valid = true - - break - } - } - } - - if !valid { - return errorsx.WithStack(ErrInvalidRequestObject.WithHintf(hintRequestObjectInvalidAuthorizationClaim, hintRequestObjectPrefix(openid)).WithDebugf("The OAuth 2.0 client with id '%s' included a request object with a 'aud' claim with the values '%s' which is required match the issuer '%s'.", request.GetClient().GetID(), value, issuer)) - } - } - - */ - claimScope := RemoveEmpty(strings.Split(request.Form.Get(consts.FormParameterScope), " ")) for _, s := range scope { if !stringslice.Has(claimScope, s) { From 18b4e1b6f2567f3dfb7daa5d227033a64857f594 Mon Sep 17 00:00:00 2001 From: James Elliott Date: Fri, 27 Sep 2024 18:37:16 +1000 Subject: [PATCH 20/33] client auth tests --- go.mod | 1 - go.sum | 2 - handler/openid/flow_hybrid_test.go | 24 ++------ internal/test_helpers.go | 11 ++-- introspection_request_handler_test.go | 4 -- token/jwt/claims_id_token.go | 85 +++++++++++++++++++++++++++ token/jwt/claims_jarm.go | 4 +- token/jwt/claims_jwt.go | 43 +++++++++++--- token/jwt/claims_map.go | 24 +------- token/jwt/jwt_strategy_opts.go | 2 +- token/jwt/util.go | 13 ++++ 11 files changed, 148 insertions(+), 65 deletions(-) diff --git a/go.mod b/go.mod index ff1d6c06..9f51070d 100644 --- a/go.mod +++ b/go.mod @@ -7,7 +7,6 @@ toolchain go1.23.1 require ( github.com/dgraph-io/ristretto v0.1.1 github.com/go-jose/go-jose/v4 v4.0.4 - github.com/golang-jwt/jwt/v5 v5.2.1 github.com/google/uuid v1.6.0 github.com/gorilla/mux v1.8.1 github.com/hashicorp/go-retryablehttp v0.7.7 diff --git a/go.sum b/go.sum index 287c8796..9927a4e4 100644 --- a/go.sum +++ b/go.sum @@ -19,8 +19,6 @@ github.com/fatih/color v1.16.0 h1:zmkK9Ngbjj+K0yRhTVONQh1p/HknKYSlNT+vZCzyokM= github.com/fatih/color v1.16.0/go.mod h1:fL2Sau1YI5c0pdGEVCbKQbLXB6edEj1ZgiY4NijnWvE= github.com/go-jose/go-jose/v4 v4.0.4 h1:VsjPI33J0SB9vQM6PLmNjoHqMQNGPiZ0rHL7Ni7Q6/E= github.com/go-jose/go-jose/v4 v4.0.4/go.mod h1:NKb5HO1EZccyMpiZNbdUw/14tiXNyUJh188dfnMCAfc= -github.com/golang-jwt/jwt/v5 v5.2.1 h1:OuVbFODueb089Lh128TAcimifWaLhJwVflnrgM17wHk= -github.com/golang-jwt/jwt/v5 v5.2.1/go.mod h1:pqrtFR0X4osieyHYxtmOUWsAWrfe1Q5UVIyoH402zdk= github.com/golang/glog v0.0.0-20160126235308-23def4e6c14b/go.mod h1:SBH7ygxi8pfUlaOkMMuAQtPIUF8ecWP5IEl/CR7VP2Q= github.com/golang/glog v1.2.2 h1:1+mZ9upx1Dh6FmUTFR1naJ77miKiXgALjWOZ3NVFPmY= github.com/golang/glog v1.2.2/go.mod h1:6AhwSGph0fcJtXVM/PEHPqZlFeoLxhs7/t5UDAwmO+w= diff --git a/handler/openid/flow_hybrid_test.go b/handler/openid/flow_hybrid_test.go index 9364c9b4..4b7a8d76 100644 --- a/handler/openid/flow_hybrid_test.go +++ b/handler/openid/flow_hybrid_test.go @@ -9,7 +9,6 @@ import ( "testing" "time" - xjwt "github.com/golang-jwt/jwt/v5" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "go.uber.org/mock/gomock" @@ -190,11 +189,9 @@ func TestHybrid_HandleAuthorizeEndpointRequest(t *testing.T) { assert.NotEmpty(t, idToken) assert.True(t, request.GetSession().GetExpiresAt(oauth2.IDToken).IsZero()) - parser := xjwt.NewParser() + claims := &jwt.IDTokenClaims{} - claims := &IDTokenClaims{} - - _, _, err := parser.ParseUnverified(idToken, claims) + _, err := jwt.UnsafeParseSignedAny(idToken, claims) require.NoError(t, err) assert.Equal(t, "MvmJNOT-fq6rnnnrUTC_2A", claims.StateHash) @@ -337,14 +334,12 @@ func TestHybrid_HandleAuthorizeEndpointRequest(t *testing.T) { assert.NotEmpty(t, idToken) assert.True(t, request.GetSession().GetExpiresAt(oauth2.IDToken).IsZero()) - parser := xjwt.NewParser() - - claims := &IDTokenClaims{} - - _, _, err := parser.ParseUnverified(idToken, claims) + claims := &jwt.IDTokenClaims{} + _, err := jwt.UnsafeParseSignedAny(idToken, claims) require.NoError(t, err) - internal.RequireEqualTime(t, time.Now().Add(*internal.TestLifespans.ImplicitGrantIDTokenLifespan), claims.ExpiresAt.Time, time.Minute) + + internal.RequireEqualTime(t, time.Now().Add(*internal.TestLifespans.ImplicitGrantIDTokenLifespan), claims.ExpiresAt, time.Minute) assert.NotEmpty(t, claims.CodeHash) assert.Empty(t, claims.StateHash) @@ -415,13 +410,6 @@ func TestHybrid_HandleAuthorizeEndpointRequest(t *testing.T) { } } -type IDTokenClaims struct { - StateHash string `json:"s_hash"` - CodeHash string `json:"c_hash"` - - xjwt.RegisteredClaims -} - var hmacStrategy = &hoauth2.HMACCoreStrategy{ Enigma: &hmac.HMACStrategy{ Config: &oauth2.Config{ diff --git a/internal/test_helpers.go b/internal/test_helpers.go index 3638a1d6..b2ce1f64 100644 --- a/internal/test_helpers.go +++ b/internal/test_helpers.go @@ -12,13 +12,13 @@ import ( "testing" "time" - "github.com/golang-jwt/jwt/v5" "github.com/stretchr/testify/require" "golang.org/x/net/html" xoauth2 "golang.org/x/oauth2" "authelia.com/provider/oauth2" "authelia.com/provider/oauth2/internal/consts" + "authelia.com/provider/oauth2/token/jwt" ) func ptr(d time.Duration) *time.Duration { @@ -61,18 +61,17 @@ func RequireEqualTime(t *testing.T, expected time.Time, actual time.Time, precis } func ExtractJwtExpClaim(t *testing.T, token string) *time.Time { - parser := jwt.NewParser(jwt.WithoutClaimsValidation()) + claims := &jwt.IDTokenClaims{} - claims := &jwt.RegisteredClaims{} + _, err := jwt.UnsafeParseSignedAny(token, claims) - _, _, err := parser.ParseUnverified(token, claims) require.NoError(t, err) - if claims.ExpiresAt == nil { + if claims.ExpiresAt.IsZero() { return nil } - return &claims.ExpiresAt.Time + return &claims.ExpiresAt } //nolint:gocyclo diff --git a/introspection_request_handler_test.go b/introspection_request_handler_test.go index 3dcf0b5b..05c16f95 100644 --- a/introspection_request_handler_test.go +++ b/introspection_request_handler_test.go @@ -342,10 +342,6 @@ func TestIntrospectionResponseToMap(t *testing.T) { consts.ClaimAudience: []string{"https://example.com", "aclient"}, consts.ClaimIssuedAt: int64(100000), consts.ClaimClientIdentifier: "aclient", - //"aclaim": 1, - //consts.ClaimSubject: "asubj", - //consts.ClaimExpirationTime: int64(1000000), - //consts.ClaimUsername: "auser", }, }, } diff --git a/token/jwt/claims_id_token.go b/token/jwt/claims_id_token.go index 862b9305..255508bf 100644 --- a/token/jwt/claims_id_token.go +++ b/token/jwt/claims_id_token.go @@ -4,11 +4,15 @@ package jwt import ( + "bytes" + "fmt" "time" + jjson "github.com/go-jose/go-jose/v4/json" "github.com/google/uuid" "authelia.com/provider/oauth2/internal/consts" + "authelia.com/provider/oauth2/x/errorsx" ) // IDTokenClaims represent the claims used in open id connect requests @@ -30,6 +34,87 @@ type IDTokenClaims struct { Extra map[string]any `json:"ext"` } +func (c *IDTokenClaims) UnmarshalJSON(data []byte) error { + claims := MapClaims{} + + decoder := jjson.NewDecoder(bytes.NewReader(data)) + decoder.SetNumberType(jjson.UnmarshalIntOrFloat) + + if err := decoder.Decode(&claims); err != nil { + return errorsx.WithStack(err) + } + + var ok bool + + for claim, value := range claims { + ok = false + + switch claim { + case consts.ClaimJWTID: + c.JTI, ok = value.(string) + case consts.ClaimIssuer: + c.Issuer, ok = value.(string) + case consts.ClaimSubject: + c.Subject, ok = value.(string) + case consts.ClaimAudience: + switch aud := value.(type) { + case string: + ok = true + + c.Audience = []string{aud} + case []string: + ok = true + + c.Audience = aud + case []any: + ok = true + + loop: + for _, av := range aud { + switch a := av.(type) { + case string: + c.Audience = append(c.Audience, a) + default: + ok = false + + break loop + } + } + } + case consts.ClaimNonce: + c.Nonce, ok = value.(string) + case consts.ClaimExpirationTime: + c.ExpiresAt, ok = toTime(value, c.ExpiresAt) + case consts.ClaimIssuedAt: + c.IssuedAt, ok = toTime(value, c.IssuedAt) + case consts.ClaimRequestedAt: + c.RequestedAt, ok = toTime(value, c.RequestedAt) + case consts.ClaimAuthenticationTime: + c.AuthTime, ok = toTime(value, c.AuthTime) + case consts.ClaimCodeHash: + c.CodeHash, ok = value.(string) + case consts.ClaimStateHash: + c.StateHash, ok = value.(string) + case consts.ClaimAuthenticationContextClassReference: + c.AuthenticationContextClassReference, ok = value.(string) + default: + if c.Extra == nil { + c.Extra = make(map[string]any) + } + + c.Extra[claim] = value + + continue + } + + if !ok { + return fmt.Errorf("claim %s with value %v could not be decoded", claim, value) + } + } + + return nil +} + // ToMap will transform the headers to a map structure func (c *IDTokenClaims) ToMap() map[string]any { var ret = Copy(c.Extra) diff --git a/token/jwt/claims_jarm.go b/token/jwt/claims_jarm.go index d43c5a98..8557c19d 100644 --- a/token/jwt/claims_jarm.go +++ b/token/jwt/claims_jarm.go @@ -73,9 +73,9 @@ func (c *JARMClaims) FromMap(m map[string]any) { c.Audience = aud } case consts.ClaimIssuedAt: - c.IssuedAt = toTime(v, c.IssuedAt) + c.IssuedAt, _ = toTime(v, c.IssuedAt) case consts.ClaimExpirationTime: - c.ExpiresAt = toTime(v, c.ExpiresAt) + c.ExpiresAt, _ = toTime(v, c.ExpiresAt) default: c.Extra[k] = v } diff --git a/token/jwt/claims_jwt.go b/token/jwt/claims_jwt.go index 0ae784de..1e1d356e 100644 --- a/token/jwt/claims_jwt.go +++ b/token/jwt/claims_jwt.go @@ -4,6 +4,7 @@ package jwt import ( + "encoding/json" "strings" "time" @@ -183,11 +184,11 @@ func (c *JWTClaims) FromMap(m map[string]any) { c.Audience = aud } case consts.ClaimIssuedAt: - c.IssuedAt = toTime(v, c.IssuedAt) + c.IssuedAt, _ = toTime(v, c.IssuedAt) case consts.ClaimNotBefore: - c.NotBefore = toTime(v, c.NotBefore) + c.NotBefore, _ = toTime(v, c.NotBefore) case consts.ClaimExpirationTime: - c.ExpiresAt = toTime(v, c.ExpiresAt) + c.ExpiresAt, _ = toTime(v, c.ExpiresAt) case consts.ClaimScopeNonStandard: switch s := v.(type) { case []string: @@ -225,15 +226,41 @@ func (c *JWTClaims) FromMap(m map[string]any) { } } -func toTime(v any, def time.Time) (t time.Time) { +func toTime(v any, def time.Time) (t time.Time, ok bool) { t = def - switch a := v.(type) { + + var value int64 + + if value, ok = toInt64(v); ok { + t = time.Unix(value, 0).UTC() + } + + return +} + +func toInt64(v any) (val int64, ok bool) { + var err error + + switch t := v.(type) { case float64: - t = time.Unix(int64(a), 0).UTC() + return int64(t), true case int64: - t = time.Unix(a, 0).UTC() + return t, true + case json.Number: + if val, err = t.Int64(); err == nil { + return val, true + } + + var valf float64 + + if valf, err = t.Float64(); err != nil { + return 0, false + } + + return int64(valf), true } - return + + return 0, false } // Add will add a key-value pair to the extra field diff --git a/token/jwt/claims_map.go b/token/jwt/claims_map.go index 3a067301..8f85404e 100644 --- a/token/jwt/claims_map.go +++ b/token/jwt/claims_map.go @@ -6,7 +6,6 @@ package jwt import ( "bytes" "crypto/subtle" - "encoding/json" "errors" "time" @@ -256,28 +255,7 @@ func (m MapClaims) UnmarshalJSON(b []byte) error { } func (m MapClaims) toInt64(claim string) (val int64, ok bool) { - var err error - - switch t := m[claim].(type) { - case float64: - return int64(t), true - case int64: - return t, true - case json.Number: - if val, err = t.Int64(); err == nil { - return val, true - } - - var valf float64 - - if valf, err = t.Float64(); err != nil { - return 0, false - } - - return int64(valf), true - } - - return 0, false + return toInt64(m[claim]) } type ClaimValidationOption func(opts *ClaimValidationOptions) diff --git a/token/jwt/jwt_strategy_opts.go b/token/jwt/jwt_strategy_opts.go index 1828fcc0..21520c35 100644 --- a/token/jwt/jwt_strategy_opts.go +++ b/token/jwt/jwt_strategy_opts.go @@ -135,7 +135,7 @@ func WithJWTProfileAccessTokenClient(client any) StrategyOpt { } } -func WithNewStatelessJWTProfileIntrospectionClient(client any) StrategyOpt { +func WithStatelessJWTProfileIntrospectionClient(client any) StrategyOpt { return func(opts *StrategyOpts) (err error) { switch c := client.(type) { case IntrospectionClient: diff --git a/token/jwt/util.go b/token/jwt/util.go index 05813562..84758329 100644 --- a/token/jwt/util.go +++ b/token/jwt/util.go @@ -4,6 +4,7 @@ import ( "context" "crypto" "fmt" + jjwt "github.com/go-jose/go-jose/v4/jwt" "regexp" "strings" @@ -374,3 +375,15 @@ func getPublicJWK(jwk *jose.JSONWebKey) jose.JSONWebKey { return jwk.Public() } + +func UnsafeParseSignedAny(tokenString string, dest any) (token *jjwt.JSONWebToken, err error) { + if token, err = jjwt.ParseSigned(tokenString, SignatureAlgorithmsNone); err != nil { + return nil, err + } + + if err = token.UnsafeClaimsWithoutVerification(dest); err != nil { + return nil, err + } + + return token, nil +} From 2d8115b6f772ca914bcb605631f425ccc16553ac Mon Sep 17 00:00:00 2001 From: James Elliott Date: Fri, 27 Sep 2024 18:43:20 +1000 Subject: [PATCH 21/33] client auth tests --- token/jwt/util.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/token/jwt/util.go b/token/jwt/util.go index 84758329..42a1625d 100644 --- a/token/jwt/util.go +++ b/token/jwt/util.go @@ -4,11 +4,11 @@ import ( "context" "crypto" "fmt" - jjwt "github.com/go-jose/go-jose/v4/jwt" "regexp" "strings" "github.com/go-jose/go-jose/v4" + jjwt "github.com/go-jose/go-jose/v4/jwt" "github.com/pkg/errors" "authelia.com/provider/oauth2/internal/consts" From 33bb1f56b527a08a924c892dd60479bf0776eb87 Mon Sep 17 00:00:00 2001 From: James Elliott Date: Fri, 27 Sep 2024 21:48:11 +1000 Subject: [PATCH 22/33] client auth tests --- token/jwt/claims_jwt.go | 4 + token/jwt/claims_map.go | 53 +- token/jwt/claims_map_test.go | 1019 ++++++++++++++++++++++++++++++++-- token/jwt/claims_test.go | 41 ++ 4 files changed, 1039 insertions(+), 78 deletions(-) diff --git a/token/jwt/claims_jwt.go b/token/jwt/claims_jwt.go index 1e1d356e..7efdbb5f 100644 --- a/token/jwt/claims_jwt.go +++ b/token/jwt/claims_jwt.go @@ -246,6 +246,10 @@ func toInt64(v any) (val int64, ok bool) { return int64(t), true case int64: return t, true + case int32: + return int64(t), true + case int: + return int64(t), true case json.Number: if val, err = t.Int64(); err == nil { return val, true diff --git a/token/jwt/claims_map.go b/token/jwt/claims_map.go index 8f85404e..844ee7de 100644 --- a/token/jwt/claims_map.go +++ b/token/jwt/claims_map.go @@ -76,7 +76,13 @@ func (m MapClaims) VerifySubject(cmp string, required bool) (ok bool) { // GetAudience returns the aud claim. func (m MapClaims) GetAudience() (aud []string, ok bool) { - return StringSliceFromMap(m[consts.ClaimAudience]) + var v any + + if v, ok = m[consts.ClaimAudience]; !ok { + return nil, false + } + + return StringSliceFromMap(v) } // VerifyAudience compares the aud claim against cmp. @@ -131,7 +137,7 @@ func (m MapClaims) VerifyExpiresAt(cmp int64, required bool) (ok bool) { return !required } - return verifyExp(exp, cmp, required) + return verifyInt64Future(exp, cmp, required) } // GetIssuedAt returns the iat claim. @@ -238,16 +244,13 @@ func (m MapClaims) Valid(opts ...ClaimValidationOption) (err error) { return vErr } -func (m MapClaims) UnmarshalJSON(b []byte) error { - // This custom unmarshal allows to configure the - // go-jose decoding settings since there is no other way - // see https://github.com/square/go-jose/issues/353. - // If issue is closed with a better solution - // this custom Unmarshal method can be removed - d := jjson.NewDecoder(bytes.NewReader(b)) +func (m MapClaims) UnmarshalJSON(data []byte) error { + decoder := jjson.NewDecoder(bytes.NewReader(data)) + decoder.SetNumberType(jjson.UnmarshalIntOrFloat) + mp := map[string]any(m) - d.SetNumberType(jjson.UnmarshalIntOrFloat) - if err := d.Decode(&mp); err != nil { + + if err := decoder.Decode(&mp); err != nil { return errorsx.WithStack(err) } @@ -255,7 +258,13 @@ func (m MapClaims) UnmarshalJSON(b []byte) error { } func (m MapClaims) toInt64(claim string) (val int64, ok bool) { - return toInt64(m[claim]) + var v any + + if v, ok = m[claim]; !ok { + return 0, false + } + + return toInt64(v) } type ClaimValidationOption func(opts *ClaimValidationOptions) @@ -368,26 +377,28 @@ outer: return true } -func verifyExp(exp int64, now int64, required bool) bool { - if exp == 0 { +// verifyInt64Future ensures the given value is in the future. +func verifyInt64Future(value, now int64, required bool) bool { + if value == 0 { return !required } - return now <= exp + return now <= value } -func verifyInt64Past(iat int64, now int64, required bool) bool { - if iat == 0 { +// verifyInt64Past ensures the given value is in the past or the current value. +func verifyInt64Past(value, now int64, required bool) bool { + if value == 0 { return !required } - return now >= iat + return now >= value } -func verifyMapString(iss string, cmp string, required bool) bool { - if iss == "" { +func verifyMapString(value, cmp string, required bool) bool { + if value == "" { return !required } - return subtle.ConstantTimeCompare([]byte(iss), []byte(cmp)) == 1 + return subtle.ConstantTimeCompare([]byte(value), []byte(cmp)) == 1 } diff --git a/token/jwt/claims_map_test.go b/token/jwt/claims_map_test.go index 9613e461..13f1c0dc 100644 --- a/token/jwt/claims_map_test.go +++ b/token/jwt/claims_map_test.go @@ -4,102 +4,1007 @@ package jwt import ( + "errors" "testing" + "time" + + "github.com/stretchr/testify/assert" "authelia.com/provider/oauth2/internal/consts" ) -// Test taken from taken from [here](https://raw.githubusercontent.com/form3tech-oss/jwt-go/master/map_claims_test.go). -func Test_mapClaims_list_aud(t *testing.T) { - mapClaims := MapClaims{ - consts.ClaimAudience: []string{"foo"}, +func TestMapClaims_VerifyAudience(t *testing.T) { + testCases := []struct { + name string + have MapClaims + cmp string + required bool + expected bool + }{ + { + "ShouldPass", + MapClaims{ + consts.ClaimAudience: []string{"foo"}, + }, + "foo", + true, + true, + }, + { + "ShouldPassMultiple", + MapClaims{ + consts.ClaimAudience: []string{"foo", "bar"}, + }, + "foo", + true, + true, + }, + { + "ShouldFailNoClaim", + MapClaims{}, + "foo", + true, + false, + }, + { + "ShouldFailNoMatch", + MapClaims{ + consts.ClaimAudience: []string{"bar"}, + }, + "foo", + true, + false, + }, + { + "ShouldPassNoClaim", + MapClaims{}, + "foo", + false, + true, + }, + { + "ShouldPassTypeAny", + MapClaims{ + consts.ClaimAudience: []any{"foo"}, + }, + "foo", + true, + true, + }, + { + "ShouldPassTypeString", + MapClaims{ + consts.ClaimAudience: "foo", + }, + "foo", + true, + true, + }, + { + "ShouldFailTypeString", + MapClaims{ + consts.ClaimAudience: "bar", + }, + "foo", + true, + false, + }, + { + "ShouldFailTypeNil", + MapClaims{ + consts.ClaimAudience: nil, + }, + "foo", + true, + false, + }, + { + "ShouldFailTypeSliceAnyInt", + MapClaims{ + consts.ClaimAudience: []any{1, 2, 3}, + }, + "foo", + true, + false, + }, + { + "ShouldFailTypeInt", + MapClaims{ + consts.ClaimAudience: 1, + }, + "foo", + true, + false, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + assert.Equal(t, tc.expected, tc.have.VerifyAudience(tc.cmp, tc.required)) + }) + } +} + +func TestMapClaims_VerifyAudienceAll(t *testing.T) { + testCases := []struct { + name string + have MapClaims + cmp []string + required bool + expected bool + }{ + { + "ShouldPass", + MapClaims{ + consts.ClaimAudience: []string{"foo"}, + }, + []string{"foo"}, + true, + true, + }, + { + "ShouldFailMultipleAny", + MapClaims{ + consts.ClaimAudience: []string{"foo"}, + }, + []string{"foo", "bar"}, + true, + false, + }, + { + "ShouldPassMultiple", + MapClaims{ + consts.ClaimAudience: []string{"foo", "bar"}, + }, + []string{"foo"}, + true, + true, + }, + { + "ShouldPassMultipleAll", + MapClaims{ + consts.ClaimAudience: []string{"foo", "bar"}, + }, + []string{"foo", "bar"}, + true, + true, + }, + { + "ShouldFailNoClaim", + MapClaims{}, + []string{"foo"}, + true, + false, + }, + { + "ShouldFailNoMatch", + MapClaims{ + consts.ClaimAudience: []string{"bar"}, + }, + []string{"foo"}, + true, + false, + }, + { + "ShouldPassNoClaim", + MapClaims{}, + []string{"foo"}, + false, + true, + }, + { + "ShouldPassTypeAny", + MapClaims{ + consts.ClaimAudience: []any{"foo"}, + }, + []string{"foo"}, + true, + true, + }, + { + "ShouldPassTypeString", + MapClaims{ + consts.ClaimAudience: "foo", + }, + []string{"foo"}, + true, + true, + }, + { + "ShouldFailTypeString", + MapClaims{ + consts.ClaimAudience: "bar", + }, + []string{"foo"}, + true, + false, + }, + { + "ShouldFailTypeNil", + MapClaims{ + consts.ClaimAudience: nil, + }, + []string{"foo"}, + true, + false, + }, + { + "ShouldFailTypeSliceAnyInt", + MapClaims{ + consts.ClaimAudience: []any{1, 2, 3}, + }, + []string{"foo"}, + true, + false, + }, + { + "ShouldFailTypeInt", + MapClaims{ + consts.ClaimAudience: 1, + }, + []string{"foo"}, + true, + false, + }, } - want := true - got := mapClaims.VerifyAudience("foo", true) - if want != got { - t.Fatalf("Failed to verify claims, wanted: %v got %v", want, got) + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + assert.Equal(t, tc.expected, tc.have.VerifyAudienceAll(tc.cmp, tc.required)) + }) } } -// This is a custom test to check that an empty -// list with require == false returns valid -func Test_mapClaims_empty_list_aud(t *testing.T) { - mapClaims := MapClaims{ - consts.ClaimAudience: []string{}, +func TestMapClaims_VerifyAudienceAny(t *testing.T) { + testCases := []struct { + name string + have MapClaims + cmp []string + required bool + expected bool + }{ + { + "ShouldPass", + MapClaims{ + consts.ClaimAudience: []string{"foo"}, + }, + []string{"foo"}, + true, + true, + }, + { + "ShouldFailMultipleAny", + MapClaims{ + consts.ClaimAudience: []string{"foo"}, + }, + []string{"foo", "bar"}, + true, + true, + }, + { + "ShouldPassMultiple", + MapClaims{ + consts.ClaimAudience: []string{"foo", "bar"}, + }, + []string{"foo"}, + true, + true, + }, + { + "ShouldPassMultipleAll", + MapClaims{ + consts.ClaimAudience: []string{"foo", "bar"}, + }, + []string{"foo", "bar"}, + true, + true, + }, + { + "ShouldFailNoClaim", + MapClaims{}, + []string{"foo"}, + true, + false, + }, + { + "ShouldFailNoMatch", + MapClaims{ + consts.ClaimAudience: []string{"bar"}, + }, + []string{"foo"}, + true, + false, + }, + { + "ShouldPassNoClaim", + MapClaims{}, + []string{"foo"}, + false, + true, + }, + { + "ShouldPassTypeAny", + MapClaims{ + consts.ClaimAudience: []any{"foo"}, + }, + []string{"foo"}, + true, + true, + }, + { + "ShouldPassTypeString", + MapClaims{ + consts.ClaimAudience: "foo", + }, + []string{"foo"}, + true, + true, + }, + { + "ShouldFailTypeString", + MapClaims{ + consts.ClaimAudience: "bar", + }, + []string{"foo"}, + true, + false, + }, + { + "ShouldFailTypeNil", + MapClaims{ + consts.ClaimAudience: nil, + }, + []string{"foo"}, + true, + false, + }, + { + "ShouldFailTypeSliceAnyInt", + MapClaims{ + consts.ClaimAudience: []any{1, 2, 3}, + }, + []string{"foo"}, + true, + false, + }, + { + "ShouldFailTypeInt", + MapClaims{ + consts.ClaimAudience: 1, + }, + []string{"foo"}, + true, + false, + }, } - want := true - got := mapClaims.VerifyAudience("foo", false) - if want != got { - t.Fatalf("Failed to verify claims, wanted: %v got %v", want, got) + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + assert.Equal(t, tc.expected, tc.have.VerifyAudienceAny(tc.cmp, tc.required)) + }) } } -func Test_mapClaims_list_interface_aud(t *testing.T) { - mapClaims := MapClaims{ - consts.ClaimAudience: []any{"foo"}, +func TestMapClaims_VerifyIssuer(t *testing.T) { + testCases := []struct { + name string + have MapClaims + cmp string + required bool + expected bool + }{ + { + "ShouldPass", + MapClaims{ + consts.ClaimIssuer: "foo", + }, + "foo", + true, + true, + }, + { + "ShouldFailEmptyString", + MapClaims{ + consts.ClaimIssuer: "", + }, + "foo", + true, + false, + }, + { + "ShouldFailNoClaim", + MapClaims{}, + "foo", + true, + false, + }, + { + "ShouldPassNoClaim", + MapClaims{}, + "foo", + false, + true, + }, + { + "ShouldFailNoMatch", + MapClaims{ + consts.ClaimIssuer: "bar", + }, + "foo", + true, + false, + }, + { + "ShouldFailWrongType", + MapClaims{ + consts.ClaimIssuer: 5, + }, + "5", + true, + false, + }, + { + "ShouldFailNil", + MapClaims{ + consts.ClaimIssuer: nil, + }, + "foo", + true, + false, + }, + { + "ShouldPassNil", + MapClaims{ + consts.ClaimIssuer: nil, + }, + "foo", + false, + true, + }, } - want := true - got := mapClaims.VerifyAudience("foo", true) - if want != got { - t.Fatalf("Failed to verify claims, wanted: %v got %v", want, got) + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + assert.Equal(t, tc.expected, tc.have.VerifyIssuer(tc.cmp, tc.required)) + }) } } -func Test_mapClaims_string_aud(t *testing.T) { - mapClaims := MapClaims{ - consts.ClaimAudience: "foo", +func TestMapClaims_VerifySubject(t *testing.T) { + testCases := []struct { + name string + have MapClaims + cmp string + required bool + expected bool + }{ + { + "ShouldPass", + MapClaims{ + consts.ClaimSubject: "foo", + }, + "foo", + true, + true, + }, + { + "ShouldFailNoClaim", + MapClaims{}, + "foo", + true, + false, + }, + { + "ShouldPassNoClaim", + MapClaims{}, + "foo", + false, + true, + }, + { + "ShouldFailNoMatch", + MapClaims{ + consts.ClaimSubject: "bar", + }, + "foo", + true, + false, + }, + { + "ShouldFailWrongType", + MapClaims{ + consts.ClaimSubject: 5, + }, + "5", + true, + false, + }, + { + "ShouldFailNil", + MapClaims{ + consts.ClaimSubject: nil, + }, + "foo", + true, + false, + }, + { + "ShouldPassNil", + MapClaims{ + consts.ClaimSubject: nil, + }, + "foo", + false, + true, + }, } - want := true - got := mapClaims.VerifyAudience("foo", true) - if want != got { - t.Fatalf("Failed to verify claims, wanted: %v got %v", want, got) + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + assert.Equal(t, tc.expected, tc.have.VerifySubject(tc.cmp, tc.required)) + }) } } -func Test_mapClaims_list_aud_no_match(t *testing.T) { - mapClaims := MapClaims{ - consts.ClaimAudience: []string{"bar"}, +func TestMapClaims_VerifyExpiresAt(t *testing.T) { + testCases := []struct { + name string + have MapClaims + cmp int64 + required bool + expected bool + }{ + { + "ShouldPass", + MapClaims{ + consts.ClaimExpirationTime: int64(123), + }, + int64(123), + true, + true, + }, + { + "ShouldPassStandardInt", + MapClaims{ + consts.ClaimExpirationTime: 123, + }, + int64(123), + true, + true, + }, + { + "ShouldPassStandardInt32", + MapClaims{ + consts.ClaimExpirationTime: int32(123), + }, + int64(123), + true, + true, + }, + { + "ShouldFailNoClaim", + MapClaims{}, + int64(123), + true, + false, + }, + { + "ShouldPassNoClaim", + MapClaims{}, + int64(123), + false, + true, + }, + { + "ShouldFailNoMatch", + MapClaims{ + consts.ClaimExpirationTime: 4, + }, + int64(123), + true, + false, + }, + { + "ShouldFailWrongType", + MapClaims{ + consts.ClaimExpirationTime: true, + }, + int64(123), + true, + false, + }, + { + "ShouldFailNil", + MapClaims{ + consts.ClaimExpirationTime: nil, + }, + int64(123), + true, + false, + }, + { + "ShouldPassNil", + MapClaims{ + consts.ClaimExpirationTime: nil, + }, + int64(123), + false, + true, + }, } - want := false - got := mapClaims.VerifyAudience("foo", true) - if want != got { - t.Fatalf("Failed to verify claims, wanted: %v got %v", want, got) + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + assert.Equal(t, tc.expected, tc.have.VerifyExpiresAt(tc.cmp, tc.required)) + }) } } -func Test_mapClaims_string_aud_fail(t *testing.T) { - mapClaims := MapClaims{ - consts.ClaimAudience: "bar", +func TestMapClaims_VerifyIssuedAt(t *testing.T) { + testCases := []struct { + name string + have MapClaims + cmp int64 + required bool + expected bool + }{ + { + "ShouldPass", + MapClaims{ + consts.ClaimIssuedAt: int64(123), + }, + int64(123), + true, + true, + }, + { + "ShouldPassStandardInt", + MapClaims{ + consts.ClaimIssuedAt: 123, + }, + int64(123), + true, + true, + }, + { + "ShouldPassStandardInt32", + MapClaims{ + consts.ClaimIssuedAt: int32(123), + }, + int64(123), + true, + true, + }, + { + "ShouldFailNoClaim", + MapClaims{}, + int64(123), + true, + false, + }, + { + "ShouldPassNoClaim", + MapClaims{}, + int64(123), + false, + true, + }, + { + "ShouldFailFuture", + MapClaims{ + consts.ClaimIssuedAt: 9000, + }, + int64(123), + true, + false, + }, + { + "ShouldPassPast", + MapClaims{ + consts.ClaimIssuedAt: 4, + }, + int64(123), + true, + true, + }, + { + "ShouldFailWrongType", + MapClaims{ + consts.ClaimIssuedAt: true, + }, + int64(123), + true, + false, + }, + { + "ShouldFailNil", + MapClaims{ + consts.ClaimIssuedAt: nil, + }, + int64(123), + true, + false, + }, + { + "ShouldPassNil", + MapClaims{ + consts.ClaimIssuedAt: nil, + }, + int64(123), + false, + true, + }, } - want := false - got := mapClaims.VerifyAudience("foo", true) - if want != got { - t.Fatalf("Failed to verify claims, wanted: %v got %v", want, got) + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + assert.Equal(t, tc.expected, tc.have.VerifyIssuedAt(tc.cmp, tc.required)) + }) } } -func Test_mapClaims_string_aud_no_claim(t *testing.T) { - mapClaims := MapClaims{} - want := false - got := mapClaims.VerifyAudience("foo", true) +func TestMapClaims_VerifyNotBefore(t *testing.T) { + testCases := []struct { + name string + have MapClaims + cmp int64 + required bool + expected bool + }{ + { + "ShouldPass", + MapClaims{ + consts.ClaimNotBefore: int64(123), + }, + int64(123), + true, + true, + }, + { + "ShouldPassStandardInt", + MapClaims{ + consts.ClaimNotBefore: 123, + }, + int64(123), + true, + true, + }, + { + "ShouldPassStandardInt32", + MapClaims{ + consts.ClaimNotBefore: int32(123), + }, + int64(123), + true, + true, + }, + { + "ShouldFailNoClaim", + MapClaims{}, + int64(123), + true, + false, + }, + { + "ShouldPassNoClaim", + MapClaims{}, + int64(123), + false, + true, + }, + { + "ShouldFailFuture", + MapClaims{ + consts.ClaimNotBefore: 9000, + }, + int64(123), + true, + false, + }, + { + "ShouldPassPast", + MapClaims{ + consts.ClaimNotBefore: 4, + }, + int64(123), + true, + true, + }, + { + "ShouldFailWrongType", + MapClaims{ + consts.ClaimNotBefore: true, + }, + int64(123), + true, + false, + }, + { + "ShouldFailNil", + MapClaims{ + consts.ClaimNotBefore: nil, + }, + int64(123), + true, + false, + }, + { + "ShouldPassNil", + MapClaims{ + consts.ClaimNotBefore: nil, + }, + int64(123), + false, + true, + }, + } - if want != got { - t.Fatalf("Failed to verify claims, wanted: %v got %v", want, got) + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + assert.Equal(t, tc.expected, tc.have.VerifyNotBefore(tc.cmp, tc.required)) + }) } } -func Test_mapClaims_string_aud_no_claim_not_required(t *testing.T) { - mapClaims := MapClaims{} - want := true - got := mapClaims.VerifyAudience("foo", false) +func TestMapClaims_Valid(t *testing.T) { + testCases := []struct { + name string + have MapClaims + opts []ClaimValidationOption + errs []uint32 + err string + }{ + { + "ShouldPass", + MapClaims{}, + nil, + nil, + "", + }, + { + "ShouldFailEXPNotPresent", + MapClaims{}, + []ClaimValidationOption{ValidateRequireExpiresAt()}, + []uint32{ValidationErrorExpired}, + "Token is expired", + }, + { + "ShouldFailIATNotPresent", + MapClaims{}, + []ClaimValidationOption{ValidateRequireIssuedAt()}, + []uint32{ValidationErrorIssuedAt}, + "Token used before issued", + }, + { + "ShouldFailNBFNotPresent", + MapClaims{}, + []ClaimValidationOption{ValidateRequireNotBefore()}, + []uint32{ValidationErrorNotValidYet}, + "Token is not valid yet", + }, + { + "ShouldFailExpPast", + MapClaims{ + consts.ClaimExpirationTime: 1, + }, + nil, + []uint32{ValidationErrorExpired}, + "Token is expired", + }, + { + "ShouldFailIssuedFuture", + MapClaims{ + consts.ClaimIssuedAt: 999999999999999, + }, + nil, + []uint32{ValidationErrorIssuedAt}, + "Token used before issued", + }, + { + "ShouldFailMultiple", + MapClaims{ + consts.ClaimExpirationTime: 1, + consts.ClaimIssuedAt: 999999999999999, + }, + nil, + []uint32{ValidationErrorIssuedAt, ValidationErrorExpired}, + "Token used before issued", + }, + { + "ShouldPassIssuer", + MapClaims{ + consts.ClaimIssuer: "abc", + }, + []ClaimValidationOption{ValidateIssuer("abc")}, + nil, + "", + }, + { + "ShouldFailIssuer", + MapClaims{ + consts.ClaimIssuer: "abc", + }, + []ClaimValidationOption{ValidateIssuer("abc2"), ValidateTimeFunc(time.Now)}, + []uint32{ValidationErrorIssuer}, + "Token has invalid issuer", + }, + { + "ShouldFailIssuerAbsent", + MapClaims{}, + []ClaimValidationOption{ValidateIssuer("abc2")}, + []uint32{ValidationErrorIssuer}, + "Token has invalid issuer", + }, + { + "ShouldPassSubject", + MapClaims{ + consts.ClaimSubject: "abc", + }, + []ClaimValidationOption{ValidateSubject("abc")}, + nil, + "", + }, + { + "ShouldFailSubject", + MapClaims{ + consts.ClaimSubject: "abc", + }, + []ClaimValidationOption{ValidateSubject("abc2")}, + []uint32{ValidationErrorSubject}, + "Token has invalid subject", + }, + { + "ShouldFailSubjectAbsent", + MapClaims{}, + []ClaimValidationOption{ValidateSubject("abc2")}, + []uint32{ValidationErrorSubject}, + "Token has invalid subject", + }, + { + "ShouldPassAudienceAll", + MapClaims{ + consts.ClaimAudience: []any{"abc", "123"}, + }, + []ClaimValidationOption{ValidateAudienceAll("abc", "123")}, + nil, + "", + }, + { + "ShouldFailAudienceAll", + MapClaims{ + consts.ClaimAudience: []any{"abc", "123"}, + }, + []ClaimValidationOption{ValidateAudienceAll("abc", "123", "456")}, + []uint32{ValidationErrorAudience}, + "Token has invalid audience", + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + actual := tc.have.Valid(tc.opts...) + + if len(tc.err) == 0 && tc.err == "" { + assert.NoError(t, actual) + } else { + if tc.err != "" { + assert.EqualError(t, actual, tc.err) + } + + var e *ValidationError + + errors.As(actual, &e) + + var errs uint32 + + for _, err := range tc.errs { + errs |= err + + assert.True(t, e.Has(err)) + } - if want != got { - t.Fatalf("Failed to verify claims, wanted: %v got %v", want, got) + assert.Equal(t, errs, e.Errors) + } + }) } } diff --git a/token/jwt/claims_test.go b/token/jwt/claims_test.go index 667de76c..e06f39b9 100644 --- a/token/jwt/claims_test.go +++ b/token/jwt/claims_test.go @@ -26,3 +26,44 @@ func TestToTime(t *testing.T) { assert.Equal(t, now, ToTime(now.Unix())) assert.Equal(t, now, ToTime(float64(now.Unix()))) } + +func TestFilter(t *testing.T) { + testCases := []struct { + name string + have map[string]any + filter []string + expected map[string]any + }{ + { + "ShouldFilterNone", + map[string]any{"abc": 123}, + []string{}, + map[string]any{"abc": 123}, + }, + { + "ShouldFilterNoneNil", + map[string]any{"abc": 123}, + []string{}, + map[string]any{"abc": 123}, + }, + { + "ShouldFilterAll", + map[string]any{"abc": 123, "example": 123}, + []string{"abc", "example"}, + map[string]any{}, + }, + { + "ShouldFilterSome", + map[string]any{"abc": 123, "example": 123}, + []string{"abc"}, + map[string]any{"example": 123}, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + have := Filter(tc.have, tc.filter...) + assert.Equal(t, tc.expected, have) + }) + } +} From bff1f4abe001f37573392b4796bff9df7f855944 Mon Sep 17 00:00:00 2001 From: James Elliott Date: Fri, 27 Sep 2024 23:40:43 +1000 Subject: [PATCH 23/33] client auth tests --- authorize_request_handler.go | 2 +- ...orize_request_handler_oidc_request_test.go | 214 +++++++++++++++--- token/jwt/jwt_strategy.go | 6 +- token/jwt/jwt_strategy_test.go | 2 +- token/jwt/token.go | 19 +- token/jwt/util.go | 8 +- 6 files changed, 207 insertions(+), 44 deletions(-) diff --git a/authorize_request_handler.go b/authorize_request_handler.go index 36552c29..0e01938f 100644 --- a/authorize_request_handler.go +++ b/authorize_request_handler.go @@ -547,7 +547,7 @@ func fmtRequestObjectDecodeError(token *jwt.Token, client JARClient, issuer stri case errJWTValidation.Has(jwt.ValidationErrorUnverifiable): return outer.WithDebugf("%s client with id '%s' provided a request object that was not able to be verified. %s.", hintRequestObjectPrefix(openid), client.GetID(), strings.TrimPrefix(errJWTValidation.Error(), "go-jose/go-jose: ")) case errJWTValidation.Has(jwt.ValidationErrorSignatureInvalid): - return outer.WithDebugf("%s client with id '%s' provided a request object that has an invalid signature. %s.", hintRequestObjectPrefix(openid), client.GetID(), strings.TrimPrefix(errJWTValidation.Error(), "go-jose/go-jose: ")) + return outer.WithDebugf("%s client with id '%s' provided a request object that has an invalid signature.", hintRequestObjectPrefix(openid), client.GetID()) case errJWTValidation.Has(jwt.ValidationErrorExpired): exp, ok := token.Claims.GetExpiresAt() if ok { diff --git a/authorize_request_handler_oidc_request_test.go b/authorize_request_handler_oidc_request_test.go index c8059888..5f45d5dc 100644 --- a/authorize_request_handler_oidc_request_test.go +++ b/authorize_request_handler_oidc_request_test.go @@ -5,14 +5,18 @@ package oauth2 import ( "context" + "crypto/ecdsa" + "crypto/elliptic" "crypto/rand" "crypto/rsa" + "encoding/base64" "encoding/json" "fmt" "net/http" "net/http/httptest" "net/url" "regexp" + "strings" "testing" "github.com/go-jose/go-jose/v4" @@ -24,34 +28,113 @@ import ( ) func TestAuthorizeRequestParametersFromOpenIDConnectRequestObject(t *testing.T) { - key, err := rsa.GenerateKey(rand.Reader, 1024) //nolint:gosec + keyRSA, err := rsa.GenerateKey(rand.Reader, 2048) require.NoError(t, err) - jwks := &jose.JSONWebKeySet{ + keyECDSA, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) + require.NoError(t, err) + + jwkNone := jose.JSONWebKey{ + Key: jwt.UnsafeAllowNoneSignatureType, + } + + rawClientSecretHS256 := "aaaaaaaaaaaaaaabbbbbbbbbbbbbbbbbbbbbbbcccccccccccccccccccccddddddddddddddddddddddd" + + clientSecretHS256 := NewPlainTextClientSecret(rawClientSecretHS256) + + jwkEncHS := jose.JSONWebKey{ + Key: []byte(rawClientSecretHS256), + Algorithm: string(jose.A128GCMKW), + Use: consts.JSONWebTokenUseEncryption, + } + + fmt.Println(jwkEncHS) + + jwkSigHS := jose.JSONWebKey{ + Key: []byte(rawClientSecretHS256), + KeyID: "hs256-sig", + Algorithm: string(jose.HS256), + Use: consts.JSONWebTokenUseSignature, + } + + jwkPublicSigRSA := jose.JSONWebKey{ + Key: keyRSA.Public(), + KeyID: "rs256-sig", + Algorithm: string(jose.RS256), + Use: consts.JSONWebTokenUseSignature, + } + + jwkPrivateSigRSA := jose.JSONWebKey{ + Key: keyRSA, + KeyID: "rs256-sig", + Algorithm: string(jose.RS256), + Use: consts.JSONWebTokenUseSignature, + } + + jwkPublicSigRSA384 := jose.JSONWebKey{ + Key: keyRSA.Public(), + KeyID: "rs384-sig", + Algorithm: string(jose.RS384), + Use: consts.JSONWebTokenUseSignature, + } + + jwkPrivateSigRSA384 := jose.JSONWebKey{ + Key: keyRSA, + KeyID: "rs384-sig", + Algorithm: string(jose.RS384), + Use: consts.JSONWebTokenUseSignature, + } + + jwkPublicSigECDSA := jose.JSONWebKey{ + Key: keyECDSA.Public(), + KeyID: "es256-sig", + Algorithm: string(jose.ES256), + Use: consts.JSONWebTokenUseSignature, + } + + jwkPrivateSigECDSA := jose.JSONWebKey{ + Key: keyECDSA, + KeyID: "es256-sig", + Algorithm: string(jose.ES256), + Use: consts.JSONWebTokenUseSignature, + } + + jwksPrivate := &jose.JSONWebKeySet{ Keys: []jose.JSONWebKey{ - { - KeyID: "kid-foo", - Use: "sig", - Algorithm: string(jose.RS256), - Key: &key.PublicKey, - }, + jwkPrivateSigRSA, + jwkPrivateSigECDSA, }, } - assertionRequestObjectValid := mustGenerateAssertion(t, jwt.MapClaims{consts.ClaimIssuer: "foo", consts.ClaimClientIdentifier: "foo", consts.ClaimAudience: []string{"https://auth.example.com"}, consts.FormParameterScope: "foo", "foo": "bar", "baz": "baz", consts.FormParameterResponseType: consts.ResponseTypeImplicitFlowToken, consts.FormParameterResponseMode: consts.ResponseModeFormPost}, key, "kid-foo") - assertionRequestObjectInvalidRequestInRequest := mustGenerateAssertion(t, jwt.MapClaims{consts.ClaimIssuer: "foo", consts.FormParameterRequest: "abc", consts.ClaimClientIdentifier: "foo", consts.ClaimAudience: []string{"https://auth.example.com"}, consts.FormParameterScope: "foo", "foo": "bar", "baz": "baz", consts.FormParameterResponseType: consts.ResponseTypeImplicitFlowToken, consts.FormParameterResponseMode: consts.ResponseModeFormPost}, key, "kid-foo") - assertionRequestObjectInvalidRequestURIInRequest := mustGenerateAssertion(t, jwt.MapClaims{consts.ClaimIssuer: "foo", consts.FormParameterRequestURI: "https://auth.example.com", consts.ClaimClientIdentifier: "foo", consts.ClaimAudience: []string{"https://auth.example.com"}, consts.FormParameterScope: "foo", "foo": "bar", "baz": "baz", consts.FormParameterResponseType: consts.ResponseTypeImplicitFlowToken, consts.FormParameterResponseMode: consts.ResponseModeFormPost}, key, "kid-foo") - assertionRequestObjectInvalidClientIDValue := mustGenerateAssertion(t, jwt.MapClaims{consts.ClaimIssuer: "foo", consts.ClaimClientIdentifier: 100, consts.ClaimAudience: []string{"https://auth.example.com"}, consts.FormParameterScope: "foo", "foo": "bar", "baz": "baz", consts.FormParameterResponseType: consts.ResponseTypeImplicitFlowToken, consts.FormParameterResponseMode: consts.ResponseModeFormPost}, key, "kid-foo") - assertionRequestObjectInvalidResponseTypeValue := mustGenerateAssertion(t, jwt.MapClaims{consts.ClaimIssuer: "foo", consts.ClaimAudience: []string{"https://auth.example.com"}, consts.FormParameterScope: "foo", "foo": "bar", "baz": "baz", consts.FormParameterResponseType: 100, consts.FormParameterResponseMode: consts.ResponseModeFormPost}, key, "kid-foo") - assertionRequestObjectInvalidAudience := mustGenerateAssertion(t, jwt.MapClaims{consts.ClaimIssuer: "foo", consts.ClaimAudience: []string{"https://auth.not-example.com"}, consts.FormParameterScope: "foo", "foo": "bar", "baz": "baz", consts.FormParameterResponseType: consts.ResponseTypeAuthorizationCodeFlow, consts.FormParameterResponseMode: consts.ResponseModeFormPost}, key, "kid-foo") - assertionRequestObjectInvalidIssuer := mustGenerateAssertion(t, jwt.MapClaims{consts.ClaimIssuer: "not-foo", consts.ClaimAudience: []string{"https://auth.example.com"}, consts.FormParameterScope: "foo", "foo": "bar", "baz": "baz", consts.FormParameterResponseType: consts.ResponseTypeAuthorizationCodeFlow, consts.FormParameterResponseMode: consts.ResponseModeFormPost}, key, "kid-foo") - assertionRequestObjectValidWithoutKID := mustGenerateAssertion(t, jwt.MapClaims{consts.ClaimIssuer: "foo", consts.ClaimAudience: []string{"https://auth.example.com"}, consts.FormParameterScope: "foo", "foo": "bar", "baz": "baz"}, key, "") - assertionRequestObjectValidNone := mustGenerateNoneAssertion(t, jwt.MapClaims{consts.FormParameterScope: "foo", "foo": "bar", "baz": "baz", consts.FormParameterState: "some-state", consts.ClaimIssuer: "foo", consts.ClaimAudience: []string{"https://auth.example.com"}}) + fmt.Println(jwksPrivate) + + jwksPublic := &jose.JSONWebKeySet{ + Keys: []jose.JSONWebKey{ + jwkPublicSigRSA, + jwkPublicSigRSA384, + jwkPublicSigECDSA, + }, + } + + assertionRequestObjectValid := mustGenerateRequestObjectJWS(t, jwt.MapClaims{consts.ClaimIssuer: "foo", consts.ClaimClientIdentifier: "foo", consts.ClaimAudience: []string{"https://auth.example.com"}, consts.FormParameterScope: "foo", "foo": "bar", "baz": "baz", consts.FormParameterResponseType: consts.ResponseTypeImplicitFlowToken, consts.FormParameterResponseMode: consts.ResponseModeFormPost}, nil, &jwkPrivateSigRSA) + assertionRequestObjectInvalidSignature := mangleSig(assertionRequestObjectValid) + assertionRequestObjectInvalidKID := mustGenerateRequestObjectJWS(t, jwt.MapClaims{consts.ClaimIssuer: "foo", consts.ClaimClientIdentifier: "foo", consts.ClaimAudience: []string{"https://auth.example.com"}, consts.FormParameterScope: "foo", "foo": "bar", "baz": "baz", consts.FormParameterResponseType: consts.ResponseTypeImplicitFlowToken, consts.FormParameterResponseMode: consts.ResponseModeFormPost}, nil, &jwkPrivateSigRSA384) + assertionRequestObjectInvalidTyp := mustGenerateRequestObjectJWS(t, jwt.MapClaims{consts.ClaimIssuer: "foo", consts.ClaimClientIdentifier: "foo", consts.ClaimAudience: []string{"https://auth.example.com"}, consts.FormParameterScope: "foo", "foo": "bar", "baz": "baz", consts.FormParameterResponseType: consts.ResponseTypeImplicitFlowToken, consts.FormParameterResponseMode: consts.ResponseModeFormPost}, &jwt.Headers{Extra: map[string]any{consts.JSONWebTokenHeaderType: "abc"}}, &jwkPrivateSigRSA) + assertionRequestObjectEmptyHS256 := mustGenerateRequestObjectJWS(t, jwt.MapClaims{}, nil, &jwkSigHS) + assertionRequestObjectInvalidRequestInRequest := mustGenerateRequestObjectJWS(t, jwt.MapClaims{consts.ClaimIssuer: "foo", consts.FormParameterRequest: "abc", consts.ClaimClientIdentifier: "foo", consts.ClaimAudience: []string{"https://auth.example.com"}, consts.FormParameterScope: "foo", "foo": "bar", "baz": "baz", consts.FormParameterResponseType: consts.ResponseTypeImplicitFlowToken, consts.FormParameterResponseMode: consts.ResponseModeFormPost}, nil, &jwkPrivateSigRSA) + assertionRequestObjectInvalidRequestURIInRequest := mustGenerateRequestObjectJWS(t, jwt.MapClaims{consts.ClaimIssuer: "foo", consts.FormParameterRequestURI: "https://auth.example.com", consts.ClaimClientIdentifier: "foo", consts.ClaimAudience: []string{"https://auth.example.com"}, consts.FormParameterScope: "foo", "foo": "bar", "baz": "baz", consts.FormParameterResponseType: consts.ResponseTypeImplicitFlowToken, consts.FormParameterResponseMode: consts.ResponseModeFormPost}, nil, &jwkPrivateSigRSA) + assertionRequestObjectInvalidClientIDValue := mustGenerateRequestObjectJWS(t, jwt.MapClaims{consts.ClaimIssuer: "foo", consts.ClaimClientIdentifier: 100, consts.ClaimAudience: []string{"https://auth.example.com"}, consts.FormParameterScope: "foo", "foo": "bar", "baz": "baz", consts.FormParameterResponseType: consts.ResponseTypeImplicitFlowToken, consts.FormParameterResponseMode: consts.ResponseModeFormPost}, nil, &jwkPrivateSigRSA) + assertionRequestObjectInvalidResponseTypeValue := mustGenerateRequestObjectJWS(t, jwt.MapClaims{consts.ClaimIssuer: "foo", consts.ClaimAudience: []string{"https://auth.example.com"}, consts.FormParameterScope: "foo", "foo": "bar", "baz": "baz", consts.FormParameterResponseType: 100, consts.FormParameterResponseMode: consts.ResponseModeFormPost}, nil, &jwkPrivateSigRSA) + assertionRequestObjectInvalidAudience := mustGenerateRequestObjectJWS(t, jwt.MapClaims{consts.ClaimIssuer: "foo", consts.ClaimAudience: []string{"https://auth.not-example.com"}, consts.FormParameterScope: "foo", "foo": "bar", "baz": "baz", consts.FormParameterResponseType: consts.ResponseTypeAuthorizationCodeFlow, consts.FormParameterResponseMode: consts.ResponseModeFormPost}, nil, &jwkPrivateSigRSA) + assertionRequestObjectInvalidIssuer := mustGenerateRequestObjectJWS(t, jwt.MapClaims{consts.ClaimIssuer: "not-foo", consts.ClaimAudience: []string{"https://auth.example.com"}, consts.FormParameterScope: "foo", "foo": "bar", "baz": "baz", consts.FormParameterResponseType: consts.ResponseTypeAuthorizationCodeFlow, consts.FormParameterResponseMode: consts.ResponseModeFormPost}, nil, &jwkPrivateSigRSA) + assertionRequestObjectValidWithoutKID := mustGenerateRequestObjectJWS(t, jwt.MapClaims{consts.ClaimIssuer: "foo", consts.ClaimAudience: []string{"https://auth.example.com"}, consts.FormParameterScope: "foo", "foo": "bar", "baz": "baz"}, nil, &jose.JSONWebKey{Key: keyRSA, Algorithm: string(jose.RS256), Use: consts.JSONWebTokenUseSignature}) + assertionRequestObjectValidNone := mustGenerateRequestObjectJWS(t, jwt.MapClaims{consts.FormParameterScope: "foo", "foo": "bar", "baz": "baz", consts.FormParameterState: "some-state", consts.ClaimIssuer: "foo", consts.ClaimAudience: []string{"https://auth.example.com"}}, nil, &jwkNone) + assertionRequestObjectValidHS256 := mustGenerateRequestObjectJWS(t, jwt.MapClaims{consts.FormParameterScope: "foo", "foo": "bar", "baz": "baz", consts.FormParameterState: "some-state", consts.ClaimIssuer: "foo", consts.ClaimAudience: []string{"https://auth.example.com"}}, nil, &jwkSigHS) mux := http.NewServeMux() var handlerJWKS http.HandlerFunc = func(rw http.ResponseWriter, r *http.Request) { - require.NoError(t, json.NewEncoder(rw).Encode(jwks)) + require.NoError(t, json.NewEncoder(rw).Encode(jwksPublic)) } handleString := func(in string) http.HandlerFunc { @@ -110,7 +193,7 @@ func TestAuthorizeRequestParametersFromOpenIDConnectRequestObject(t *testing.T) { name: "ShouldPassRequest", have: url.Values{consts.FormParameterScope: {"foo openid"}, consts.FormParameterClientID: {"foo"}, consts.FormParameterResponseType: {consts.ResponseTypeImplicitFlowToken}, consts.FormParameterRequest: {assertionRequestObjectValid}}, - client: &DefaultJARClient{JSONWebKeys: jwks, RequestObjectSigningAlg: "RS256", DefaultClient: &DefaultClient{ID: "foo"}}, + client: &DefaultJARClient{JSONWebKeys: jwksPublic, RequestObjectSigningAlg: "RS256", DefaultClient: &DefaultClient{ID: "foo"}}, expected: url.Values{consts.FormParameterScope: {"foo openid"}, consts.FormParameterClientID: {"foo"}, consts.FormParameterResponseType: {consts.ResponseTypeImplicitFlowToken}, consts.FormParameterResponseMode: {consts.ResponseModeFormPost}, consts.FormParameterRequest: {assertionRequestObjectValid}, "foo": {"bar"}, "baz": {"baz"}}, }, { @@ -184,23 +267,55 @@ func TestAuthorizeRequestParametersFromOpenIDConnectRequestObject(t *testing.T) { name: "ShouldFailInvalidTokenMalformed", have: url.Values{consts.FormParameterScope: {consts.ScopeOpenID}, consts.FormParameterClientID: {"foo"}, consts.FormParameterResponseType: {consts.ResponseTypeAuthorizationCodeFlow}, consts.FormParameterRequest: {"bad-token"}}, - client: &DefaultJARClient{JSONWebKeys: jwks, RequestObjectSigningAlg: "RS256", DefaultClient: &DefaultClient{ID: "foo"}}, + client: &DefaultJARClient{JSONWebKeys: jwksPublic, RequestObjectSigningAlg: "RS256", DefaultClient: &DefaultClient{ID: "foo"}}, expected: url.Values{consts.FormParameterScope: {consts.ScopeOpenID}}, err: ErrInvalidRequestObject, errString: "The request parameter contains an invalid Request Object. OpenID Connect 1.0 request object could not be decoded or validated. OpenID Connect 1.0 client with id 'foo' provided a request object that was malformed. The request object does not appear to be a JWE or JWS compact serialized JWT.", }, { name: "ShouldFailUnknownKID", - have: url.Values{consts.FormParameterScope: {consts.ScopeOpenID}, consts.FormParameterClientID: {"foo"}, consts.FormParameterResponseType: {consts.ResponseTypeAuthorizationCodeFlow}, consts.FormParameterRequest: {mustGenerateAssertion(t, jwt.MapClaims{}, key, "does-not-exists")}}, - client: &DefaultJARClient{JSONWebKeys: jwks, RequestObjectSigningAlg: "RS256", DefaultClient: &DefaultClient{ID: "test"}}, + have: url.Values{consts.FormParameterScope: {consts.ScopeOpenID}, consts.FormParameterClientID: {"foo"}, consts.FormParameterResponseType: {consts.ResponseTypeAuthorizationCodeFlow}, consts.FormParameterRequest: {mustGenerateAssertion(t, jwt.MapClaims{}, keyRSA, "does-not-exists")}}, + client: &DefaultJARClient{JSONWebKeys: jwksPublic, RequestObjectSigningAlg: "RS256", DefaultClient: &DefaultClient{ID: "test"}}, expected: url.Values{consts.FormParameterScope: {consts.ScopeOpenID}}, err: ErrInvalidRequestObject, errString: "The request parameter contains an invalid Request Object. OpenID Connect 1.0 request object could not be decoded or validated. OpenID Connect 1.0 client with id 'test' provided a request object that was not able to be verified. Error occurred retrieving the JSON Web Key. The JSON Web Token uses signing key with kid 'does-not-exists' which was not found.", }, + { + name: "ShouldFailBadKID", + have: url.Values{consts.FormParameterScope: {consts.ScopeOpenID}, consts.FormParameterClientID: {"foo"}, consts.FormParameterResponseType: {consts.ResponseTypeAuthorizationCodeFlow}, consts.FormParameterRequest: {assertionRequestObjectInvalidKID}}, + client: &DefaultJARClient{JSONWebKeys: jwksPublic, RequestObjectSigningAlg: "RS256", RequestObjectSigningKeyID: "rs256-sig", DefaultClient: &DefaultClient{ID: "test", ClientSecret: clientSecretHS256}}, + expected: url.Values{consts.FormParameterScope: {consts.ScopeOpenID}}, + err: ErrInvalidRequestObject, + errString: "The request parameter contains an invalid Request Object. OpenID Connect 1.0 request object could not be decoded or validated. OpenID Connect 1.0 client with id 'test' expects request objects to be signed with the 'kid' header value 'rs256-sig' due to the client registration 'request_object_signing_key_id' value but the request object was signed with the 'kid' header value 'rs384-sig'.", + }, + { + name: "ShouldFailBadType", + have: url.Values{consts.FormParameterScope: {consts.ScopeOpenID}, consts.FormParameterClientID: {"foo"}, consts.FormParameterResponseType: {consts.ResponseTypeAuthorizationCodeFlow}, consts.FormParameterRequest: {assertionRequestObjectInvalidTyp}}, + client: &DefaultJARClient{JSONWebKeys: jwksPublic, RequestObjectSigningAlg: "RS256", RequestObjectSigningKeyID: "rs256-sig", DefaultClient: &DefaultClient{ID: "test", ClientSecret: clientSecretHS256}}, + expected: url.Values{consts.FormParameterScope: {consts.ScopeOpenID}}, + err: ErrInvalidRequestObject, + errString: "The request parameter contains an invalid Request Object. OpenID Connect 1.0 request object could not be decoded or validated. OpenID Connect 1.0 client with id 'test' expects request objects to be signed with the 'typ' header value 'JWT' but the request object was signed with the 'typ' header value 'abc'.", + }, + { + name: "ShouldFailBadContentType", + have: url.Values{consts.FormParameterScope: {consts.ScopeOpenID}, consts.FormParameterClientID: {"foo"}, consts.FormParameterResponseType: {consts.ResponseTypeAuthorizationCodeFlow}, consts.FormParameterRequest: {assertionRequestObjectInvalidTyp}}, + client: &DefaultJARClient{JSONWebKeys: jwksPublic, RequestObjectSigningAlg: "RS256", RequestObjectSigningKeyID: "rs256-sig", DefaultClient: &DefaultClient{ID: "test", ClientSecret: clientSecretHS256}}, + expected: url.Values{consts.FormParameterScope: {consts.ScopeOpenID}}, + err: ErrInvalidRequestObject, + errString: "The request parameter contains an invalid Request Object. OpenID Connect 1.0 request object could not be decoded or validated. OpenID Connect 1.0 client with id 'test' expects request objects to be signed with the 'typ' header value 'JWT' but the request object was signed with the 'typ' header value 'abc'.", + }, + { + name: "ShouldFailBadSignature", + have: url.Values{consts.FormParameterScope: {consts.ScopeOpenID}, consts.FormParameterClientID: {"foo"}, consts.FormParameterResponseType: {consts.ResponseTypeAuthorizationCodeFlow}, consts.FormParameterRequest: {assertionRequestObjectInvalidSignature}}, + client: &DefaultJARClient{JSONWebKeys: jwksPublic, RequestObjectSigningAlg: "RS256", RequestObjectSigningKeyID: "rs256-sig", DefaultClient: &DefaultClient{ID: "test", ClientSecret: clientSecretHS256}}, + expected: url.Values{consts.FormParameterScope: {consts.ScopeOpenID}}, + err: ErrInvalidRequestObject, + errString: "The request parameter contains an invalid Request Object. OpenID Connect 1.0 request object could not be decoded or validated. OpenID Connect 1.0 client with id 'test' provided a request object that has an invalid signature.", + }, { name: "ShouldFailBadAlgRS256", - have: url.Values{consts.FormParameterScope: {consts.ScopeOpenID}, consts.FormParameterClientID: {"foo"}, consts.FormParameterResponseType: {consts.ResponseTypeAuthorizationCodeFlow}, consts.FormParameterRequest: {mustGenerateHSAssertion(t, jwt.MapClaims{})}}, - client: &DefaultJARClient{JSONWebKeys: jwks, RequestObjectSigningAlg: "RS256", DefaultClient: &DefaultClient{ID: "test", ClientSecret: NewPlainTextClientSecret("aaaaaaaaaaaaaaabbbbbbbbbbbbbbbbbbbbbbbcccccccccccccccccccccddddddddddddddddddddddd")}}, + have: url.Values{consts.FormParameterScope: {consts.ScopeOpenID}, consts.FormParameterClientID: {"foo"}, consts.FormParameterResponseType: {consts.ResponseTypeAuthorizationCodeFlow}, consts.FormParameterRequest: {assertionRequestObjectEmptyHS256}}, + client: &DefaultJARClient{JSONWebKeys: jwksPublic, RequestObjectSigningAlg: "RS256", DefaultClient: &DefaultClient{ID: "test", ClientSecret: clientSecretHS256}}, expected: url.Values{consts.FormParameterScope: {consts.ScopeOpenID}}, err: ErrInvalidRequestObject, errString: "The request parameter contains an invalid Request Object. OpenID Connect 1.0 request object could not be decoded or validated. OpenID Connect 1.0 client with id 'test' expects request objects to be signed with the 'alg' header value 'RS256' due to the client registration 'request_object_signing_alg' value but the request object was signed with the 'alg' header value 'HS256'.", @@ -208,7 +323,7 @@ func TestAuthorizeRequestParametersFromOpenIDConnectRequestObject(t *testing.T) { name: "ShouldFailMismatchedClientID", have: url.Values{consts.FormParameterScope: {consts.ScopeOpenID}, consts.FormParameterClientID: {"not-foo"}, consts.FormParameterResponseType: {consts.ResponseTypeImplicitFlowToken}, consts.FormParameterResponseMode: {consts.ResponseModeNone}, consts.FormParameterRequest: {assertionRequestObjectValid}}, - client: &DefaultJARClient{JSONWebKeys: jwks, RequestObjectSigningAlg: "RS256", DefaultClient: &DefaultClient{ID: "foo"}}, + client: &DefaultJARClient{JSONWebKeys: jwksPublic, RequestObjectSigningAlg: "RS256", DefaultClient: &DefaultClient{ID: "foo"}}, expected: url.Values{consts.FormParameterClientID: {"foo"}, consts.FormParameterResponseType: {consts.ResponseTypeImplicitFlowToken}, consts.FormParameterResponseMode: {consts.ResponseModeFormPost}, consts.FormParameterScope: {"foo openid"}, consts.FormParameterRequest: {assertionRequestObjectValid}, "foo": {"bar"}, "baz": {"baz"}}, err: ErrInvalidRequestObject, errString: "The request parameter contains an invalid Request Object. OpenID Connect 1.0 request included a request object which excluded claims that are required or included claims that did not match the OAuth 2.0 request syntax or are generally not permitted. The OAuth 2.0 client with id 'foo' included a request object with a 'client_id' claim with a value of 'foo' which is required to match the value 'not-foo' in the parameter with the same name from the OAuth 2.0 request syntax.", @@ -216,7 +331,7 @@ func TestAuthorizeRequestParametersFromOpenIDConnectRequestObject(t *testing.T) { name: "ShouldFailRequestClientIDAssert", have: url.Values{consts.FormParameterScope: {consts.ScopeOpenID}, consts.FormParameterClientID: {"not-foo"}, consts.FormParameterResponseType: {consts.ResponseTypeImplicitFlowToken}, consts.FormParameterResponseMode: {consts.ResponseModeNone}, consts.FormParameterRequest: {assertionRequestObjectInvalidClientIDValue}}, - client: &DefaultJARClient{JSONWebKeys: jwks, RequestObjectSigningAlg: "RS256", DefaultClient: &DefaultClient{ID: "foo"}}, + client: &DefaultJARClient{JSONWebKeys: jwksPublic, RequestObjectSigningAlg: "RS256", DefaultClient: &DefaultClient{ID: "foo"}}, expected: url.Values{consts.FormParameterClientID: {"foo"}, consts.FormParameterResponseType: {consts.ResponseTypeImplicitFlowToken}, consts.FormParameterResponseMode: {consts.ResponseModeFormPost}, consts.FormParameterScope: {"foo openid"}, consts.FormParameterRequest: {assertionRequestObjectInvalidClientIDValue}, "foo": {"bar"}, "baz": {"baz"}}, err: ErrInvalidRequestObject, errString: "The request parameter contains an invalid Request Object. OpenID Connect 1.0 request included a request object which excluded claims that are required or included claims that did not match the OAuth 2.0 request syntax or are generally not permitted. The OAuth 2.0 client with id 'foo' included a request object with a 'client_id' claim with a value of '100' which is required to match the value 'not-foo' in the parameter with the same name from the OAuth 2.0 request syntax but instead of a string it had the int64 type.", @@ -224,7 +339,7 @@ func TestAuthorizeRequestParametersFromOpenIDConnectRequestObject(t *testing.T) { name: "ShouldFailRequestWithRequest", have: url.Values{consts.FormParameterScope: {consts.ScopeOpenID}, consts.FormParameterClientID: {"foo"}, consts.FormParameterResponseType: {consts.ResponseTypeImplicitFlowToken}, consts.FormParameterResponseMode: {consts.ResponseModeNone}, consts.FormParameterRequest: {assertionRequestObjectInvalidRequestInRequest}}, - client: &DefaultJARClient{JSONWebKeys: jwks, RequestObjectSigningAlg: "RS256", DefaultClient: &DefaultClient{ID: "foo"}}, + client: &DefaultJARClient{JSONWebKeys: jwksPublic, RequestObjectSigningAlg: "RS256", DefaultClient: &DefaultClient{ID: "foo"}}, expected: url.Values{consts.FormParameterClientID: {"foo"}, consts.FormParameterResponseType: {consts.ResponseTypeImplicitFlowToken}, consts.FormParameterResponseMode: {consts.ResponseModeFormPost}, consts.FormParameterScope: {"foo openid"}, consts.FormParameterRequest: {assertionRequestObjectInvalidRequestInRequest}, "foo": {"bar"}, "baz": {"baz"}}, err: ErrInvalidRequestObject, errString: "The request parameter contains an invalid Request Object. OpenID Connect 1.0 request included a request object which excluded claims that are required or included claims that did not match the OAuth 2.0 request syntax or are generally not permitted. The OAuth 2.0 client with id 'foo' included a request object which contained the 'request' or 'request_uri' claims but this is not permitted.", @@ -232,7 +347,7 @@ func TestAuthorizeRequestParametersFromOpenIDConnectRequestObject(t *testing.T) { name: "ShouldFailRequestWithRequestURI", have: url.Values{consts.FormParameterScope: {consts.ScopeOpenID}, consts.FormParameterClientID: {"foo"}, consts.FormParameterResponseType: {consts.ResponseTypeImplicitFlowToken}, consts.FormParameterResponseMode: {consts.ResponseModeNone}, consts.FormParameterRequest: {assertionRequestObjectInvalidRequestURIInRequest}}, - client: &DefaultJARClient{JSONWebKeys: jwks, RequestObjectSigningAlg: "RS256", DefaultClient: &DefaultClient{ID: "foo"}}, + client: &DefaultJARClient{JSONWebKeys: jwksPublic, RequestObjectSigningAlg: "RS256", DefaultClient: &DefaultClient{ID: "foo"}}, expected: url.Values{consts.FormParameterClientID: {"foo"}, consts.FormParameterResponseType: {consts.ResponseTypeImplicitFlowToken}, consts.FormParameterResponseMode: {consts.ResponseModeFormPost}, consts.FormParameterScope: {"foo openid"}, consts.FormParameterRequest: {assertionRequestObjectInvalidRequestURIInRequest}, "foo": {"bar"}, "baz": {"baz"}}, err: ErrInvalidRequestObject, errString: "The request parameter contains an invalid Request Object. OpenID Connect 1.0 request included a request object which excluded claims that are required or included claims that did not match the OAuth 2.0 request syntax or are generally not permitted. The OAuth 2.0 client with id 'foo' included a request object which contained the 'request' or 'request_uri' claims but this is not permitted.", @@ -240,7 +355,7 @@ func TestAuthorizeRequestParametersFromOpenIDConnectRequestObject(t *testing.T) { name: "ShouldFailMismatchedResponseType", have: url.Values{consts.FormParameterScope: {consts.ScopeOpenID}, consts.FormParameterClientID: {"foo"}, consts.FormParameterResponseType: {consts.ResponseTypeAuthorizationCodeFlow}, consts.FormParameterResponseMode: {consts.ResponseModeNone}, consts.FormParameterRequest: {assertionRequestObjectValid}}, - client: &DefaultJARClient{JSONWebKeys: jwks, RequestObjectSigningAlg: "RS256", DefaultClient: &DefaultClient{ID: "foo"}}, + client: &DefaultJARClient{JSONWebKeys: jwksPublic, RequestObjectSigningAlg: "RS256", DefaultClient: &DefaultClient{ID: "foo"}}, expected: url.Values{consts.FormParameterClientID: {"foo"}, consts.FormParameterResponseType: {consts.ResponseTypeImplicitFlowToken}, consts.FormParameterResponseMode: {consts.ResponseModeFormPost}, consts.FormParameterScope: {"foo openid"}, consts.FormParameterRequest: {assertionRequestObjectValid}, "foo": {"bar"}, "baz": {"baz"}}, err: ErrInvalidRequestObject, errString: "The request parameter contains an invalid Request Object. OpenID Connect 1.0 request included a request object which excluded claims that are required or included claims that did not match the OAuth 2.0 request syntax or are generally not permitted. The OAuth 2.0 client with id 'foo' included a request object with a 'response_type' claim with a value of 'token' which is required to match the value 'code' in the parameter with the same name from the OAuth 2.0 request syntax.", @@ -248,7 +363,7 @@ func TestAuthorizeRequestParametersFromOpenIDConnectRequestObject(t *testing.T) { name: "ShouldFailMismatchedResponseTypeAsserted", have: url.Values{consts.FormParameterScope: {consts.ScopeOpenID}, consts.FormParameterClientID: {"foo"}, consts.FormParameterResponseType: {consts.ResponseTypeAuthorizationCodeFlow}, consts.FormParameterResponseMode: {consts.ResponseModeNone}, consts.FormParameterRequest: {assertionRequestObjectInvalidResponseTypeValue}}, - client: &DefaultJARClient{JSONWebKeys: jwks, RequestObjectSigningAlg: "RS256", DefaultClient: &DefaultClient{ID: "foo"}}, + client: &DefaultJARClient{JSONWebKeys: jwksPublic, RequestObjectSigningAlg: "RS256", DefaultClient: &DefaultClient{ID: "foo"}}, expected: url.Values{consts.FormParameterClientID: {"foo"}, consts.FormParameterResponseType: {consts.ResponseTypeImplicitFlowToken}, consts.FormParameterResponseMode: {consts.ResponseModeFormPost}, consts.FormParameterScope: {"foo openid"}, consts.FormParameterRequest: {assertionRequestObjectInvalidResponseTypeValue}, "foo": {"bar"}, "baz": {"baz"}}, err: ErrInvalidRequestObject, errString: "The request parameter contains an invalid Request Object. OpenID Connect 1.0 request included a request object which excluded claims that are required or included claims that did not match the OAuth 2.0 request syntax or are generally not permitted. The OAuth 2.0 client with id 'foo' included a request object with a 'response_type' claim with a value of '100' which is required to match the value 'code' in the parameter with the same name from the OAuth 2.0 request syntax but instead of a string it had the int64 type.", @@ -256,13 +371,13 @@ func TestAuthorizeRequestParametersFromOpenIDConnectRequestObject(t *testing.T) { name: "ShouldPassWithoutKID", have: url.Values{consts.FormParameterScope: {consts.ScopeOpenID}, consts.FormParameterClientID: {"foo"}, consts.FormParameterResponseType: {consts.ResponseTypeAuthorizationCodeFlow}, consts.FormParameterRequest: {assertionRequestObjectValidWithoutKID}}, - client: &DefaultJARClient{JSONWebKeys: jwks, RequestObjectSigningAlg: "RS256", DefaultClient: &DefaultClient{ID: "foo"}}, + client: &DefaultJARClient{JSONWebKeys: jwksPublic, RequestObjectSigningAlg: "RS256", DefaultClient: &DefaultClient{ID: "foo"}}, expected: url.Values{consts.FormParameterScope: {"foo openid"}, consts.FormParameterClientID: {"foo"}, consts.FormParameterResponseType: {consts.ResponseTypeAuthorizationCodeFlow}, consts.FormParameterRequest: {assertionRequestObjectValidWithoutKID}, "foo": {"bar"}, "baz": {"baz"}}, }, { name: "ShouldFailRequestURINotWhiteListed", have: url.Values{consts.FormParameterScope: {consts.ScopeOpenID}, consts.FormParameterClientID: {"foo"}, consts.FormParameterResponseType: {consts.ResponseTypeAuthorizationCodeFlow}, consts.FormParameterRequestURI: {root.JoinPath("request-object", "valid", "standard.jwk").String()}}, - client: &DefaultJARClient{JSONWebKeys: jwks, RequestObjectSigningAlg: "RS256", DefaultClient: &DefaultClient{ID: "foo"}}, + client: &DefaultJARClient{JSONWebKeys: jwksPublic, RequestObjectSigningAlg: "RS256", DefaultClient: &DefaultClient{ID: "foo"}}, expected: url.Values{consts.FormParameterScope: {"foo openid"}, consts.FormParameterClientID: {"foo"}, consts.FormParameterResponseType: {consts.ResponseTypeAuthorizationCodeFlow}, consts.FormParameterRequest: {assertionRequestObjectValidWithoutKID}, "foo": {"bar"}, "baz": {"baz"}}, err: ErrInvalidRequestURI, errRegex: regexp.MustCompile(`^The request_uri in the authorization request returns an error or contains invalid data\. OpenID Connect 1\.0 request failed to fetch request parameters from the provided 'request_uri'\. The OAuth 2\.0 client with id 'foo' provided the 'request_uri' parameter with value 'http://127.0.0.1:\d+/request-object/valid/standard\.jwk' which is not whitelisted.$`), @@ -317,6 +432,12 @@ func TestAuthorizeRequestParametersFromOpenIDConnectRequestObject(t *testing.T) client: &DefaultJARClient{JSONWebKeysURI: root.JoinPath("jwks.json").String(), RequestObjectSigningAlg: consts.JSONWebTokenAlgNone, RequestURIs: []string{root.JoinPath("request-object", "valid", "none.jwk").String()}, DefaultClient: &DefaultClient{ID: "foo"}}, expected: url.Values{consts.FormParameterResponseType: {"token"}, consts.FormParameterClientID: {"foo"}, consts.FormParameterState: {"some-state"}, consts.FormParameterScope: {"foo openid"}, consts.FormParameterRequestURI: {root.JoinPath("request-object", "valid", "none.jwk").String()}, "foo": {"bar"}, "baz": {"baz"}}, }, + { + name: "ShouldPassRequestAlgHS256", + have: url.Values{consts.FormParameterScope: {consts.ScopeOpenID}, consts.FormParameterClientID: {"foo"}, consts.FormParameterResponseType: {consts.ResponseTypeAuthorizationCodeFlow}, consts.FormParameterRequest: {assertionRequestObjectValidHS256}}, + client: &DefaultJARClient{JSONWebKeysURI: root.JoinPath("jwks.json").String(), RequestObjectSigningAlg: string(jose.HS256), DefaultClient: &DefaultClient{ID: "foo", ClientSecret: clientSecretHS256}}, + expected: url.Values{consts.FormParameterState: {"some-state"}, consts.FormParameterClientID: {"foo"}, consts.FormParameterResponseType: {consts.ResponseTypeAuthorizationCodeFlow}, consts.FormParameterScope: {"foo openid"}, consts.FormParameterRequest: {assertionRequestObjectValidHS256}, "foo": {"bar"}, "baz": {"baz"}}, + }, { name: "ShouldPassRequestAlgNoneAllowAny", have: url.Values{consts.FormParameterScope: {consts.ScopeOpenID}, consts.FormParameterClientID: {"foo"}, consts.FormParameterResponseType: {consts.ResponseTypeAuthorizationCodeFlow}, consts.FormParameterRequest: {assertionRequestObjectValidNone}}, @@ -424,9 +545,30 @@ func mustGenerateHSAssertion(t *testing.T, claims jwt.MapClaims) string { return tokenString } -func mustGenerateNoneAssertion(t *testing.T, claims jwt.MapClaims) string { - token := jwt.NewWithClaims(jwt.SigningMethodNone, claims) - tokenString, err := token.CompactSignedString(jwt.UnsafeAllowNoneSignatureType) +func mangleSig(tokenString string) string { + parts := strings.Split(tokenString, ".") + raw, err := base64.RawURLEncoding.DecodeString(parts[2]) + if err != nil { + panic(err) + } + + raw = append(raw, []byte("abc")...) + + parts[2] = base64.RawURLEncoding.EncodeToString(raw) + + return strings.Join(parts, ".") +} + +func mustGenerateRequestObjectJWS(t *testing.T, claims jwt.MapClaims, headers jwt.Mapper, key *jose.JSONWebKey) string { + token, _, err := jwt.EncodeCompactSigned(context.TODO(), claims, headers, key) require.NoError(t, err) - return tokenString + + return token +} + +func mustGenerateRequestObjectJWE(t *testing.T, claims jwt.MapClaims, headers, headersJWE jwt.Mapper, key *jose.JSONWebKey, keyEnc *jose.JSONWebKey, enc jose.ContentEncryption) string { + token, _, err := jwt.EncodeNestedCompactEncrypted(context.TODO(), claims, headers, headersJWE, key, keyEnc, enc) + require.NoError(t, err) + + return token } diff --git a/token/jwt/jwt_strategy.go b/token/jwt/jwt_strategy.go index f068e474..f2d858e3 100644 --- a/token/jwt/jwt_strategy.go +++ b/token/jwt/jwt_strategy.go @@ -72,13 +72,13 @@ func (j *DefaultStrategy) Encode(ctx context.Context, opts ...StrategyOpt) (toke } if o.client == nil { - return encodeCompactSigned(ctx, o.claims, o.headers, keySig) + return EncodeCompactSigned(ctx, o.claims, o.headers, keySig) } kid, alg, enc := o.client.GetEncryptionKeyID(), o.client.GetEncryptionAlg(), o.client.GetEncryptionEnc() if len(kid) == 0 && len(alg) == 0 { - return encodeCompactSigned(ctx, o.claims, o.headers, keySig) + return EncodeCompactSigned(ctx, o.claims, o.headers, keySig) } if len(enc) == 0 { @@ -95,7 +95,7 @@ func (j *DefaultStrategy) Encode(ctx context.Context, opts ...StrategyOpt) (toke return "", "", errorsx.WithStack(fmt.Errorf("Failed to encrypt the JWT using the client configuration. %w", err)) } - return encodeNestedCompactEncrypted(ctx, o.claims, o.headers, o.headersJWE, keySig, keyEnc, jose.ContentEncryption(enc)) + return EncodeNestedCompactEncrypted(ctx, o.claims, o.headers, o.headersJWE, keySig, keyEnc, jose.ContentEncryption(enc)) } func (j *DefaultStrategy) Decrypt(ctx context.Context, tokenStringEnc string, opts ...StrategyOpt) (tokenString, signature string, jwe *jose.JSONWebEncryption, err error) { diff --git a/token/jwt/jwt_strategy_test.go b/token/jwt/jwt_strategy_test.go index 60f8ca61..02112b60 100644 --- a/token/jwt/jwt_strategy_test.go +++ b/token/jwt/jwt_strategy_test.go @@ -719,7 +719,7 @@ func TestIniit(t *testing.T) { "exp": time.Now().Add(time.Hour * 24 * 365 * 40).UTC().Unix(), } - out, _, err := encodeNestedCompactEncrypted(context.TODO(), claims, &Headers{}, &Headers{}, &testKeySigECDSA, &testKeyPublicEncECDSA, jose.A128GCM) + out, _, err := EncodeNestedCompactEncrypted(context.TODO(), claims, &Headers{}, &Headers{}, &testKeySigECDSA, &testKeyPublicEncECDSA, jose.A128GCM) fmt.Println(err) fmt.Println(out) diff --git a/token/jwt/token.go b/token/jwt/token.go index dcefdb20..356962c9 100644 --- a/token/jwt/token.go +++ b/token/jwt/token.go @@ -317,7 +317,7 @@ func (t *Token) CompactSigned(k any) (tokenString, signature string, err error) // // > Get the complete, signed token func (t *Token) CompactSignedString(k any) (tokenString string, err error) { - if _, ok := k.(unsafeNoneMagicConstant); ok { + if isUnsafeNoneMagicConstant(k) { return unsignedToken(t) } @@ -594,3 +594,20 @@ func validateTokenType(typValues []string, header map[string]any) bool { return false } + +func isUnsafeNoneMagicConstant(k any) bool { + switch key := k.(type) { + case unsafeNoneMagicConstant: + return true + case jose.JSONWebKey: + if _, ok := key.Key.(unsafeNoneMagicConstant); ok { + return true + } + case *jose.JSONWebKey: + if _, ok := key.Key.(unsafeNoneMagicConstant); ok { + return true + } + } + + return false +} diff --git a/token/jwt/util.go b/token/jwt/util.go index 42a1625d..17f7729b 100644 --- a/token/jwt/util.go +++ b/token/jwt/util.go @@ -309,15 +309,19 @@ func NewJWKFromClientSecret(ctx context.Context, client BaseClient, kid, alg, us }, nil } -func encodeCompactSigned(ctx context.Context, claims MapClaims, headers Mapper, key *jose.JSONWebKey) (tokenString string, signature string, err error) { +func EncodeCompactSigned(ctx context.Context, claims MapClaims, headers Mapper, key *jose.JSONWebKey) (tokenString string, signature string, err error) { token := New() + if headers == nil { + headers = &Headers{} + } + token.SetJWS(headers, claims, key.KeyID, jose.SignatureAlgorithm(key.Algorithm)) return token.CompactSigned(key) } -func encodeNestedCompactEncrypted(ctx context.Context, claims MapClaims, headers, headersJWE Mapper, keySig, keyEnc *jose.JSONWebKey, enc jose.ContentEncryption) (tokenString string, signature string, err error) { +func EncodeNestedCompactEncrypted(ctx context.Context, claims MapClaims, headers, headersJWE Mapper, keySig, keyEnc *jose.JSONWebKey, enc jose.ContentEncryption) (tokenString string, signature string, err error) { token := New() if headers == nil { From 95370b65cc5c5448f95ca8d94a5f28447f93b6a5 Mon Sep 17 00:00:00 2001 From: James Elliott Date: Sat, 28 Sep 2024 14:14:26 +1000 Subject: [PATCH 24/33] client auth tests --- authorize_request_handler.go | 2 + ...orize_request_handler_oidc_request_test.go | 165 +++++++++++++----- client_authentication_strategy.go | 2 + handler/oauth2/strategy_jwt_profile.go | 2 + token/jwt/jwt_strategy.go | 20 +-- token/jwt/token.go | 35 ++-- token/jwt/util.go | 146 ++++++++++------ token/jwt/validation_error.go | 41 ++--- 8 files changed, 272 insertions(+), 141 deletions(-) diff --git a/authorize_request_handler.go b/authorize_request_handler.go index 0e01938f..1d993e0d 100644 --- a/authorize_request_handler.go +++ b/authorize_request_handler.go @@ -532,6 +532,8 @@ func fmtRequestObjectDecodeError(token *jwt.Token, client JARClient, issuer stri return outer.WithDebugf("%s client with id '%s' expects request objects to be signed with the 'typ' header value '%s' but the request object was signed with the 'typ' header value '%s'.", hintRequestObjectPrefix(openid), client.GetID(), consts.JSONWebTokenTypeJWT, token.Header[consts.JSONWebTokenHeaderType]) case errJWTValidation.Has(jwt.ValidationErrorHeaderEncryptionTypeInvalid): return outer.WithDebugf("%s client with id '%s' expects request objects to be encrypted with the 'typ' header value '%s' but the request object was encrypted with the 'typ' header value '%s'.", hintRequestObjectPrefix(openid), client.GetID(), consts.JSONWebTokenTypeJWT, token.HeaderJWE[consts.JSONWebTokenHeaderType]) + case errJWTValidation.Has(jwt.ValidationErrorHeaderContentTypeInvalidMismatch): + return outer.WithDebugf("%s client with id '%s' expects request objects to be encrypted with a 'cty' header value and signed with a 'typ' value that match but the request object was encrypted with the 'cty' header value '%s' and signed with the 'typ' header value '%s'.", hintRequestObjectPrefix(openid), client.GetID(), token.HeaderJWE[consts.JSONWebTokenHeaderContentType], token.HeaderJWE[consts.JSONWebTokenHeaderType]) case errJWTValidation.Has(jwt.ValidationErrorHeaderContentTypeInvalid): return outer.WithDebugf("%s client with id '%s' expects request objects to be encrypted with the 'cty' header value '%s' but the request object was encrypted with the 'cty' header value '%s'.", hintRequestObjectPrefix(openid), client.GetID(), consts.JSONWebTokenTypeJWT, token.HeaderJWE[consts.JSONWebTokenHeaderContentType]) case errJWTValidation.Has(jwt.ValidationErrorHeaderEncryptionKeyIDInvalid): diff --git a/authorize_request_handler_oidc_request_test.go b/authorize_request_handler_oidc_request_test.go index 5f45d5dc..c020a7b8 100644 --- a/authorize_request_handler_oidc_request_test.go +++ b/authorize_request_handler_oidc_request_test.go @@ -18,6 +18,7 @@ import ( "regexp" "strings" "testing" + "time" "github.com/go-jose/go-jose/v4" "github.com/stretchr/testify/assert" @@ -34,75 +35,85 @@ func TestAuthorizeRequestParametersFromOpenIDConnectRequestObject(t *testing.T) keyECDSA, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) require.NoError(t, err) - jwkNone := jose.JSONWebKey{ + jwkNone := &jose.JSONWebKey{ Key: jwt.UnsafeAllowNoneSignatureType, } - rawClientSecretHS256 := "aaaaaaaaaaaaaaabbbbbbbbbbbbbbbbbbbbbbbcccccccccccccccccccccddddddddddddddddddddddd" + rawClientSecret := "aaaaaaaaaaaaaaabbbbbbbbbbbbbbbbbbbbbbbcccccccccccccccccccccddddddddddddddddddddddd" - clientSecretHS256 := NewPlainTextClientSecret(rawClientSecretHS256) + clientSecretHS256 := NewPlainTextClientSecret(rawClientSecret) - jwkEncHS := jose.JSONWebKey{ - Key: []byte(rawClientSecretHS256), - Algorithm: string(jose.A128GCMKW), - Use: consts.JSONWebTokenUseEncryption, - } - - fmt.Println(jwkEncHS) + jwkEncAES256, err := jwt.NewClientSecretJWK(context.TODO(), []byte(rawClientSecret), "", string(jose.A256GCMKW), "", consts.JSONWebTokenUseEncryption) + require.NoError(t, err) - jwkSigHS := jose.JSONWebKey{ - Key: []byte(rawClientSecretHS256), + jwkSigHS := &jose.JSONWebKey{ + Key: []byte(rawClientSecret), KeyID: "hs256-sig", Algorithm: string(jose.HS256), Use: consts.JSONWebTokenUseSignature, } - jwkPublicSigRSA := jose.JSONWebKey{ + jwkPublicSigRSA := &jose.JSONWebKey{ Key: keyRSA.Public(), KeyID: "rs256-sig", Algorithm: string(jose.RS256), Use: consts.JSONWebTokenUseSignature, } - jwkPrivateSigRSA := jose.JSONWebKey{ + jwkPrivateSigRSA := &jose.JSONWebKey{ Key: keyRSA, KeyID: "rs256-sig", Algorithm: string(jose.RS256), Use: consts.JSONWebTokenUseSignature, } - jwkPublicSigRSA384 := jose.JSONWebKey{ + jwkPublicSigRSA384 := &jose.JSONWebKey{ Key: keyRSA.Public(), KeyID: "rs384-sig", Algorithm: string(jose.RS384), Use: consts.JSONWebTokenUseSignature, } - jwkPrivateSigRSA384 := jose.JSONWebKey{ + jwkPrivateSigRSA384 := &jose.JSONWebKey{ Key: keyRSA, KeyID: "rs384-sig", Algorithm: string(jose.RS384), Use: consts.JSONWebTokenUseSignature, } - jwkPublicSigECDSA := jose.JSONWebKey{ + jwkPublicSigECDSA := &jose.JSONWebKey{ Key: keyECDSA.Public(), KeyID: "es256-sig", Algorithm: string(jose.ES256), Use: consts.JSONWebTokenUseSignature, } - jwkPrivateSigECDSA := jose.JSONWebKey{ + jwkPrivateSigECDSA := &jose.JSONWebKey{ Key: keyECDSA, KeyID: "es256-sig", Algorithm: string(jose.ES256), Use: consts.JSONWebTokenUseSignature, } + jwkPublicEncECDSA := &jose.JSONWebKey{ + Key: keyECDSA.Public(), + KeyID: "es256-enc", + Algorithm: string(jose.ECDH_ES_A128KW), + Use: consts.JSONWebTokenUseEncryption, + } + + jwkPrivateEncECDSA := &jose.JSONWebKey{ + Key: keyECDSA, + KeyID: "es256-enc", + Algorithm: string(jose.ECDH_ES_A128KW), + Use: consts.JSONWebTokenUseEncryption, + } + jwksPrivate := &jose.JSONWebKeySet{ Keys: []jose.JSONWebKey{ - jwkPrivateSigRSA, - jwkPrivateSigECDSA, + *jwkPrivateSigRSA, + *jwkPrivateSigECDSA, + *jwkPrivateEncECDSA, }, } @@ -110,26 +121,34 @@ func TestAuthorizeRequestParametersFromOpenIDConnectRequestObject(t *testing.T) jwksPublic := &jose.JSONWebKeySet{ Keys: []jose.JSONWebKey{ - jwkPublicSigRSA, - jwkPublicSigRSA384, - jwkPublicSigECDSA, + *jwkPublicSigRSA, + *jwkPublicSigRSA384, + *jwkPublicSigECDSA, + *jwkPublicEncECDSA, }, } - assertionRequestObjectValid := mustGenerateRequestObjectJWS(t, jwt.MapClaims{consts.ClaimIssuer: "foo", consts.ClaimClientIdentifier: "foo", consts.ClaimAudience: []string{"https://auth.example.com"}, consts.FormParameterScope: "foo", "foo": "bar", "baz": "baz", consts.FormParameterResponseType: consts.ResponseTypeImplicitFlowToken, consts.FormParameterResponseMode: consts.ResponseModeFormPost}, nil, &jwkPrivateSigRSA) + assertionRequestObjectValid := mustGenerateRequestObjectJWS(t, jwt.MapClaims{consts.ClaimIssuer: "foo", consts.ClaimClientIdentifier: "foo", consts.ClaimAudience: []string{"https://auth.example.com"}, consts.FormParameterScope: "foo", "foo": "bar", "baz": "baz", consts.FormParameterResponseType: consts.ResponseTypeImplicitFlowToken, consts.FormParameterResponseMode: consts.ResponseModeFormPost}, nil, jwkPrivateSigRSA) + assertionRequestObjectInvalidExpired := mustGenerateRequestObjectJWS(t, jwt.MapClaims{consts.ClaimIssuer: "foo", consts.ClaimExpirationTime: time.Now().Add(-time.Hour).UTC().Unix(), consts.ClaimClientIdentifier: "foo", consts.ClaimAudience: []string{"https://auth.example.com"}, consts.FormParameterScope: "foo", "foo": "bar", "baz": "baz", consts.FormParameterResponseType: consts.ResponseTypeImplicitFlowToken, consts.FormParameterResponseMode: consts.ResponseModeFormPost}, nil, jwkPrivateSigRSA) + assertionRequestObjectInvalidFuture := mustGenerateRequestObjectJWS(t, jwt.MapClaims{consts.ClaimIssuer: "foo", consts.ClaimIssuedAt: time.Now().Add(time.Hour).UTC().Unix(), consts.ClaimClientIdentifier: "foo", consts.ClaimAudience: []string{"https://auth.example.com"}, consts.FormParameterScope: "foo", "foo": "bar", "baz": "baz", consts.FormParameterResponseType: consts.ResponseTypeImplicitFlowToken, consts.FormParameterResponseMode: consts.ResponseModeFormPost}, nil, jwkPrivateSigRSA) + assertionRequestObjectInvalidNotValidYet := mustGenerateRequestObjectJWS(t, jwt.MapClaims{consts.ClaimIssuer: "foo", consts.ClaimNotBefore: time.Now().Add(time.Hour).UTC().Unix(), consts.ClaimClientIdentifier: "foo", consts.ClaimAudience: []string{"https://auth.example.com"}, consts.FormParameterScope: "foo", "foo": "bar", "baz": "baz", consts.FormParameterResponseType: consts.ResponseTypeImplicitFlowToken, consts.FormParameterResponseMode: consts.ResponseModeFormPost}, nil, jwkPrivateSigRSA) assertionRequestObjectInvalidSignature := mangleSig(assertionRequestObjectValid) - assertionRequestObjectInvalidKID := mustGenerateRequestObjectJWS(t, jwt.MapClaims{consts.ClaimIssuer: "foo", consts.ClaimClientIdentifier: "foo", consts.ClaimAudience: []string{"https://auth.example.com"}, consts.FormParameterScope: "foo", "foo": "bar", "baz": "baz", consts.FormParameterResponseType: consts.ResponseTypeImplicitFlowToken, consts.FormParameterResponseMode: consts.ResponseModeFormPost}, nil, &jwkPrivateSigRSA384) - assertionRequestObjectInvalidTyp := mustGenerateRequestObjectJWS(t, jwt.MapClaims{consts.ClaimIssuer: "foo", consts.ClaimClientIdentifier: "foo", consts.ClaimAudience: []string{"https://auth.example.com"}, consts.FormParameterScope: "foo", "foo": "bar", "baz": "baz", consts.FormParameterResponseType: consts.ResponseTypeImplicitFlowToken, consts.FormParameterResponseMode: consts.ResponseModeFormPost}, &jwt.Headers{Extra: map[string]any{consts.JSONWebTokenHeaderType: "abc"}}, &jwkPrivateSigRSA) - assertionRequestObjectEmptyHS256 := mustGenerateRequestObjectJWS(t, jwt.MapClaims{}, nil, &jwkSigHS) - assertionRequestObjectInvalidRequestInRequest := mustGenerateRequestObjectJWS(t, jwt.MapClaims{consts.ClaimIssuer: "foo", consts.FormParameterRequest: "abc", consts.ClaimClientIdentifier: "foo", consts.ClaimAudience: []string{"https://auth.example.com"}, consts.FormParameterScope: "foo", "foo": "bar", "baz": "baz", consts.FormParameterResponseType: consts.ResponseTypeImplicitFlowToken, consts.FormParameterResponseMode: consts.ResponseModeFormPost}, nil, &jwkPrivateSigRSA) - assertionRequestObjectInvalidRequestURIInRequest := mustGenerateRequestObjectJWS(t, jwt.MapClaims{consts.ClaimIssuer: "foo", consts.FormParameterRequestURI: "https://auth.example.com", consts.ClaimClientIdentifier: "foo", consts.ClaimAudience: []string{"https://auth.example.com"}, consts.FormParameterScope: "foo", "foo": "bar", "baz": "baz", consts.FormParameterResponseType: consts.ResponseTypeImplicitFlowToken, consts.FormParameterResponseMode: consts.ResponseModeFormPost}, nil, &jwkPrivateSigRSA) - assertionRequestObjectInvalidClientIDValue := mustGenerateRequestObjectJWS(t, jwt.MapClaims{consts.ClaimIssuer: "foo", consts.ClaimClientIdentifier: 100, consts.ClaimAudience: []string{"https://auth.example.com"}, consts.FormParameterScope: "foo", "foo": "bar", "baz": "baz", consts.FormParameterResponseType: consts.ResponseTypeImplicitFlowToken, consts.FormParameterResponseMode: consts.ResponseModeFormPost}, nil, &jwkPrivateSigRSA) - assertionRequestObjectInvalidResponseTypeValue := mustGenerateRequestObjectJWS(t, jwt.MapClaims{consts.ClaimIssuer: "foo", consts.ClaimAudience: []string{"https://auth.example.com"}, consts.FormParameterScope: "foo", "foo": "bar", "baz": "baz", consts.FormParameterResponseType: 100, consts.FormParameterResponseMode: consts.ResponseModeFormPost}, nil, &jwkPrivateSigRSA) - assertionRequestObjectInvalidAudience := mustGenerateRequestObjectJWS(t, jwt.MapClaims{consts.ClaimIssuer: "foo", consts.ClaimAudience: []string{"https://auth.not-example.com"}, consts.FormParameterScope: "foo", "foo": "bar", "baz": "baz", consts.FormParameterResponseType: consts.ResponseTypeAuthorizationCodeFlow, consts.FormParameterResponseMode: consts.ResponseModeFormPost}, nil, &jwkPrivateSigRSA) - assertionRequestObjectInvalidIssuer := mustGenerateRequestObjectJWS(t, jwt.MapClaims{consts.ClaimIssuer: "not-foo", consts.ClaimAudience: []string{"https://auth.example.com"}, consts.FormParameterScope: "foo", "foo": "bar", "baz": "baz", consts.FormParameterResponseType: consts.ResponseTypeAuthorizationCodeFlow, consts.FormParameterResponseMode: consts.ResponseModeFormPost}, nil, &jwkPrivateSigRSA) + assertionRequestObjectInvalidKID := mustGenerateRequestObjectJWS(t, jwt.MapClaims{consts.ClaimIssuer: "foo", consts.ClaimClientIdentifier: "foo", consts.ClaimAudience: []string{"https://auth.example.com"}, consts.FormParameterScope: "foo", "foo": "bar", "baz": "baz", consts.FormParameterResponseType: consts.ResponseTypeImplicitFlowToken, consts.FormParameterResponseMode: consts.ResponseModeFormPost}, nil, jwkPrivateSigRSA384) + assertionRequestObjectInvalidTyp := mustGenerateRequestObjectJWS(t, jwt.MapClaims{consts.ClaimIssuer: "foo", consts.ClaimClientIdentifier: "foo", consts.ClaimAudience: []string{"https://auth.example.com"}, consts.FormParameterScope: "foo", "foo": "bar", "baz": "baz", consts.FormParameterResponseType: consts.ResponseTypeImplicitFlowToken, consts.FormParameterResponseMode: consts.ResponseModeFormPost}, &jwt.Headers{Extra: map[string]any{consts.JSONWebTokenHeaderType: "abc"}}, jwkPrivateSigRSA) + assertionRequestObjectInvalidJWEContentType := mustGenerateRequestObjectJWE(t, jwt.MapClaims{consts.ClaimIssuer: "foo", consts.ClaimClientIdentifier: "foo", consts.ClaimAudience: []string{"https://auth.example.com"}, consts.FormParameterScope: "foo", "foo": "bar", "baz": "baz", consts.FormParameterResponseType: consts.ResponseTypeImplicitFlowToken, consts.FormParameterResponseMode: consts.ResponseModeFormPost}, nil, &jwt.Headers{Extra: map[string]any{consts.JSONWebTokenHeaderContentType: "at+jwt"}}, jwkPrivateSigRSA, jwkEncAES256, jose.A256GCM) + assertionRequestObjectInvalidJWEType := mustGenerateRequestObjectJWE(t, jwt.MapClaims{consts.ClaimIssuer: "foo", consts.ClaimClientIdentifier: "foo", consts.ClaimAudience: []string{"https://auth.example.com"}, consts.FormParameterScope: "foo", "foo": "bar", "baz": "baz", consts.FormParameterResponseType: consts.ResponseTypeImplicitFlowToken, consts.FormParameterResponseMode: consts.ResponseModeFormPost}, nil, &jwt.Headers{Extra: map[string]any{consts.JSONWebTokenHeaderType: "at+jwt"}}, jwkPrivateSigRSA, jwkEncAES256, jose.A256GCM) + assertionRequestObjectValidJWE := mustGenerateRequestObjectJWE(t, jwt.MapClaims{consts.ClaimIssuer: "foo", consts.ClaimClientIdentifier: "foo", consts.ClaimAudience: []string{"https://auth.example.com"}, consts.FormParameterScope: "foo", "foo": "bar", "baz": "baz", consts.FormParameterResponseType: consts.ResponseTypeImplicitFlowToken, consts.FormParameterResponseMode: consts.ResponseModeFormPost}, nil, nil, jwkPrivateSigRSA, jwkEncAES256, jose.A256GCM) + assertionRequestObjectValidAssymetricJWE := mustGenerateRequestObjectJWE(t, jwt.MapClaims{consts.ClaimIssuer: "foo", consts.ClaimClientIdentifier: "foo", consts.ClaimAudience: []string{"https://auth.example.com"}, consts.FormParameterScope: "foo", "foo": "bar", "baz": "baz", consts.FormParameterResponseType: consts.ResponseTypeImplicitFlowToken, consts.FormParameterResponseMode: consts.ResponseModeFormPost}, nil, nil, jwkPrivateSigECDSA, jwkPublicEncECDSA, jose.A128GCM) + assertionRequestObjectEmptyHS256 := mustGenerateRequestObjectJWS(t, jwt.MapClaims{}, nil, jwkSigHS) + assertionRequestObjectInvalidRequestInRequest := mustGenerateRequestObjectJWS(t, jwt.MapClaims{consts.ClaimIssuer: "foo", consts.FormParameterRequest: "abc", consts.ClaimClientIdentifier: "foo", consts.ClaimAudience: []string{"https://auth.example.com"}, consts.FormParameterScope: "foo", "foo": "bar", "baz": "baz", consts.FormParameterResponseType: consts.ResponseTypeImplicitFlowToken, consts.FormParameterResponseMode: consts.ResponseModeFormPost}, nil, jwkPrivateSigRSA) + assertionRequestObjectInvalidRequestURIInRequest := mustGenerateRequestObjectJWS(t, jwt.MapClaims{consts.ClaimIssuer: "foo", consts.FormParameterRequestURI: "https://auth.example.com", consts.ClaimClientIdentifier: "foo", consts.ClaimAudience: []string{"https://auth.example.com"}, consts.FormParameterScope: "foo", "foo": "bar", "baz": "baz", consts.FormParameterResponseType: consts.ResponseTypeImplicitFlowToken, consts.FormParameterResponseMode: consts.ResponseModeFormPost}, nil, jwkPrivateSigRSA) + assertionRequestObjectInvalidClientIDValue := mustGenerateRequestObjectJWS(t, jwt.MapClaims{consts.ClaimIssuer: "foo", consts.ClaimClientIdentifier: 100, consts.ClaimAudience: []string{"https://auth.example.com"}, consts.FormParameterScope: "foo", "foo": "bar", "baz": "baz", consts.FormParameterResponseType: consts.ResponseTypeImplicitFlowToken, consts.FormParameterResponseMode: consts.ResponseModeFormPost}, nil, jwkPrivateSigRSA) + assertionRequestObjectInvalidResponseTypeValue := mustGenerateRequestObjectJWS(t, jwt.MapClaims{consts.ClaimIssuer: "foo", consts.ClaimAudience: []string{"https://auth.example.com"}, consts.FormParameterScope: "foo", "foo": "bar", "baz": "baz", consts.FormParameterResponseType: 100, consts.FormParameterResponseMode: consts.ResponseModeFormPost}, nil, jwkPrivateSigRSA) + assertionRequestObjectInvalidAudience := mustGenerateRequestObjectJWS(t, jwt.MapClaims{consts.ClaimIssuer: "foo", consts.ClaimAudience: []string{"https://auth.not-example.com"}, consts.FormParameterScope: "foo", "foo": "bar", "baz": "baz", consts.FormParameterResponseType: consts.ResponseTypeAuthorizationCodeFlow, consts.FormParameterResponseMode: consts.ResponseModeFormPost}, nil, jwkPrivateSigRSA) + assertionRequestObjectInvalidIssuer := mustGenerateRequestObjectJWS(t, jwt.MapClaims{consts.ClaimIssuer: "not-foo", consts.ClaimAudience: []string{"https://auth.example.com"}, consts.FormParameterScope: "foo", "foo": "bar", "baz": "baz", consts.FormParameterResponseType: consts.ResponseTypeAuthorizationCodeFlow, consts.FormParameterResponseMode: consts.ResponseModeFormPost}, nil, jwkPrivateSigRSA) assertionRequestObjectValidWithoutKID := mustGenerateRequestObjectJWS(t, jwt.MapClaims{consts.ClaimIssuer: "foo", consts.ClaimAudience: []string{"https://auth.example.com"}, consts.FormParameterScope: "foo", "foo": "bar", "baz": "baz"}, nil, &jose.JSONWebKey{Key: keyRSA, Algorithm: string(jose.RS256), Use: consts.JSONWebTokenUseSignature}) - assertionRequestObjectValidNone := mustGenerateRequestObjectJWS(t, jwt.MapClaims{consts.FormParameterScope: "foo", "foo": "bar", "baz": "baz", consts.FormParameterState: "some-state", consts.ClaimIssuer: "foo", consts.ClaimAudience: []string{"https://auth.example.com"}}, nil, &jwkNone) - assertionRequestObjectValidHS256 := mustGenerateRequestObjectJWS(t, jwt.MapClaims{consts.FormParameterScope: "foo", "foo": "bar", "baz": "baz", consts.FormParameterState: "some-state", consts.ClaimIssuer: "foo", consts.ClaimAudience: []string{"https://auth.example.com"}}, nil, &jwkSigHS) + assertionRequestObjectValidNone := mustGenerateRequestObjectJWS(t, jwt.MapClaims{consts.FormParameterScope: "foo", "foo": "bar", "baz": "baz", consts.FormParameterState: "some-state", consts.ClaimIssuer: "foo", consts.ClaimAudience: []string{"https://auth.example.com"}}, nil, jwkNone) + assertionRequestObjectValidHS256 := mustGenerateRequestObjectJWS(t, jwt.MapClaims{consts.FormParameterScope: "foo", "foo": "bar", "baz": "baz", consts.FormParameterState: "some-state", consts.ClaimIssuer: "foo", consts.ClaimAudience: []string{"https://auth.example.com"}}, nil, jwkSigHS) mux := http.NewServeMux() @@ -196,6 +215,18 @@ func TestAuthorizeRequestParametersFromOpenIDConnectRequestObject(t *testing.T) client: &DefaultJARClient{JSONWebKeys: jwksPublic, RequestObjectSigningAlg: "RS256", DefaultClient: &DefaultClient{ID: "foo"}}, expected: url.Values{consts.FormParameterScope: {"foo openid"}, consts.FormParameterClientID: {"foo"}, consts.FormParameterResponseType: {consts.ResponseTypeImplicitFlowToken}, consts.FormParameterResponseMode: {consts.ResponseModeFormPost}, consts.FormParameterRequest: {assertionRequestObjectValid}, "foo": {"bar"}, "baz": {"baz"}}, }, + { + name: "ShouldPassRequestJWE", + have: url.Values{consts.FormParameterScope: {"foo openid"}, consts.FormParameterClientID: {"foo"}, consts.FormParameterResponseType: {consts.ResponseTypeImplicitFlowToken}, consts.FormParameterRequest: {assertionRequestObjectValidJWE}}, + client: &DefaultJARClient{JSONWebKeys: jwksPublic, RequestObjectSigningAlg: "RS256", DefaultClient: &DefaultClient{ID: "foo", ClientSecret: clientSecretHS256}}, + expected: url.Values{consts.FormParameterScope: {"foo openid"}, consts.FormParameterClientID: {"foo"}, consts.FormParameterResponseType: {consts.ResponseTypeImplicitFlowToken}, consts.FormParameterResponseMode: {consts.ResponseModeFormPost}, consts.FormParameterRequest: {assertionRequestObjectValidJWE}, "foo": {"bar"}, "baz": {"baz"}}, + }, + { + name: "ShouldPassRequestJWESymmetric", + have: url.Values{consts.FormParameterScope: {"foo openid"}, consts.FormParameterClientID: {"foo"}, consts.FormParameterResponseType: {consts.ResponseTypeImplicitFlowToken}, consts.FormParameterRequest: {assertionRequestObjectValidAssymetricJWE}}, + client: &DefaultJARClient{JSONWebKeys: jwksPublic, RequestObjectSigningAlg: "ES256", RequestObjectSigningKeyID: "es256-sig", RequestObjectEncryptionKeyID: "es256-enc", DefaultClient: &DefaultClient{ID: "foo", ClientSecret: clientSecretHS256}}, + expected: url.Values{consts.FormParameterScope: {"foo openid"}, consts.FormParameterClientID: {"foo"}, consts.FormParameterResponseType: {consts.ResponseTypeImplicitFlowToken}, consts.FormParameterRequest: {assertionRequestObjectValidAssymetricJWE}, "baz": {"baz"}, "foo": {"bar"}, consts.FormParameterResponseMode: {consts.ResponseModeFormPost}}, + }, { name: "ShouldFailRequestNotOpenIDConnectClient", have: url.Values{consts.FormParameterScope: {consts.ScopeOpenID}, consts.FormParameterRequest: {"foo"}}, @@ -297,12 +328,68 @@ func TestAuthorizeRequestParametersFromOpenIDConnectRequestObject(t *testing.T) errString: "The request parameter contains an invalid Request Object. OpenID Connect 1.0 request object could not be decoded or validated. OpenID Connect 1.0 client with id 'test' expects request objects to be signed with the 'typ' header value 'JWT' but the request object was signed with the 'typ' header value 'abc'.", }, { - name: "ShouldFailBadContentType", - have: url.Values{consts.FormParameterScope: {consts.ScopeOpenID}, consts.FormParameterClientID: {"foo"}, consts.FormParameterResponseType: {consts.ResponseTypeAuthorizationCodeFlow}, consts.FormParameterRequest: {assertionRequestObjectInvalidTyp}}, + name: "ShouldFailJWEBadContentType", + have: url.Values{consts.FormParameterScope: {consts.ScopeOpenID}, consts.FormParameterClientID: {"foo"}, consts.FormParameterResponseType: {consts.ResponseTypeAuthorizationCodeFlow}, consts.FormParameterRequest: {assertionRequestObjectInvalidJWEContentType}}, client: &DefaultJARClient{JSONWebKeys: jwksPublic, RequestObjectSigningAlg: "RS256", RequestObjectSigningKeyID: "rs256-sig", DefaultClient: &DefaultClient{ID: "test", ClientSecret: clientSecretHS256}}, expected: url.Values{consts.FormParameterScope: {consts.ScopeOpenID}}, err: ErrInvalidRequestObject, - errString: "The request parameter contains an invalid Request Object. OpenID Connect 1.0 request object could not be decoded or validated. OpenID Connect 1.0 client with id 'test' expects request objects to be signed with the 'typ' header value 'JWT' but the request object was signed with the 'typ' header value 'abc'.", + errString: "The request parameter contains an invalid Request Object. OpenID Connect 1.0 request object could not be decoded or validated. OpenID Connect 1.0 client with id 'test' expects request objects to be encrypted with a 'cty' header value and signed with a 'typ' value that match but the request object was encrypted with the 'cty' header value 'at+jwt' and signed with the 'typ' header value 'JWT'.", + }, + { + name: "ShouldFailJWEBadType", + have: url.Values{consts.FormParameterScope: {consts.ScopeOpenID}, consts.FormParameterClientID: {"foo"}, consts.FormParameterResponseType: {consts.ResponseTypeAuthorizationCodeFlow}, consts.FormParameterRequest: {assertionRequestObjectInvalidJWEType}}, + client: &DefaultJARClient{JSONWebKeys: jwksPublic, RequestObjectSigningAlg: "RS256", RequestObjectSigningKeyID: "rs256-sig", DefaultClient: &DefaultClient{ID: "test", ClientSecret: clientSecretHS256}}, + expected: url.Values{consts.FormParameterScope: {consts.ScopeOpenID}}, + err: ErrInvalidRequestObject, + errString: "The request parameter contains an invalid Request Object. OpenID Connect 1.0 request object could not be decoded or validated. OpenID Connect 1.0 client with id 'test' expects request objects to be encrypted with the 'typ' header value 'JWT' but the request object was encrypted with the 'typ' header value 'at+jwt'.", + }, + { + name: "ShouldFailJWEBadKeyID", + have: url.Values{consts.FormParameterScope: {consts.ScopeOpenID}, consts.FormParameterClientID: {"foo"}, consts.FormParameterResponseType: {consts.ResponseTypeImplicitFlowToken}, consts.FormParameterRequest: {assertionRequestObjectValidAssymetricJWE}}, + client: &DefaultJARClient{JSONWebKeys: jwksPublic, RequestObjectSigningAlg: "ES256", RequestObjectSigningKeyID: "es256-sig", RequestObjectEncryptionKeyID: "abc", DefaultClient: &DefaultClient{ID: "foo", ClientSecret: clientSecretHS256}}, + expected: url.Values{consts.FormParameterScope: {consts.ScopeOpenID}}, + err: ErrInvalidRequestObject, + errString: "The request parameter contains an invalid Request Object. OpenID Connect 1.0 request object could not be decoded or validated. OpenID Connect 1.0 client with id 'foo' expects request objects to be encrypted with the 'kid' header value 'abc' due to the client registration 'request_object_encryption_key_id' value but the request object was encrypted with the 'kid' header value 'es256-enc'.", + }, + { + name: "ShouldFailJWEBadAlg", + have: url.Values{consts.FormParameterScope: {consts.ScopeOpenID}, consts.FormParameterClientID: {"foo"}, consts.FormParameterResponseType: {consts.ResponseTypeImplicitFlowToken}, consts.FormParameterRequest: {assertionRequestObjectValidAssymetricJWE}}, + client: &DefaultJARClient{JSONWebKeys: jwksPublic, RequestObjectSigningAlg: "ES256", RequestObjectSigningKeyID: "es256-sig", RequestObjectEncryptionAlg: "abc", DefaultClient: &DefaultClient{ID: "foo", ClientSecret: clientSecretHS256}}, + expected: url.Values{consts.FormParameterScope: {consts.ScopeOpenID}}, + err: ErrInvalidRequestObject, + errString: "The request parameter contains an invalid Request Object. OpenID Connect 1.0 request object could not be decoded or validated. OpenID Connect 1.0 client with id 'foo' expects request objects to be encrypted with the 'alg' header value 'abc' due to the client registration 'request_object_encryption_alg' value but the request object was encrypted with the 'alg' header value 'ECDH-ES+A128KW'.", + }, + { + name: "ShouldFailJWEBadEnc", + have: url.Values{consts.FormParameterScope: {consts.ScopeOpenID}, consts.FormParameterClientID: {"foo"}, consts.FormParameterResponseType: {consts.ResponseTypeImplicitFlowToken}, consts.FormParameterRequest: {assertionRequestObjectValidAssymetricJWE}}, + client: &DefaultJARClient{JSONWebKeys: jwksPublic, RequestObjectSigningAlg: "ES256", RequestObjectSigningKeyID: "es256-sig", RequestObjectEncryptionEnc: "abc", DefaultClient: &DefaultClient{ID: "foo", ClientSecret: clientSecretHS256}}, + expected: url.Values{consts.FormParameterScope: {consts.ScopeOpenID}}, + err: ErrInvalidRequestObject, + errString: "The request parameter contains an invalid Request Object. OpenID Connect 1.0 request object could not be decoded or validated. OpenID Connect 1.0 client with id 'foo' expects request objects to be encrypted with the 'enc' header value 'abc' due to the client registration 'request_object_encryption_enc' value but the request object was encrypted with the 'enc' header value 'A128GCM'.", + }, + { + name: "ShouldFailExpired", + have: url.Values{consts.FormParameterScope: {consts.ScopeOpenID}, consts.FormParameterClientID: {"foo"}, consts.FormParameterResponseType: {consts.ResponseTypeImplicitFlowToken}, consts.FormParameterRequest: {assertionRequestObjectInvalidExpired}}, + client: &DefaultJARClient{JSONWebKeys: jwksPublic, DefaultClient: &DefaultClient{ID: "foo", ClientSecret: clientSecretHS256}}, + expected: url.Values{consts.FormParameterScope: {consts.ScopeOpenID}}, + err: ErrInvalidRequestObject, + errRegex: regexp.MustCompile(`^The request parameter contains an invalid Request Object\. OpenID Connect 1\.0 request object could not be decoded or validated\. OpenID Connect 1\.0 client with id 'foo' provided a request object that was expired\. The request object expired at \d+\.`), + }, + { + name: "ShouldFailFuture", + have: url.Values{consts.FormParameterScope: {consts.ScopeOpenID}, consts.FormParameterClientID: {"foo"}, consts.FormParameterResponseType: {consts.ResponseTypeImplicitFlowToken}, consts.FormParameterRequest: {assertionRequestObjectInvalidFuture}}, + client: &DefaultJARClient{JSONWebKeys: jwksPublic, DefaultClient: &DefaultClient{ID: "foo", ClientSecret: clientSecretHS256}}, + expected: url.Values{consts.FormParameterScope: {consts.ScopeOpenID}}, + err: ErrInvalidRequestObject, + errRegex: regexp.MustCompile(`^The request parameter contains an invalid Request Object. OpenID Connect 1\.0 request object could not be decoded or validated\. OpenID Connect 1\.0 client with id 'foo' provided a request object that was issued in the future\. The request object was issued at \d+\.$`), + }, + { + name: "ShouldFailNotBefore", + have: url.Values{consts.FormParameterScope: {consts.ScopeOpenID}, consts.FormParameterClientID: {"foo"}, consts.FormParameterResponseType: {consts.ResponseTypeImplicitFlowToken}, consts.FormParameterRequest: {assertionRequestObjectInvalidNotValidYet}}, + client: &DefaultJARClient{JSONWebKeys: jwksPublic, DefaultClient: &DefaultClient{ID: "foo", ClientSecret: clientSecretHS256}}, + expected: url.Values{consts.FormParameterScope: {consts.ScopeOpenID}}, + err: ErrInvalidRequestObject, + errRegex: regexp.MustCompile(`^The request parameter contains an invalid Request Object\. OpenID Connect 1\.0 request object could not be decoded or validated\. OpenID Connect 1\.0 client with id 'foo' provided a request object that was issued in the future\. The request object is not valid before \d+\.`), }, { name: "ShouldFailBadSignature", @@ -501,7 +588,7 @@ func TestAuthorizeRequestParametersFromOpenIDConnectRequestObject(t *testing.T) strategy := &jwt.DefaultStrategy{ Config: config, - Issuer: jwt.MustGenDefaultIssuer(), + Issuer: jwt.NewDefaultIssuerUnverifiedFromJWKS(jwksPrivate), } provider := &Fosite{Config: &Config{JWKSFetcherStrategy: NewDefaultJWKSFetcherStrategy(), IDTokenIssuer: "https://auth.example.com", JWTStrategy: strategy}} diff --git a/client_authentication_strategy.go b/client_authentication_strategy.go index 36ad6462..c92c6583 100644 --- a/client_authentication_strategy.go +++ b/client_authentication_strategy.go @@ -399,6 +399,8 @@ func fmtClientAssertionDecodeError(token *jwt.Token, client AuthenticationMethod return outer.WithDebugf("OAuth 2.0 client with id '%s' expects client assertions to be signed with the 'typ' header value '%s' but the client assertion was signed with the 'typ' header value '%s'.", client.GetID(), consts.JSONWebTokenTypeJWT, token.Header[consts.JSONWebTokenHeaderType]) case errJWTValidation.Has(jwt.ValidationErrorHeaderEncryptionTypeInvalid): return outer.WithDebugf("OAuth 2.0 client with id '%s' expects client assertions to be encrypted with the 'typ' header value '%s' but the client assertion was encrypted with the 'typ' header value '%s'.", client.GetID(), consts.JSONWebTokenTypeJWT, token.HeaderJWE[consts.JSONWebTokenHeaderType]) + case errJWTValidation.Has(jwt.ValidationErrorHeaderContentTypeInvalidMismatch): + return outer.WithDebugf("OAuth 2.0 client with id '%s' expects client assertions to be encrypted with a 'cty' header value and signed with a 'typ' value that match but the client assertions was encrypted with the 'cty' header value '%s' and signed with the 'typ' header value '%s'.", client.GetID(), token.HeaderJWE[consts.JSONWebTokenHeaderContentType], token.HeaderJWE[consts.JSONWebTokenHeaderType]) case errJWTValidation.Has(jwt.ValidationErrorHeaderContentTypeInvalid): return outer.WithDebugf("OAuth 2.0 client with id '%s' expects client assertions to be encrypted with the 'cty' header value '%s' but the client assertion was encrypted with the 'cty' header value '%s'.", client.GetID(), consts.JSONWebTokenTypeJWT, token.HeaderJWE[consts.JSONWebTokenHeaderContentType]) case errJWTValidation.Has(jwt.ValidationErrorHeaderEncryptionKeyIDInvalid): diff --git a/handler/oauth2/strategy_jwt_profile.go b/handler/oauth2/strategy_jwt_profile.go index 7dd51d5d..9811bc8d 100644 --- a/handler/oauth2/strategy_jwt_profile.go +++ b/handler/oauth2/strategy_jwt_profile.go @@ -209,6 +209,8 @@ func fmtValidateJWTError(token *jwt.Token, client jwt.Client, inner error) (err return oauth2.ErrInvalidTokenFormat.WithDebugf("Token %sis expected to be signed with the 'typ' header value '%s' but it was signed with the 'typ' header value '%s'.", clientText, consts.JSONWebTokenTypeJWT, token.Header[consts.JSONWebTokenHeaderType]) case errJWTValidation.Has(jwt.ValidationErrorHeaderEncryptionTypeInvalid): return oauth2.ErrInvalidTokenFormat.WithDebugf("Token %sis expected to be encrypted with the 'typ' header value '%s' but it was encrypted with the 'typ' header value '%s'.", clientText, consts.JSONWebTokenTypeJWT, token.HeaderJWE[consts.JSONWebTokenHeaderType]) + case errJWTValidation.Has(jwt.ValidationErrorHeaderContentTypeInvalidMismatch): + return oauth2.ErrInvalidTokenFormat.WithDebugf("Token %sis expected to be encrypted with a 'cty' header value and signed with a 'typ' value that match but it was encrypted with the 'cty' header value '%s' and signed with the 'typ' header value '%s'.", clientText, token.HeaderJWE[consts.JSONWebTokenHeaderContentType], token.HeaderJWE[consts.JSONWebTokenHeaderType]) case errJWTValidation.Has(jwt.ValidationErrorHeaderContentTypeInvalid): return oauth2.ErrInvalidTokenFormat.WithDebugf("Token %sis expected to be encrypted with the 'cty' header value '%s' but it was encrypted with the 'cty' header value '%s'.", clientText, consts.JSONWebTokenTypeJWT, token.HeaderJWE[consts.JSONWebTokenHeaderContentType]) case errJWTValidation.Has(jwt.ValidationErrorHeaderEncryptionKeyIDInvalid): diff --git a/token/jwt/jwt_strategy.go b/token/jwt/jwt_strategy.go index f2d858e3..e507d08e 100644 --- a/token/jwt/jwt_strategy.go +++ b/token/jwt/jwt_strategy.go @@ -88,7 +88,7 @@ func (j *DefaultStrategy) Encode(ctx context.Context, opts ...StrategyOpt) (toke var keyEnc *jose.JSONWebKey if IsEncryptedJWTClientSecretAlg(alg) { - if keyEnc, err = NewJWKFromClientSecret(ctx, o.client, kid, alg, consts.JSONWebTokenUseEncryption); err != nil { + if keyEnc, err = NewClientSecretJWKFromClient(ctx, o.client, kid, alg, enc, consts.JSONWebTokenUseEncryption); err != nil { return "", "", errorsx.WithStack(fmt.Errorf("Failed to encrypt the JWT using the client secret. %w", err)) } } else if keyEnc, err = FindClientPublicJWK(ctx, o.client, j.Config.GetJWKSFetcherStrategy(ctx), kid, alg, consts.JSONWebTokenUseEncryption, false); err != nil { @@ -128,10 +128,10 @@ func (j *DefaultStrategy) Decrypt(ctx context.Context, tokenStringEnc string, op } var ( - kid, alg, cty string + kid, alg, enc string ) - if kid, alg, _, cty, err = headerValidateJWE(jwe.Header); err != nil { + if kid, alg, enc, _, err = headerValidateJWE(jwe.Header); err != nil { return "", "", nil, errorsx.WithStack(&ValidationError{Errors: ValidationErrorMalformed, Inner: err}) } @@ -144,7 +144,7 @@ func (j *DefaultStrategy) Decrypt(ctx context.Context, tokenStringEnc string, op return "", "", nil, errorsx.WithStack(&ValidationError{Errors: ValidationErrorUnverifiable, Inner: err}) } - if key, err = NewJWKFromClientSecret(ctx, o.client, kid, alg, consts.JSONWebTokenUseEncryption); err != nil { + if key, err = NewClientSecretJWKFromClient(ctx, o.client, kid, alg, enc, consts.JSONWebTokenUseEncryption); err != nil { return "", "", nil, errorsx.WithStack(&ValidationError{Errors: ValidationErrorUnverifiable, Inner: err}) } } else if key, err = j.Issuer.GetIssuerStrictJWK(ctx, kid, alg, consts.JSONWebTokenUseEncryption); err != nil { @@ -159,16 +159,6 @@ func (j *DefaultStrategy) Decrypt(ctx context.Context, tokenStringEnc string, op tokenString = string(tokenRaw) - var t *jwt.JSONWebToken - - if t, err = jwt.ParseSigned(tokenString, o.sigAlgorithm); err != nil { - return "", "", nil, errorsx.WithStack(&ValidationError{Errors: ValidationErrorMalformed, Inner: err}) - } - - if err = headerValidateJWSNested(t.Headers, cty); err != nil { - return "", "", nil, errorsx.WithStack(&ValidationError{Errors: ValidationErrorMalformed, Inner: err}) - } - if signature, err = getJWTSignature(tokenString); err != nil { return "", "", nil, errorsx.WithStack(&ValidationError{Errors: ValidationErrorMalformed, Inner: err}) } @@ -296,7 +286,7 @@ func (j *DefaultStrategy) validate(ctx context.Context, t *jwt.JSONWebToken, des return errorsx.WithStack(&ValidationError{Errors: ValidationErrorHeaderKeyIDInvalid, Inner: fmt.Errorf("error validating the jws header: alg '%s' does not support tokens with a kid but the token has kid '%s'", alg, kid)}) } - if key, err = NewJWKFromClientSecret(ctx, o.client, "", alg, consts.JSONWebTokenUseSignature); err != nil { + if key, err = NewClientSecretJWKFromClient(ctx, o.client, "", alg, "", consts.JSONWebTokenUseSignature); err != nil { return errorsx.WithStack(&ValidationError{Errors: ValidationErrorUnverifiable, Inner: err}) } } else { diff --git a/token/jwt/token.go b/token/jwt/token.go index 356962c9..39c01e7f 100644 --- a/token/jwt/token.go +++ b/token/jwt/token.go @@ -366,8 +366,8 @@ func (t *Token) Valid(opts ...HeaderValidationOption) (err error) { if t.HeaderJWE != nil && (t.KeyAlgorithm != "" || t.ContentEncryption != "") { var ( - cty, typ, ttyp any - ok bool + typ any + ok bool ) if typ, ok = t.HeaderJWE[consts.JSONWebTokenHeaderType]; !ok || typ != consts.JSONWebTokenTypeJWT { @@ -375,14 +375,19 @@ func (t *Token) Valid(opts ...HeaderValidationOption) (err error) { vErr.Errors |= ValidationErrorHeaderEncryptionTypeInvalid } - if ttyp, ok = t.Header[consts.JSONWebTokenHeaderType]; !ok { - vErr.Inner = errors.New("token was signed with invalid typ") - vErr.Errors |= ValidationErrorHeaderTypeInvalid + ttyp := t.Header[consts.JSONWebTokenHeaderType] + cty := t.HeaderJWE[consts.JSONWebTokenHeaderContentType] + + if cty != ttyp { + vErr.Inner = errors.New("token was encrypted with a cty value that doesn't match the typ value") + vErr.Errors |= ValidationErrorHeaderContentTypeInvalidMismatch } - if cty, ok = t.HeaderJWE[consts.JSONWebTokenHeaderContentType]; !ok || cty != ttyp { - vErr.Inner = errors.New("token was encrypted with invalid cty or signed with an invalid typ") - vErr.Errors |= ValidationErrorHeaderContentTypeInvalid + if len(vopts.types) != 0 { + if !validateTokenTypeValue(vopts.types, cty) { + vErr.Inner = errors.New("token was encrypted with an invalid cty") + vErr.Errors |= ValidationErrorHeaderContentTypeInvalid + } } } @@ -571,9 +576,8 @@ func pointer(v any) any { return v } -func validateTokenType(typValues []string, header map[string]any) bool { +func validateTokenType(values []string, header map[string]any) bool { var ( - typ string raw any ok bool ) @@ -582,11 +586,20 @@ func validateTokenType(typValues []string, header map[string]any) bool { return false } + return validateTokenTypeValue(values, raw) +} + +func validateTokenTypeValue(values []string, raw any) bool { + var ( + typ string + ok bool + ) + if typ, ok = raw.(string); !ok { return false } - for _, t := range typValues { + for _, t := range values { if t == typ { return true } diff --git a/token/jwt/util.go b/token/jwt/util.go index 17f7729b..06483fdc 100644 --- a/token/jwt/util.go +++ b/token/jwt/util.go @@ -3,12 +3,16 @@ package jwt import ( "context" "crypto" + "crypto/aes" + "crypto/sha256" + "crypto/sha512" "fmt" + "hash" "regexp" "strings" "github.com/go-jose/go-jose/v4" - jjwt "github.com/go-jose/go-jose/v4/jwt" + "github.com/go-jose/go-jose/v4/jwt" "github.com/pkg/errors" "authelia.com/provider/oauth2/internal/consts" @@ -89,31 +93,6 @@ func headerValidateJWS(headers []jose.Header) (kid, alg string, err error) { return headers[0].KeyID, headers[0].Algorithm, nil } -func headerValidateJWSNested(headers []jose.Header, cty string) (err error) { - switch len(headers) { - case 1: - break - case 0: - return fmt.Errorf("jws header is missing") - default: - return fmt.Errorf("jws header is malformed") - } - - typ, ok := headers[0].ExtraHeaders[consts.JSONWebTokenHeaderType] - if !ok { - return fmt.Errorf("jws header 'typ' value is missing") - } - - switch typ { - case "": - return fmt.Errorf("jws header 'typ' value is empty") - case cty: - return nil - default: - return fmt.Errorf("jws header 'typ' value '%s' is invalid: jwe header 'cty' value '%s' should match the jws header 'typ' value", typ, cty) - } -} - func headerValidateJWE(header jose.Header) (kid, alg, enc, cty string, err error) { if header.KeyID == "" && !IsEncryptedJWTClientSecretAlg(header.Algorithm) { return "", "", "", "", fmt.Errorf("jwe header 'kid' value is missing or empty") @@ -137,7 +116,6 @@ func headerValidateJWE(header jose.Header) (kid, alg, enc, cty string, err error } else if p2c < 200000 { return "", "", "", "", fmt.Errorf("jwe header 'p2c' has an invalid value '%d': less than 200,000", int(p2c)) } - default: return "", "", "", "", fmt.Errorf("jwe header 'p2c' value has invalid type %T", p2c) } @@ -159,25 +137,8 @@ func headerValidateJWE(header jose.Header) (kid, alg, enc, cty string, err error } } - if value, ok = header.ExtraHeaders[consts.JSONWebTokenHeaderContentType]; !ok { - return "", "", "", "", fmt.Errorf("jwe header 'cty' value is missing") - } else { - switch ctyv := value.(type) { - case string: - switch ctyv { - case consts.JSONWebTokenTypeJWT, consts.JSONWebTokenTypeAccessToken, consts.JSONWebTokenTypeAccessTokenAlternative, consts.JSONWebTokenTypeTokenIntrospection: - cty = ctyv - break - default: - return "", "", "", "", fmt.Errorf("jwe header 'cty' value '%s' is invalid", cty) - } - default: - return "", "", "", "", fmt.Errorf("jwe header 'cty' value has invalid type %T", cty) - } - } - - if header.JSONWebKey != nil { - return "", "", "", "", fmt.Errorf("jwe header 'jwk' value is present but not supported") + if value, ok = header.ExtraHeaders[consts.JSONWebTokenHeaderContentType]; ok { + cty, _ = value.(string) } return header.KeyID, header.Algorithm, enc, cty, nil @@ -282,8 +243,8 @@ func SearchJWKS(jwks *jose.JSONWebKeySet, kid, alg, use string, strict bool) (ke } } -// NewJWKFromClientSecret returns a JWK from a client secret. -func NewJWKFromClientSecret(ctx context.Context, client BaseClient, kid, alg, use string) (jwk *jose.JSONWebKey, err error) { +// NewClientSecretJWKFromClient returns a client secret based JWK from a client. +func NewClientSecretJWKFromClient(ctx context.Context, client BaseClient, kid, alg, enc, use string) (jwk *jose.JSONWebKey, err error) { var ( secret []byte ok bool @@ -297,16 +258,89 @@ func NewJWKFromClientSecret(ctx context.Context, client BaseClient, kid, alg, us return nil, &JWKLookupError{Description: "The client is not configured with a client secret"} } + return NewClientSecretJWK(ctx, secret, kid, alg, enc, use) +} + +// NewClientSecretJWK returns a client secret based JWK from a client secret value. +// +// The symmetric encryption key is derived from the client_secret value by using the left-most bits of a truncated +// SHA-2 hash of the octets of the UTF-8 representation of the client_secret. For keys of 256 or fewer bits, SHA-256 +// is used; for keys of 257-384 bits, SHA-384 is used; for keys of 385-512 bits, SHA-512 is used. The hash value MUST +// be truncated retaining the left-most bits to the appropriate bit length for the AES key wrapping or direct +// encryption algorithm used, for instance, truncating the SHA-256 hash to 128 bits for A128KW. If a symmetric key with +// greater than 512 bits is needed, a different method of deriving the key from the client_secret would have to be +// defined by an extension. Symmetric encryption MUST NOT be used by public (non-confidential) Clients because of +// their inability to keep secrets. +func NewClientSecretJWK(ctx context.Context, secret []byte, kid, alg, enc, use string) (jwk *jose.JSONWebKey, err error) { if len(secret) == 0 { return nil, &JWKLookupError{Description: "The client is not configured with a client secret that can be used for symmetric algorithms"} } - return &jose.JSONWebKey{ - Key: secret, - KeyID: kid, - Algorithm: alg, - Use: use, - }, nil + switch use { + case consts.JSONWebTokenUseSignature: + return &jose.JSONWebKey{ + Key: secret, + KeyID: kid, + Algorithm: alg, + Use: use, + }, nil + case consts.JSONWebTokenUseEncryption: + var ( + hasher hash.Hash + bits int + ) + + keyAlg := jose.KeyAlgorithm(alg) + + switch keyAlg { + case jose.A128KW, jose.A128GCMKW, jose.A192KW, jose.A192GCMKW, jose.A256KW, jose.A256GCMKW, jose.PBES2_HS256_A128KW: + hasher = sha256.New() + case jose.PBES2_HS384_A192KW: + hasher = sha512.New384() + case jose.PBES2_HS512_A256KW, jose.DIRECT: + hasher = sha512.New() + default: + return nil, &JWKLookupError{Description: fmt.Sprintf("Unsupported algorithm '%s'", alg)} + } + + switch keyAlg { + case jose.A128KW, jose.A128GCMKW, jose.PBES2_HS256_A128KW: + bits = aes.BlockSize + case jose.A192KW, jose.A192GCMKW, jose.PBES2_HS384_A192KW: + bits = aes.BlockSize * 1.5 + case jose.A256KW, jose.A256GCMKW, jose.PBES2_HS512_A256KW: + bits = aes.BlockSize * 2 + case jose.DIRECT: + switch jose.ContentEncryption(enc) { + case jose.A128CBC_HS256, "": + bits = aes.BlockSize * 2 + case jose.A192CBC_HS384: + bits = aes.BlockSize * 3 + case jose.A256CBC_HS512: + bits = aes.BlockSize * 4 + default: + return nil, &JWKLookupError{Description: fmt.Sprintf("Unsupported content encryption for the direct key algorthm '%s'", enc)} + } + } + + if _, err = hasher.Write(secret); err != nil { + return nil, &JWKLookupError{Description: fmt.Sprintf("Unable to derive key from hashing the client secret. %s", err.Error())} + } + + return &jose.JSONWebKey{ + Key: hasher.Sum(nil)[:bits], + KeyID: kid, + Algorithm: alg, + Use: use, + }, nil + default: + return &jose.JSONWebKey{ + Key: secret, + KeyID: kid, + Algorithm: alg, + Use: use, + }, nil + } } func EncodeCompactSigned(ctx context.Context, claims MapClaims, headers Mapper, key *jose.JSONWebKey) (tokenString string, signature string, err error) { @@ -380,8 +414,8 @@ func getPublicJWK(jwk *jose.JSONWebKey) jose.JSONWebKey { return jwk.Public() } -func UnsafeParseSignedAny(tokenString string, dest any) (token *jjwt.JSONWebToken, err error) { - if token, err = jjwt.ParseSigned(tokenString, SignatureAlgorithmsNone); err != nil { +func UnsafeParseSignedAny(tokenString string, dest any) (token *jwt.JSONWebToken, err error) { + if token, err = jwt.ParseSigned(tokenString, SignatureAlgorithmsNone); err != nil { return nil, err } diff --git a/token/jwt/validation_error.go b/token/jwt/validation_error.go index 712a22df..ce583f76 100644 --- a/token/jwt/validation_error.go +++ b/token/jwt/validation_error.go @@ -6,26 +6,27 @@ package jwt // Validation provides a backwards compatible error definition // from `jwt-go` to `go-jose`. const ( - ValidationErrorMalformed uint32 = 1 << iota // Token is malformed - ValidationErrorMalformedNotCompactSerialized // Token is malformed specifically it does not have the compact serialized format. - ValidationErrorUnverifiable // Token could not be verified because of signing problems - ValidationErrorSignatureInvalid // Signature validation failed. - ValidationErrorHeaderKeyIDInvalid // Header KID invalid error. - ValidationErrorHeaderAlgorithmInvalid // Header ALG invalid error. - ValidationErrorHeaderTypeInvalid // Header TYP invalid error. - ValidationErrorHeaderEncryptionTypeInvalid // Header TYP invalid error (JWE). - ValidationErrorHeaderContentTypeInvalid // Header TYP invalid error (JWE). - ValidationErrorHeaderEncryptionKeyIDInvalid // Header KID invalid error (JWE). - ValidationErrorHeaderKeyAlgorithmInvalid // Header ALG invalid error (JWE). - ValidationErrorHeaderContentEncryptionInvalid // Header ENC invalid error (JWE). - ValidationErrorId // Claim JTI validation failed. - ValidationErrorAudience // Claim AUD validation failed. - ValidationErrorExpired // Claim EXP validation failed. - ValidationErrorIssuedAt // Claim IAT validation failed. - ValidationErrorNotValidYet // Claim NBF validation failed. - ValidationErrorIssuer // Claim ISS validation failed. - ValidationErrorSubject // Claim SUB validation failed. - ValidationErrorClaimsInvalid // Generic claims validation error. + ValidationErrorMalformed uint32 = 1 << iota // Token is malformed + ValidationErrorMalformedNotCompactSerialized // Token is malformed specifically it does not have the compact serialized format. + ValidationErrorUnverifiable // Token could not be verified because of signing problems + ValidationErrorSignatureInvalid // Signature validation failed. + ValidationErrorHeaderKeyIDInvalid // Header KID invalid error. + ValidationErrorHeaderAlgorithmInvalid // Header ALG invalid error. + ValidationErrorHeaderTypeInvalid // Header TYP invalid error. + ValidationErrorHeaderEncryptionTypeInvalid // Header TYP invalid error (JWE). + ValidationErrorHeaderContentTypeInvalid // Header TYP invalid error (JWE). + ValidationErrorHeaderContentTypeInvalidMismatch // Header TYP invalid error (JWE). + ValidationErrorHeaderEncryptionKeyIDInvalid // Header KID invalid error (JWE). + ValidationErrorHeaderKeyAlgorithmInvalid // Header ALG invalid error (JWE). + ValidationErrorHeaderContentEncryptionInvalid // Header ENC invalid error (JWE). + ValidationErrorId // Claim JTI validation failed. + ValidationErrorAudience // Claim AUD validation failed. + ValidationErrorExpired // Claim EXP validation failed. + ValidationErrorIssuedAt // Claim IAT validation failed. + ValidationErrorNotValidYet // Claim NBF validation failed. + ValidationErrorIssuer // Claim ISS validation failed. + ValidationErrorSubject // Claim SUB validation failed. + ValidationErrorClaimsInvalid // Generic claims validation error. ) // The ValidationError is an error implementation from Parse if token is not valid. From d31ada44b3124cb357a3fb132c7e0fd3bea6f080 Mon Sep 17 00:00:00 2001 From: James Elliott Date: Sat, 28 Sep 2024 14:48:45 +1000 Subject: [PATCH 25/33] client auth tests --- client.go | 4 ++-- token/jwt/claims_map_test.go | 9 ++++++--- token/jwt/client.go | 4 ++-- 3 files changed, 10 insertions(+), 7 deletions(-) diff --git a/client.go b/client.go index 1d6bc4e0..3a753867 100644 --- a/client.go +++ b/client.go @@ -394,7 +394,7 @@ type RequestedAudienceImplicitClient interface { type IntrospectionJWTResponseClient interface { // GetIntrospectionSignedResponseKeyID returns the specific key identifier used to satisfy JWS requirements for // OAuth 2.0 JWT introspection response specifications. If unspecified the other available parameters will be - // // utilized to select an appropriate key. + // utilized to select an appropriate key. GetIntrospectionSignedResponseKeyID() (kid string) // GetIntrospectionSignedResponseAlg is equivalent to the 'introspection_signed_response_alg' client metadata @@ -405,7 +405,7 @@ type IntrospectionJWTResponseClient interface { // GetIntrospectionEncryptedResponseKeyID returns the specific key identifier used to satisfy JWE requirements for // OAuth 2.0 JWT introspection response specifications. If unspecified the other available parameters will be - // // utilized to select an appropriate key. + // utilized to select an appropriate key. GetIntrospectionEncryptedResponseKeyID() (kid string) // GetIntrospectionEncryptedResponseAlg is equivalent to the 'introspection_encrypted_response_alg' client metadata diff --git a/token/jwt/claims_map_test.go b/token/jwt/claims_map_test.go index 13f1c0dc..24af2721 100644 --- a/token/jwt/claims_map_test.go +++ b/token/jwt/claims_map_test.go @@ -9,6 +9,7 @@ import ( "time" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" "authelia.com/provider/oauth2/internal/consts" ) @@ -143,7 +144,7 @@ func TestMapClaims_VerifyAudienceAll(t *testing.T) { true, }, { - "ShouldFailMultipleAny", + "ShouldFailMultipleAll", MapClaims{ consts.ClaimAudience: []string{"foo"}, }, @@ -273,7 +274,7 @@ func TestMapClaims_VerifyAudienceAny(t *testing.T) { true, }, { - "ShouldFailMultipleAny", + "ShouldPassMultipleAny", MapClaims{ consts.ClaimAudience: []string{"foo"}, }, @@ -864,7 +865,7 @@ func TestMapClaims_Valid(t *testing.T) { { "ShouldFailEXPNotPresent", MapClaims{}, - []ClaimValidationOption{ValidateRequireExpiresAt()}, + []ClaimValidationOption{ValidateRequireExpiresAt(), ValidateTimeFunc(func() time.Time { return time.Unix(0, 0) })}, []uint32{ValidationErrorExpired}, "Token is expired", }, @@ -995,6 +996,8 @@ func TestMapClaims_Valid(t *testing.T) { errors.As(actual, &e) + require.NotNil(t, e) + var errs uint32 for _, err := range tc.errs { diff --git a/token/jwt/client.go b/token/jwt/client.go index bca5214d..c9f51610 100644 --- a/token/jwt/client.go +++ b/token/jwt/client.go @@ -420,7 +420,7 @@ func (r *decoratedJWTProfileAccessTokenClient) IsClientSigned() (is bool) { type IntrospectionClient interface { // GetIntrospectionSignedResponseKeyID returns the specific key identifier used to satisfy JWS requirements for // OAuth 2.0 JWT introspection response specifications. If unspecified the other available parameters will be - // // utilized to select an appropriate key. + // utilized to select an appropriate key. GetIntrospectionSignedResponseKeyID() (kid string) // GetIntrospectionSignedResponseAlg is equivalent to the 'introspection_signed_response_alg' client metadata @@ -431,7 +431,7 @@ type IntrospectionClient interface { // GetIntrospectionEncryptedResponseKeyID returns the specific key identifier used to satisfy JWE requirements for // OAuth 2.0 JWT introspection response specifications. If unspecified the other available parameters will be - // // utilized to select an appropriate key. + // utilized to select an appropriate key. GetIntrospectionEncryptedResponseKeyID() (kid string) // GetIntrospectionEncryptedResponseAlg is equivalent to the 'introspection_encrypted_response_alg' client metadata From d669b0e713ffc603846ea40650929054754f7b6d Mon Sep 17 00:00:00 2001 From: James Elliott Date: Sat, 28 Sep 2024 15:06:19 +1000 Subject: [PATCH 26/33] client auth tests --- handler/oauth2/strategy_jwt_profile.go | 2 +- introspection_response_writer.go | 8 ++++---- token/jarm/generate.go | 2 +- 3 files changed, 6 insertions(+), 6 deletions(-) diff --git a/handler/oauth2/strategy_jwt_profile.go b/handler/oauth2/strategy_jwt_profile.go index 9811bc8d..45ed1e59 100644 --- a/handler/oauth2/strategy_jwt_profile.go +++ b/handler/oauth2/strategy_jwt_profile.go @@ -176,7 +176,7 @@ func (s *JWTProfileCoreStrategy) GenerateJWT(ctx context.Context, tokenType oaut func validateJWT(ctx context.Context, strategy jwt.Strategy, client jwt.Client, tokenString string) (token *jwt.Token, err error) { if token, err = strategy.Decode(ctx, tokenString, jwt.WithClient(client)); err != nil { - return token, fmtValidateJWTError(token, client, err) + return nil, fmtValidateJWTError(token, client, err) } if err = token.Claims.Valid(); err != nil { diff --git a/introspection_response_writer.go b/introspection_response_writer.go index abb59121..ac65e3e0 100644 --- a/introspection_response_writer.go +++ b/introspection_response_writer.go @@ -283,15 +283,15 @@ func (f *Fosite) writeIntrospectionResponse(ctx context.Context, rw http.Respons claims[consts.ClaimAudience] = aud } - signer := f.Config.GetIntrospectionJWTResponseStrategy(ctx) + strategy := f.Config.GetIntrospectionJWTResponseStrategy(ctx) - if signer == nil { - f.WriteIntrospectionError(ctx, rw, errors.WithStack(ErrServerError.WithHint("Failed to generate the response.").WithDebug("The Introspection JWT could not be generated as the server is misconfigured. The Introspection Signer was not configured."))) + if strategy == nil { + f.WriteIntrospectionError(ctx, rw, errors.WithStack(ErrServerError.WithHint("Failed to generate the response.").WithDebug("The Introspection JWT could not be generated as the server is misconfigured. The Introspection jwt.Strategy was not configured."))) return } - if token, _, err = signer.Encode(ctx, jwt.WithClaims(claims), jwt.WithHeaders(header), jwt.WithIntrospectionClient(r.GetAccessRequester().GetClient())); err != nil { + if token, _, err = strategy.Encode(ctx, jwt.WithClaims(claims), jwt.WithHeaders(header), jwt.WithIntrospectionClient(r.GetAccessRequester().GetClient())); err != nil { f.WriteIntrospectionError(ctx, rw, errors.WithStack(ErrServerError.WithHint("Failed to generate the response.").WithDebugf("The Introspection JWT itself could not be generated with error %+v.", err))) return diff --git a/token/jarm/generate.go b/token/jarm/generate.go index 07fc1ff3..edd94ba5 100644 --- a/token/jarm/generate.go +++ b/token/jarm/generate.go @@ -76,7 +76,7 @@ func Generate(ctx context.Context, config Configurator, client Client, session a var signer jwt.Strategy if signer = config.GetJWTSecuredAuthorizeResponseModeStrategy(ctx); signer == nil { - return "", "", errors.New("The JARM response modes require the JWTSecuredAuthorizeResponseModeSignerProvider to return a jwt.Signer but it didn't.") + return "", "", errors.New("The JARM response modes require the JWTSecuredAuthorizeResponseModeSignerProvider to return a jwt.Strategy but it didn't.") } return signer.Encode(ctx, jwt.WithClaims(claims.ToMapClaims()), jwt.WithHeaders(&jwt.Headers{Extra: headers}), jwt.WithJARMClient(client)) From b43bc14814928820589e5ff66e5c1c42c6a9d1ca Mon Sep 17 00:00:00 2001 From: James Elliott Date: Sat, 28 Sep 2024 16:29:57 +1000 Subject: [PATCH 27/33] docs: add docs --- client_authentication.go | 4 ++++ token/jarm/types.go | 1 + token/jwt/client.go | 11 +++++++++++ token/jwt/util.go | 5 +++++ 4 files changed, 21 insertions(+) diff --git a/client_authentication.go b/client_authentication.go index 4bc22158..d40aebef 100644 --- a/client_authentication.go +++ b/client_authentication.go @@ -161,15 +161,19 @@ type EndpointClientAuthHandler interface { // GetAuthMethod returns the appropriate auth method for this client. GetAuthMethod(client AuthenticationMethodClient) string + // GetAuthSigningKeyID returns the appropriate auth signature key id for this client. GetAuthSigningKeyID(client AuthenticationMethodClient) string // GetAuthSigningAlg returns the appropriate auth signature algorithm for this client. GetAuthSigningAlg(client AuthenticationMethodClient) string + // GetAuthEncryptionKeyID returns the appropriate auth encryption key id for this client. GetAuthEncryptionKeyID(client AuthenticationMethodClient) string + // GetAuthEncryptionAlg returns the appropriate auth encryption key algorithm for this client. GetAuthEncryptionAlg(client AuthenticationMethodClient) string + // GetAuthEncryptionEnc returns the appropriate auth encryption content encryption for this client. GetAuthEncryptionEnc(client AuthenticationMethodClient) string // Name returns the appropriate name for this endpoint for logging purposes. diff --git a/token/jarm/types.go b/token/jarm/types.go index e82bea65..4bb0fdae 100644 --- a/token/jarm/types.go +++ b/token/jarm/types.go @@ -17,6 +17,7 @@ type Client interface { // GetID returns the client ID. GetID() (id string) + // IsPublic returns true if the client has the public client type. IsPublic() (public bool) // GetAuthorizationSignedResponseKeyID returns the specific key identifier used to satisfy JWS requirements of the diff --git a/token/jwt/client.go b/token/jwt/client.go index c9f51610..7066e15c 100644 --- a/token/jwt/client.go +++ b/token/jwt/client.go @@ -4,6 +4,7 @@ import ( "github.com/go-jose/go-jose/v4" ) +// NewJARClient converts a type into a Client provided it implements the JARClient. func NewJARClient(client any) Client { switch c := client.(type) { case JARClient: @@ -13,6 +14,7 @@ func NewJARClient(client any) Client { } } +// NewIDTokenClient converts a type into a Client provided it implements the IDTokenClient. func NewIDTokenClient(client any) Client { switch c := client.(type) { case IDTokenClient: @@ -22,6 +24,7 @@ func NewIDTokenClient(client any) Client { } } +// NewJARMClient converts a type into a Client provided it implements the JARMClient. func NewJARMClient(client any) Client { switch c := client.(type) { case JARMClient: @@ -31,6 +34,7 @@ func NewJARMClient(client any) Client { } } +// NewUserInfoClient converts a type into a Client provided it implements the UserInfoClient. func NewUserInfoClient(client any) Client { switch c := client.(type) { case UserInfoClient: @@ -40,6 +44,7 @@ func NewUserInfoClient(client any) Client { } } +// NewJWTProfileAccessTokenClient converts a type into a Client provided it implements the JWTProfileAccessTokenClient. func NewJWTProfileAccessTokenClient(client any) Client { switch c := client.(type) { case JWTProfileAccessTokenClient: @@ -49,6 +54,7 @@ func NewJWTProfileAccessTokenClient(client any) Client { } } +// NewIntrospectionClient converts a type into a Client provided it implements the IntrospectionClient. func NewIntrospectionClient(client any) Client { switch c := client.(type) { case IntrospectionClient: @@ -58,6 +64,8 @@ func NewIntrospectionClient(client any) Client { } } +// NewStatelessJWTProfileIntrospectionClient converts a type into a Client provided it implements either the +// IntrospectionClient or JWTProfileAccessTokenClient. func NewStatelessJWTProfileIntrospectionClient(client any) Client { switch c := client.(type) { case IntrospectionClient: @@ -69,6 +77,7 @@ func NewStatelessJWTProfileIntrospectionClient(client any) Client { } } +// Client represents a client which can be used to sign, verify, encrypt, and decrypt JWT's. type Client interface { GetSigningKeyID() (kid string) GetSigningAlg() (alg string) @@ -81,6 +90,7 @@ type Client interface { BaseClient } +// BaseClient represents the base implementation for any JWT compatible client. type BaseClient interface { // GetID returns the client ID. GetID() string @@ -105,6 +115,7 @@ type BaseClient interface { GetJSONWebKeysURI() (uri string) } +// JARClient represents the implementation for any JWT Authorization Request compatible client. type JARClient interface { // GetRequestObjectSigningKeyID returns the specific key identifier used to satisfy JWS requirements of the request // object specifications. If unspecified the other available parameters will be utilized to select an appropriate diff --git a/token/jwt/util.go b/token/jwt/util.go index 06483fdc..79f94475 100644 --- a/token/jwt/util.go +++ b/token/jwt/util.go @@ -343,6 +343,7 @@ func NewClientSecretJWK(ctx context.Context, secret []byte, kid, alg, enc, use s } } +// EncodeCompactSigned helps encoding a token using a signature backed compact encoding. func EncodeCompactSigned(ctx context.Context, claims MapClaims, headers Mapper, key *jose.JSONWebKey) (tokenString string, signature string, err error) { token := New() @@ -355,6 +356,8 @@ func EncodeCompactSigned(ctx context.Context, claims MapClaims, headers Mapper, return token.CompactSigned(key) } +// EncodeNestedCompactEncrypted helps encoding a token using a signature backed compact encoding, then nests that within +// an encrypted compact encoded JWT. func EncodeNestedCompactEncrypted(ctx context.Context, claims MapClaims, headers, headersJWE Mapper, keySig, keyEnc *jose.JSONWebKey, enc jose.ContentEncryption) (tokenString string, signature string, err error) { token := New() @@ -414,6 +417,8 @@ func getPublicJWK(jwk *jose.JSONWebKey) jose.JSONWebKey { return jwk.Public() } +// UnsafeParseSignedAny is a function that will attempt to parse any signed token without any verification process. +// It's unsafe for production and should only be used for tests. func UnsafeParseSignedAny(tokenString string, dest any) (token *jwt.JSONWebToken, err error) { if token, err = jwt.ParseSigned(tokenString, SignatureAlgorithmsNone); err != nil { return nil, err From 5df0ecf4fe9652f4b144c11d957d87494d5764f5 Mon Sep 17 00:00:00 2001 From: James Elliott Date: Sat, 28 Sep 2024 17:06:04 +1000 Subject: [PATCH 28/33] fix: lint --- handler/oauth2/strategy_jwt_profile.go | 4 ++-- token/jwt/claims_map_test.go | 4 ++-- token/jwt/client.go | 3 +++ 3 files changed, 7 insertions(+), 4 deletions(-) diff --git a/handler/oauth2/strategy_jwt_profile.go b/handler/oauth2/strategy_jwt_profile.go index 45ed1e59..113d164b 100644 --- a/handler/oauth2/strategy_jwt_profile.go +++ b/handler/oauth2/strategy_jwt_profile.go @@ -258,9 +258,9 @@ func fmtValidateJWTError(token *jwt.Token, client jwt.Client, inner error) (err case errJWTValidation.Has(jwt.ValidationErrorAudience): aud, ok := token.Claims.GetAudience() if ok { - return oauth2.ErrTokenClaim.WithDebugf("Token %shas an invalid audience. The token was expected to have an 'iss' claim with one of the following values: ''. The 'iss' claim has a value of '%s'.", clientText, aud) + return oauth2.ErrTokenClaim.WithDebugf("Token %shas an invalid audience. The token was expected to have an 'iss' claim with one of the following values: ''. The 'aud' claim has a value of '%s'.", clientText, aud) } else { - return oauth2.ErrTokenClaim.WithDebugf("Token %shas an invalid audience. The token does not have an 'iss' claim or it has an invalid type.", clientText) + return oauth2.ErrTokenClaim.WithDebugf("Token %shas an invalid audience. The token does not have an 'aud' claim or it has an invalid type.", clientText) } case errJWTValidation.Has(jwt.ValidationErrorClaimsInvalid): return oauth2.ErrTokenClaim.WithDebugf("Token %shas invalid claims. Error occurred trying to validate the request objects claims: %s", clientText, strings.TrimPrefix(errJWTValidation.Error(), "go-jose/go-jose: ")) diff --git a/token/jwt/claims_map_test.go b/token/jwt/claims_map_test.go index 24af2721..d5fdb53d 100644 --- a/token/jwt/claims_map_test.go +++ b/token/jwt/claims_map_test.go @@ -276,9 +276,9 @@ func TestMapClaims_VerifyAudienceAny(t *testing.T) { { "ShouldPassMultipleAny", MapClaims{ - consts.ClaimAudience: []string{"foo"}, + consts.ClaimAudience: []string{"foo", "baz"}, }, - []string{"foo", "bar"}, + []string{"bar", "baz"}, true, true, }, diff --git a/token/jwt/client.go b/token/jwt/client.go index 7066e15c..0082f94e 100644 --- a/token/jwt/client.go +++ b/token/jwt/client.go @@ -397,6 +397,9 @@ type JWTProfileAccessTokenClient interface { // MUST NOT be specified without setting access_token_encrypted_response_alg. GetAccessTokenEncryptedResponseEnc() (alg string) + // GetEnableJWTProfileOAuthAccessTokens indicates this client should or should not issue JWT Profile Access Tokens. + GetEnableJWTProfileOAuthAccessTokens() (enforce bool) + BaseClient } From 14388173598beb2f1d43eac665c26494664df7f7 Mon Sep 17 00:00:00 2001 From: James Elliott Date: Sat, 28 Sep 2024 23:44:50 +1000 Subject: [PATCH 29/33] feat: claims interface --- arguments_test.go | 5 - authorize_helper_test.go | 1 - authorize_request_handler.go | 28 +- ...orize_request_handler_oidc_request_test.go | 2 - authorize_response_writer_test.go | 1 - client_authentication_strategy.go | 34 ++- handler/oauth2/introspector_jwt.go | 2 +- handler/oauth2/strategy_jwt_profile.go | 30 ++- handler/openid/flow_explicit_token_test.go | 4 +- handler/openid/flow_refresh_token_test.go | 4 +- handler/openid/helper_test.go | 1 - handler/openid/strategy_jwt.go | 18 +- handler/openid/validator.go | 11 +- handler/openid/validator_test.go | 3 +- .../rfc8628/token_endpoint_handler_test.go | 2 - handler/rfc8693/custom_jwt_type_handler.go | 4 +- handler/rfc8693/token_exchange_test.go | 2 +- helper_test.go | 1 - ...rize_code_grant_public_client_pkce_test.go | 8 - integration/client_credentials_grant_test.go | 2 - integration/introspect_token_test.go | 2 - ...e_owner_password_credentials_grant_test.go | 2 - introspection_response_writer.go | 4 +- pushed_authorize_response_writer_test.go | 1 - ...8_device_authorize_response_writer_test.go | 1 - ...628_user_authorize_request_handler_test.go | 1 - ...628_user_authorize_response_writer_test.go | 1 - session.go | 2 +- testing/mock/client.go | 4 +- token/hmac/hmacsha_test.go | 3 +- token/jarm/generate.go | 2 +- token/jwt/claims.go | 11 + token/jwt/claims_id_token.go | 30 +-- token/jwt/claims_id_token_test.go | 26 +- token/jwt/claims_jarm.go | 26 +- token/jwt/claims_jarm_test.go | 17 +- token/jwt/claims_jwt.go | 71 ++++-- token/jwt/claims_jwt_test.go | 28 +- token/jwt/claims_map.go | 241 ++++++++++++------ token/jwt/claims_map_test.go | 57 +---- token/jwt/consts.go | 65 ++++- token/jwt/date.go | 115 +++++++++ token/jwt/header.go | 6 +- token/jwt/issuer.go | 8 +- token/jwt/jwt_strategy.go | 32 ++- token/jwt/jwt_strategy_opts.go | 9 - token/jwt/jwt_strategy_test.go | 60 +++-- token/jwt/token.go | 73 +++--- token/jwt/token_test.go | 19 +- token/jwt/util.go | 36 ++- token/jwt/variables.go | 9 + 51 files changed, 693 insertions(+), 432 deletions(-) create mode 100644 token/jwt/date.go create mode 100644 token/jwt/variables.go diff --git a/arguments_test.go b/arguments_test.go index a60c8035..94b2e51f 100644 --- a/arguments_test.go +++ b/arguments_test.go @@ -54,7 +54,6 @@ func TestArgumentsExactOne(t *testing.T) { for k, c := range testCases { assert.Equal(t, c.expect, c.args.ExactOne(c.exact), "%d", k) - t.Logf("Passed test case %d", k) } } @@ -106,7 +105,6 @@ func TestArgumentsHas(t *testing.T) { }, } { assert.Equal(t, c.expect, c.args.Has(c.has...), "%d", k) - t.Logf("Passed test case %d", k) } } @@ -192,7 +190,6 @@ func TestArgumentsMatchesExact(t *testing.T) { }...) for k, c := range testCases { assert.Equal(t, c.expect, c.args.MatchesExact(c.is...), "%d", k) - t.Logf("Passed test case %d", k) } } @@ -224,7 +221,6 @@ func TestArgumentsMatches(t *testing.T) { }...) for k, c := range testCases { assert.Equal(t, c.expect, c.args.Matches(c.is...), "%d", k) - t.Logf("Passed test case %d", k) } } @@ -251,6 +247,5 @@ func TestArgumentsOneOf(t *testing.T) { }, } { assert.Equal(t, c.expect, c.args.HasOneOf(c.oneOf...), "%d", k) - t.Logf("Passed test case %d", k) } } diff --git a/authorize_helper_test.go b/authorize_helper_test.go index f7bf6c98..1e6a1fdc 100644 --- a/authorize_helper_test.go +++ b/authorize_helper_test.go @@ -221,7 +221,6 @@ func TestDoesClientWhiteListRedirect(t *testing.T) { require.NotNil(t, redir, "%d", k) assert.Equal(t, c.expected, redir.String(), "%d", k) } - t.Logf("Passed test case %d", k) } } diff --git a/authorize_request_handler.go b/authorize_request_handler.go index 1d993e0d..89fa4f5a 100644 --- a/authorize_request_handler.go +++ b/authorize_request_handler.go @@ -151,7 +151,7 @@ func (f *Fosite) authorizeRequestParametersFromOpenIDConnectRequestObject(ctx co v any ) - for k, v = range claims { + for k, v = range claims.ToMapClaims() { switch k { case consts.FormParameterRequest, consts.FormParameterRequestURI: // The request and request_uri parameters MUST NOT be included in Request Objects. @@ -551,36 +551,36 @@ func fmtRequestObjectDecodeError(token *jwt.Token, client JARClient, issuer stri case errJWTValidation.Has(jwt.ValidationErrorSignatureInvalid): return outer.WithDebugf("%s client with id '%s' provided a request object that has an invalid signature.", hintRequestObjectPrefix(openid), client.GetID()) case errJWTValidation.Has(jwt.ValidationErrorExpired): - exp, ok := token.Claims.GetExpiresAt() - if ok { - return outer.WithDebugf("%s client with id '%s' provided a request object that was expired. The request object expired at %d.", hintRequestObjectPrefix(openid), client.GetID(), exp) + exp, err := token.Claims.GetExpirationTime() + if err == nil { + return outer.WithDebugf("%s client with id '%s' provided a request object that was expired. The request object expired at %d.", hintRequestObjectPrefix(openid), client.GetID(), exp.Int64()) } else { return outer.WithDebugf("%s client with id '%s' provided a request object that was expired. The request object does not have an 'exp' claim or it has an invalid type.", hintRequestObjectPrefix(openid), client.GetID()) } case errJWTValidation.Has(jwt.ValidationErrorIssuedAt): - iat, ok := token.Claims.GetIssuedAt() - if ok { - return outer.WithDebugf("%s client with id '%s' provided a request object that was issued in the future. The request object was issued at %d.", hintRequestObjectPrefix(openid), client.GetID(), iat) + iat, err := token.Claims.GetIssuedAt() + if err == nil { + return outer.WithDebugf("%s client with id '%s' provided a request object that was issued in the future. The request object was issued at %d.", hintRequestObjectPrefix(openid), client.GetID(), iat.Int64()) } else { return outer.WithDebugf("%s client with id '%s' provided a request object that was issued in the future. The request object does not have an 'iat' claim or it has an invalid type.", hintRequestObjectPrefix(openid), client.GetID()) } case errJWTValidation.Has(jwt.ValidationErrorNotValidYet): - nbf, ok := token.Claims.GetNotBefore() - if ok { - return outer.WithDebugf("%s client with id '%s' provided a request object that was issued in the future. The request object is not valid before %d.", hintRequestObjectPrefix(openid), client.GetID(), nbf) + nbf, err := token.Claims.GetNotBefore() + if err == nil { + return outer.WithDebugf("%s client with id '%s' provided a request object that was issued in the future. The request object is not valid before %d.", hintRequestObjectPrefix(openid), client.GetID(), nbf.Int64()) } else { return outer.WithDebugf("%s client with id '%s' provided a request object that was issued in the future. The request object does not have an 'nbf' claim or it has an invalid type.", hintRequestObjectPrefix(openid), client.GetID()) } case errJWTValidation.Has(jwt.ValidationErrorIssuer): - iss, ok := token.Claims.GetIssuer() - if ok { + iss, err := token.Claims.GetIssuer() + if err == nil { return outer.WithDebugf("%s client with id '%s' provided a request object that has an invalid issuer. The request object was expected to have an 'iss' claim which matches the value '%s' but the 'iss' claim had the value '%s'.", hintRequestObjectPrefix(openid), client.GetID(), client.GetID(), iss) } else { return outer.WithDebugf("%s client with id '%s' provided a request object that has an invalid issuer. The request object does not have an 'iss' claim or it has an invalid type.", hintRequestObjectPrefix(openid), client.GetID()) } case errJWTValidation.Has(jwt.ValidationErrorAudience): - aud, ok := token.Claims.GetAudience() - if ok { + aud, err := token.Claims.GetAudience() + if err == nil { return outer.WithDebugf("%s client with id '%s' provided a request object that has an invalid audience. The request object was expected to have an 'aud' claim which matches the issuer value of '%s' but the 'aud' claim had the values '%s'.", hintRequestObjectPrefix(openid), client.GetID(), issuer, strings.Join(aud, "', '")) } else { return outer.WithDebugf("%s client with id '%s' provided a request object that has an invalid audience. The request object does not have an 'aud' claim or it has an invalid type.", hintRequestObjectPrefix(openid), client.GetID()) diff --git a/authorize_request_handler_oidc_request_test.go b/authorize_request_handler_oidc_request_test.go index c020a7b8..32920a4a 100644 --- a/authorize_request_handler_oidc_request_test.go +++ b/authorize_request_handler_oidc_request_test.go @@ -117,8 +117,6 @@ func TestAuthorizeRequestParametersFromOpenIDConnectRequestObject(t *testing.T) }, } - fmt.Println(jwksPrivate) - jwksPublic := &jose.JSONWebKeySet{ Keys: []jose.JSONWebKey{ *jwkPublicSigRSA, diff --git a/authorize_response_writer_test.go b/authorize_response_writer_test.go index 7ba3d441..d6d33973 100644 --- a/authorize_response_writer_test.go +++ b/authorize_response_writer_test.go @@ -91,6 +91,5 @@ func TestNewAuthorizeResponse(t *testing.T) { } else { assert.NotNil(t, responder, "%d", k) } - t.Logf("Passed test case %d", k) } } diff --git a/client_authentication_strategy.go b/client_authentication_strategy.go index c92c6583..0b39c58f 100644 --- a/client_authentication_strategy.go +++ b/client_authentication_strategy.go @@ -156,10 +156,8 @@ func NewClientAssertion(ctx context.Context, strategy jwt.Strategy, store Client return &ClientAssertion{Assertion: assertion, Type: assertionType}, resolveJWTErrorToRFCError(err) } - var ok bool - - if id, ok = token.Claims.GetSubject(); !ok { - if id, ok = token.Claims.GetIssuer(); !ok { + if id, err = token.Claims.GetSubject(); err != nil || len(id) == 0 { + if id, err = token.Claims.GetIssuer(); err != nil || len(id) == 0 { return &ClientAssertion{Assertion: assertion, Type: assertionType}, nil } } @@ -273,7 +271,7 @@ func (s *DefaultClientAuthenticationStrategy) doAuthenticateAssertionJWTBearer(c claims := &jwt.JWTClaims{} - claims.FromMapClaims(token.Claims) + claims.FromMapClaims(token.Claims.ToMapClaims()) switch { case subtle.ConstantTimeCompare([]byte(claims.Issuer), clientID) == 0: @@ -418,36 +416,36 @@ func fmtClientAssertionDecodeError(token *jwt.Token, client AuthenticationMethod case errJWTValidation.Has(jwt.ValidationErrorSignatureInvalid): return outer.WithDebugf("OAuth 2.0 client with id '%s' provided a client assertion that has an invalid signature. %s.", client.GetID(), strings.TrimPrefix(errJWTValidation.Error(), "go-jose/go-jose: ")) case errJWTValidation.Has(jwt.ValidationErrorExpired): - exp, ok := token.Claims.GetExpiresAt() - if ok { - return outer.WithDebugf("OAuth 2.0 client with id '%s' provided a client assertion that was expired. The client assertion expired at %d.", client.GetID(), exp) + exp, err := token.Claims.GetExpirationTime() + if err == nil { + return outer.WithDebugf("OAuth 2.0 client with id '%s' provided a client assertion that was expired. The client assertion expired at %d.", client.GetID(), exp.Int64()) } else { return outer.WithDebugf("OAuth 2.0 client with id '%s' provided a client assertion that was expired. The client assertion does not have an 'exp' claim or it has an invalid type.", client.GetID()) } case errJWTValidation.Has(jwt.ValidationErrorIssuedAt): - iat, ok := token.Claims.GetIssuedAt() - if ok { - return outer.WithDebugf("OAuth 2.0 client with id '%s' provided a client assertion that was issued in the future. The client assertion was issued at %d.", client.GetID(), iat) + iat, err := token.Claims.GetIssuedAt() + if err == nil { + return outer.WithDebugf("OAuth 2.0 client with id '%s' provided a client assertion that was issued in the future. The client assertion was issued at %d.", client.GetID(), iat.Int64()) } else { return outer.WithDebugf("OAuth 2.0 client with id '%s' provided a client assertion that was issued in the future. The client assertion does not have an 'iat' claim or it has an invalid type.", client.GetID()) } case errJWTValidation.Has(jwt.ValidationErrorNotValidYet): - nbf, ok := token.Claims.GetNotBefore() - if ok { - return outer.WithDebugf("OAuth 2.0 client with id '%s' provided a client assertion that was issued in the future. The client assertion is not valid before %d.", client.GetID(), nbf) + nbf, err := token.Claims.GetNotBefore() + if err == nil { + return outer.WithDebugf("OAuth 2.0 client with id '%s' provided a client assertion that was issued in the future. The client assertion is not valid before %d.", client.GetID(), nbf.Int64()) } else { return outer.WithDebugf("OAuth 2.0 client with id '%s' provided a client assertion that was issued in the future. The client assertion does not have an 'nbf' claim or it has an invalid type.", client.GetID()) } case errJWTValidation.Has(jwt.ValidationErrorIssuer): - iss, ok := token.Claims.GetIssuer() - if ok { + iss, err := token.Claims.GetIssuer() + if err == nil { return outer.WithDebugf("OAuth 2.0 client with id '%s' provided a client assertion that has an invalid issuer. The client assertion was expected to have an 'iss' claim which matches the value '%s' but the 'iss' claim had the value '%s'.", client.GetID(), client.GetID(), iss) } else { return outer.WithDebugf("OAuth 2.0 client with id '%s' provided a client assertion that has an invalid issuer. The client assertion does not have an 'iss' claim or it has an invalid type.", client.GetID()) } case errJWTValidation.Has(jwt.ValidationErrorAudience): - aud, ok := token.Claims.GetAudience() - if ok { + aud, err := token.Claims.GetAudience() + if err == nil { return outer.WithDebugf("OAuth 2.0 client with id '%s' provided a client assertion that has an invalid audience. The client assertion was expected to have an 'aud' claim which matches one of the values '%s' but the 'aud' claim had the values '%s'.", client.GetID(), strings.Join(audience, "', '"), strings.Join(aud, "', '")) } else { return outer.WithDebugf("OAuth 2.0 client with id '%s' provided a client assertion that has an invalid audience. The client assertion does not have an 'aud' claim or it has an invalid type.", client.GetID()) diff --git a/handler/oauth2/introspector_jwt.go b/handler/oauth2/introspector_jwt.go index 2b55be38..78be9f64 100644 --- a/handler/oauth2/introspector_jwt.go +++ b/handler/oauth2/introspector_jwt.go @@ -44,7 +44,7 @@ func (v *StatelessJWTValidator) IntrospectToken(ctx context.Context, tokenString // AccessTokenJWTToRequest tries to reconstruct oauth2.Request from a JWT. func AccessTokenJWTToRequest(token *jwt.Token) oauth2.Requester { - mapClaims := token.Claims + mapClaims := token.Claims.ToMapClaims() claims := jwt.JWTClaims{} claims.FromMapClaims(mapClaims) diff --git a/handler/oauth2/strategy_jwt_profile.go b/handler/oauth2/strategy_jwt_profile.go index 113d164b..c553d5a5 100644 --- a/handler/oauth2/strategy_jwt_profile.go +++ b/handler/oauth2/strategy_jwt_profile.go @@ -171,7 +171,9 @@ func (s *JWTProfileCoreStrategy) GenerateJWT(ctx context.Context, tokenType oaut s.Config.GetJWTScopeField(ctx), ) - return s.Strategy.Encode(ctx, jwt.WithClaims(claims.ToMapClaims()), jwt.WithHeaders(header), jwt.WithJWTProfileAccessTokenClient(client)) + mapClaims := claims.ToMapClaims() + + return s.Strategy.Encode(ctx, mapClaims, jwt.WithHeaders(header), jwt.WithJWTProfileAccessTokenClient(client)) } func validateJWT(ctx context.Context, strategy jwt.Strategy, client jwt.Client, tokenString string) (token *jwt.Token, err error) { @@ -191,6 +193,7 @@ func fmtValidateJWTError(token *jwt.Token, client jwt.Client, inner error) (err clientText string sigKID, sigAlg string encKID, encAlg, enc string + date *jwt.NumericDate ) if client != nil { @@ -228,36 +231,35 @@ func fmtValidateJWTError(token *jwt.Token, client jwt.Client, inner error) (err case errJWTValidation.Has(jwt.ValidationErrorSignatureInvalid): return oauth2.ErrTokenSignatureMismatch.WithDebugf("Token %shas an invalid signature.", clientText) case errJWTValidation.Has(jwt.ValidationErrorExpired): - exp, ok := token.Claims.GetExpiresAt() - if ok { - return oauth2.ErrTokenExpired.WithDebugf("Token %sexpired at %d.", clientText, exp) + if date, err = token.Claims.GetExpirationTime(); err == nil { + return oauth2.ErrTokenExpired.WithDebugf("Token %sexpired at %d.", clientText, date.Int64()) } else { return oauth2.ErrTokenExpired.WithDebugf("Token %sdoes not have an 'exp' claim or it has an invalid type.", clientText) } case errJWTValidation.Has(jwt.ValidationErrorIssuedAt): - iat, ok := token.Claims.GetIssuedAt() - if ok { - return oauth2.ErrTokenClaim.WithDebugf("Token %sis issued in the future. The token was issued at %d.", clientText, iat) + if date, err = token.Claims.GetIssuedAt(); err == nil { + return oauth2.ErrTokenClaim.WithDebugf("Token %sis issued in the future. The token was issued at %d.", clientText, date.Int64()) } else { return oauth2.ErrTokenClaim.WithDebugf("Token %sis issued in the future. The token does not have an 'iat' claim or it has an invalid type.", clientText) } case errJWTValidation.Has(jwt.ValidationErrorNotValidYet): - nbf, ok := token.Claims.GetNotBefore() - if ok { - return oauth2.ErrTokenClaim.WithDebugf("Token %sis not valid yet. The token is not valid before %d.", clientText, nbf) + if date, err = token.Claims.GetNotBefore(); err == nil { + return oauth2.ErrTokenClaim.WithDebugf("Token %sis not valid yet. The token is not valid before %d.", clientText, date.Int64()) } else { return oauth2.ErrTokenClaim.WithDebugf("Token %sis not valid yet. The token does not have an 'nbf' claim or it has an invalid type.", clientText) } case errJWTValidation.Has(jwt.ValidationErrorIssuer): - iss, ok := token.Claims.GetIssuer() - if ok { + var iss string + + if iss, err = token.Claims.GetIssuer(); err == nil { return oauth2.ErrTokenClaim.WithDebugf("Token %shas an invalid issuer. The token was expected to have an 'iss' claim with one of the following values: ''. The 'iss' claim has a value of '%s'.", clientText, iss) } else { return oauth2.ErrTokenClaim.WithDebugf("Token %shas an invalid issuer. The token does not have an 'iss' claim or it has an invalid type.", clientText) } case errJWTValidation.Has(jwt.ValidationErrorAudience): - aud, ok := token.Claims.GetAudience() - if ok { + var aud jwt.ClaimStrings + + if aud, err = token.Claims.GetAudience(); err == nil { return oauth2.ErrTokenClaim.WithDebugf("Token %shas an invalid audience. The token was expected to have an 'iss' claim with one of the following values: ''. The 'aud' claim has a value of '%s'.", clientText, aud) } else { return oauth2.ErrTokenClaim.WithDebugf("Token %shas an invalid audience. The token does not have an 'aud' claim or it has an invalid type.", clientText) diff --git a/handler/openid/flow_explicit_token_test.go b/handler/openid/flow_explicit_token_test.go index 8c655bca..3993fb42 100644 --- a/handler/openid/flow_explicit_token_test.go +++ b/handler/openid/flow_explicit_token_test.go @@ -113,7 +113,7 @@ func TestExplicit_PopulateTokenEndpointResponse(t *testing.T) { return key.PublicKey, nil }) require.NoError(t, err) - claims := decodedIdToken.Claims + claims := decodedIdToken.Claims.ToMapClaims() assert.NotEmpty(t, claims["at_hash"]) idTokenExp := internal.ExtractJwtExpClaim(t, idToken) internal.RequireEqualTime(t, time.Now().Add(*internal.TestLifespans.AuthorizationCodeGrantIDTokenLifespan).UTC(), *idTokenExp, time.Minute) @@ -144,7 +144,7 @@ func TestExplicit_PopulateTokenEndpointResponse(t *testing.T) { return key.PublicKey, nil }) require.NoError(t, err) - claims := decodedIdToken.Claims + claims := decodedIdToken.Claims.ToMapClaims() assert.NotEmpty(t, claims["at_hash"]) idTokenExp := internal.ExtractJwtExpClaim(t, idToken) internal.RequireEqualTime(t, time.Now().Add(time.Hour), *idTokenExp, time.Minute) diff --git a/handler/openid/flow_refresh_token_test.go b/handler/openid/flow_refresh_token_test.go index 727a5aad..77e0bc0b 100644 --- a/handler/openid/flow_refresh_token_test.go +++ b/handler/openid/flow_refresh_token_test.go @@ -147,7 +147,7 @@ func TestOpenIDConnectRefreshHandler_PopulateTokenEndpointResponse(t *testing.T) return key.PublicKey, nil }) require.NoError(t, err) - claims := decodedIdToken.Claims + claims := decodedIdToken.Claims.ToMapClaims() assert.NotEmpty(t, claims[consts.ClaimAccessTokenHash]) idTokenExp := internal.ExtractJwtExpClaim(t, idToken) require.NotEmpty(t, idTokenExp) @@ -182,7 +182,7 @@ func TestOpenIDConnectRefreshHandler_PopulateTokenEndpointResponse(t *testing.T) return key.PublicKey, nil }) require.NoError(t, err) - claims := decodedIdToken.Claims + claims := decodedIdToken.Claims.ToMapClaims() assert.NotEmpty(t, claims[consts.ClaimAccessTokenHash]) idTokenExp := internal.ExtractJwtExpClaim(t, idToken) require.NotEmpty(t, idTokenExp) diff --git a/handler/openid/helper_test.go b/handler/openid/helper_test.go index b8041513..75b79e30 100644 --- a/handler/openid/helper_test.go +++ b/handler/openid/helper_test.go @@ -74,7 +74,6 @@ func TestGenerateIDToken(t *testing.T) { if err == nil { assert.NotEmpty(t, token, "(%d) %s", k, c.description) } - t.Logf("Passed test case %d", k) } } diff --git a/handler/openid/strategy_jwt.go b/handler/openid/strategy_jwt.go index 5ae1bfeb..646ff1b4 100644 --- a/handler/openid/strategy_jwt.go +++ b/handler/openid/strategy_jwt.go @@ -144,8 +144,9 @@ func (h DefaultStrategy) GenerateIDToken(ctx context.Context, lifespan time.Dura jwtClient := jwt.NewIDTokenClient(requester.GetClient()) if requester.GetRequestForm().Get(consts.FormParameterGrantType) != consts.GrantTypeRefreshToken { - maxAge, err := strconv.ParseInt(requester.GetRequestForm().Get(consts.FormParameterMaximumAge), 10, 64) - if err != nil { + var maxAge int64 + + if maxAge, err = strconv.ParseInt(requester.GetRequestForm().Get(consts.FormParameterMaximumAge), 10, 64); err != nil { maxAge = 0 } @@ -192,7 +193,10 @@ func (h DefaultStrategy) GenerateIDToken(ctx context.Context, lifespan time.Dura } if tokenHintString := requester.GetRequestForm().Get(consts.FormParameterIDTokenHint); tokenHintString != "" { - tokenHint, err := h.Strategy.Decode(ctx, tokenHintString, jwt.WithClient(jwtClient)) + var tokenHint *jwt.Token + + tokenHint, err = h.Strategy.Decode(ctx, tokenHintString, jwt.WithClient(jwtClient)) + var ve *jwt.ValidationError if errors.As(err, &ve) && ve.Has(jwt.ValidationErrorExpired) { // Expired ID Tokens are allowed as values to id_token_hint @@ -200,9 +204,11 @@ func (h DefaultStrategy) GenerateIDToken(ctx context.Context, lifespan time.Dura return "", errorsx.WithStack(oauth2.ErrServerError.WithWrap(err).WithDebugf("Unable to decode id token from 'id_token_hint' parameter because %s.", err.Error())) } - if hintSub, _ := tokenHint.Claims[consts.ClaimSubject].(string); hintSub == "" { + var subHint string + + if subHint, err = tokenHint.Claims.GetSubject(); subHint == "" || err != nil { return "", errorsx.WithStack(oauth2.ErrServerError.WithDebug("Provided id token from 'id_token_hint' does not have a subject.")) - } else if hintSub != claims.Subject { + } else if subHint != claims.Subject { return "", errorsx.WithStack(oauth2.ErrServerError.WithDebug("Subject from authorization mismatches id token subject from 'id_token_hint'.")) } } @@ -236,7 +242,7 @@ func (h DefaultStrategy) GenerateIDToken(ctx context.Context, lifespan time.Dura claims.Audience = stringslice.Unique(append(claims.Audience, requester.GetClient().GetID())) claims.IssuedAt = time.Now().UTC() - token, _, err = h.Strategy.Encode(ctx, jwt.WithClaims(claims.ToMapClaims()), jwt.WithHeaders(sess.IDTokenHeaders()), jwt.WithClient(jwtClient)) + token, _, err = h.Strategy.Encode(ctx, claims.ToMapClaims(), jwt.WithHeaders(sess.IDTokenHeaders()), jwt.WithClient(jwtClient)) return token, err } diff --git a/handler/openid/validator.go b/handler/openid/validator.go index e432f3a9..85c3789d 100644 --- a/handler/openid/validator.go +++ b/handler/openid/validator.go @@ -144,7 +144,10 @@ func (v *OpenIDConnectRequestValidator) ValidatePrompt(ctx context.Context, req return nil } - tokenHint, err := v.Strategy.Decode(ctx, idTokenHint, jwt.WithIDTokenClient(req.GetClient())) + var tokenHint *jwt.Token + + tokenHint, err = v.Strategy.Decode(ctx, idTokenHint, jwt.WithIDTokenClient(req.GetClient())) + var ve *jwt.ValidationError if errors.As(err, &ve) && ve.Has(jwt.ValidationErrorExpired) { // Expired tokens are ok @@ -152,9 +155,11 @@ func (v *OpenIDConnectRequestValidator) ValidatePrompt(ctx context.Context, req return errorsx.WithStack(oauth2.ErrInvalidRequest.WithHint("Failed to validate OpenID Connect request as decoding id token from id_token_hint parameter failed.").WithWrap(err).WithDebugError(err)) } - if hintSub, _ := tokenHint.Claims[consts.ClaimSubject].(string); hintSub == "" { + var subHint string + + if subHint, err = tokenHint.Claims.GetSubject(); subHint == "" || err != nil { return errorsx.WithStack(oauth2.ErrInvalidRequest.WithHint("Failed to validate OpenID Connect request because provided id token from id_token_hint does not have a subject.")) - } else if hintSub != claims.Subject { + } else if subHint != claims.Subject { return errorsx.WithStack(oauth2.ErrLoginRequired.WithHint("Failed to validate OpenID Connect request because the subject from provided id token from id_token_hint does not match the current session's subject.")) } diff --git a/handler/openid/validator_test.go b/handler/openid/validator_test.go index 5a9c805f..c493cf40 100644 --- a/handler/openid/validator_test.go +++ b/handler/openid/validator_test.go @@ -35,7 +35,7 @@ func TestValidatePrompt(t *testing.T) { v := NewOpenIDConnectRequestValidator(j, config) var genIDToken = func(c jwt.IDTokenClaims) string { - s, _, err := j.Encode(context.TODO(), jwt.WithClaims(c.ToMapClaims())) + s, _, err := j.Encode(context.TODO(), c.ToMapClaims()) require.NoError(t, err) return s } @@ -251,7 +251,6 @@ func TestValidatePrompt(t *testing.T) { }, } { t.Run(fmt.Sprintf("case=%d/description=%s", k, tc.d), func(t *testing.T) { - t.Logf("%s", tc.idTokenHint) err := v.ValidatePrompt(context.TODO(), &oauth2.AuthorizeRequest{ Request: oauth2.Request{ Form: url.Values{"prompt": {tc.prompt}, "id_token_hint": {tc.idTokenHint}}, diff --git a/handler/rfc8628/token_endpoint_handler_test.go b/handler/rfc8628/token_endpoint_handler_test.go index 514a1c38..90e1703d 100644 --- a/handler/rfc8628/token_endpoint_handler_test.go +++ b/handler/rfc8628/token_endpoint_handler_test.go @@ -458,8 +458,6 @@ func TestDeviceAuthorizeCode_HandleTokenEndpointRequest(t *testing.T) { c.setup(t, c.areq, c.authreq) } - t.Logf("Processing %+v", c.areq.Client) - err := h.HandleTokenEndpointRequest(context.Background(), c.areq) if c.expectErr != nil { require.EqualError(t, err, c.expectErr.Error(), "%+v", err) diff --git a/handler/rfc8693/custom_jwt_type_handler.go b/handler/rfc8693/custom_jwt_type_handler.go index 297c5349..18b4f392 100644 --- a/handler/rfc8693/custom_jwt_type_handler.go +++ b/handler/rfc8693/custom_jwt_type_handler.go @@ -123,7 +123,7 @@ func (c *CustomJWTTypeHandler) validate(ctx context.Context, _ oauth2.AccessRequ if window == 0 { window = 1 * time.Hour } - claims := ftoken.Claims + claims := ftoken.Claims.ToMapClaims() if issued, exists := claims[consts.ClaimIssuedAt]; exists { if time.Unix(toInt64(issued), 0).Add(window).Before(time.Now()) { @@ -203,7 +203,7 @@ func (c *CustomJWTTypeHandler) issue(ctx context.Context, request oauth2.AccessR claims.IssuedAt = time.Now().UTC() - token, _, err := c.Strategy.Encode(ctx, jwt.WithClaims(claims.ToMapClaims()), jwt.WithHeaders(sess.IDTokenHeaders()), jwt.WithIDTokenClient(request.GetClient())) + token, _, err := c.Strategy.Encode(ctx, claims.ToMapClaims(), jwt.WithHeaders(sess.IDTokenHeaders()), jwt.WithIDTokenClient(request.GetClient())) if err != nil { return err } diff --git a/handler/rfc8693/token_exchange_test.go b/handler/rfc8693/token_exchange_test.go index b34f1aff..3dbf3d19 100644 --- a/handler/rfc8693/token_exchange_test.go +++ b/handler/rfc8693/token_exchange_test.go @@ -266,7 +266,7 @@ func createAccessToken(ctx context.Context, coreStrategy hoauth2.CoreStrategy, s } func createJWT(ctx context.Context, client any, strategy jwt.Strategy, claims jwt.MapClaims) string { - token, _, err := strategy.Encode(ctx, jwt.WithClaims(claims), jwt.WithIDTokenClient(client)) + token, _, err := strategy.Encode(ctx, claims, jwt.WithIDTokenClient(client)) if err != nil { panic(err.Error()) diff --git a/helper_test.go b/helper_test.go index 0c68a9fa..fe5839ae 100644 --- a/helper_test.go +++ b/helper_test.go @@ -25,7 +25,6 @@ func TestStringInSlice(t *testing.T) { {needle: "foo", haystack: []string{}, ok: false}, } { assert.Equal(t, c.ok, StringInSlice(c.needle, c.haystack), "%d", k) - t.Logf("Passed test case %d", k) } } diff --git a/integration/authorize_code_grant_public_client_pkce_test.go b/integration/authorize_code_grant_public_client_pkce_test.go index 0d72a973..8e1a3525 100644 --- a/integration/authorize_code_grant_public_client_pkce_test.go +++ b/integration/authorize_code_grant_public_client_pkce_test.go @@ -79,19 +79,11 @@ func runAuthorizeCodeGrantWithPublicClientAndPKCETest(t *testing.T, strategy any t.Run(fmt.Sprintf("case=%d/description=%s", k, c.description), func(t *testing.T) { c.setup() - t.Logf("Got url: %s", authCodeUrl) - resp, err := http.Get(authCodeUrl) //nolint:gosec require.NoError(t, err) require.Equal(t, resp.StatusCode, c.authStatusCode) if resp.StatusCode == http.StatusOK { - // This should fail because no verifier was given - // _, err := oauthClient.Exchange(xoauth2.NoContext, resp.Request.URL.Query().Get(consts.FormParameterAuthorizationCode)) - // require.Error(t, err) - // require.Empty(t, token.AccessToken) - t.Logf("Got redirect url: %s", resp.Request.URL) - resp, err := http.PostForm(ts.URL+"/token", url.Values{ consts.FormParameterAuthorizationCode: {resp.Request.URL.Query().Get(consts.FormParameterAuthorizationCode)}, consts.FormParameterGrantType: {consts.GrantTypeAuthorizationCode}, diff --git a/integration/client_credentials_grant_test.go b/integration/client_credentials_grant_test.go index 261eb8fc..d38f8449 100644 --- a/integration/client_credentials_grant_test.go +++ b/integration/client_credentials_grant_test.go @@ -145,8 +145,6 @@ func runClientCredentialsGrantTest(t *testing.T, strategy hoauth2.AccessTokenStr if c.check != nil { c.check(t, token) } - - t.Logf("Passed test case %d", k) }) } } diff --git a/integration/introspect_token_test.go b/integration/introspect_token_test.go index 6d1ffa46..3c6de835 100644 --- a/integration/introspect_token_test.go +++ b/integration/introspect_token_test.go @@ -52,7 +52,6 @@ func TestIntrospectToken(t *testing.T) { factory: compose.OAuth2StatelessJWTIntrospectionFactory, }, } { - t.Logf("testing %v", c.description) runIntrospectTokenTest(t, c.strategy, c.factory) } } @@ -124,7 +123,6 @@ func runIntrospectTokenTest(t *testing.T, strategy hoauth2.AccessTokenStrategy, _, bytes, errs := c.prepare(s).End() assert.Nil(t, json.Unmarshal([]byte(bytes), &res)) - t.Logf("Got answer: %s", bytes) assert.Len(t, errs, 0) assert.Equal(t, c.isActive, res.Active) diff --git a/integration/resource_owner_password_credentials_grant_test.go b/integration/resource_owner_password_credentials_grant_test.go index ec235b02..ac024223 100644 --- a/integration/resource_owner_password_credentials_grant_test.go +++ b/integration/resource_owner_password_credentials_grant_test.go @@ -94,7 +94,5 @@ func runResourceOwnerPasswordCredentialsGrantTest(t *testing.T, strategy hoauth2 c.check(t, token) } } - - t.Logf("Passed test case %d", k) } } diff --git a/introspection_response_writer.go b/introspection_response_writer.go index ac65e3e0..99c7f2cb 100644 --- a/introspection_response_writer.go +++ b/introspection_response_writer.go @@ -272,7 +272,7 @@ func (f *Fosite) writeIntrospectionResponse(ctx context.Context, rw http.Respons return } - claims := map[string]any{ + claims := jwt.MapClaims{ consts.ClaimJWTID: jti.String(), consts.ClaimIssuer: f.Config.GetIntrospectionIssuer(ctx), consts.ClaimIssuedAt: time.Now().UTC().Unix(), @@ -291,7 +291,7 @@ func (f *Fosite) writeIntrospectionResponse(ctx context.Context, rw http.Respons return } - if token, _, err = strategy.Encode(ctx, jwt.WithClaims(claims), jwt.WithHeaders(header), jwt.WithIntrospectionClient(r.GetAccessRequester().GetClient())); err != nil { + if token, _, err = strategy.Encode(ctx, claims, jwt.WithHeaders(header), jwt.WithIntrospectionClient(r.GetAccessRequester().GetClient())); err != nil { f.WriteIntrospectionError(ctx, rw, errors.WithStack(ErrServerError.WithHint("Failed to generate the response.").WithDebugf("The Introspection JWT itself could not be generated with error %+v.", err))) return diff --git a/pushed_authorize_response_writer_test.go b/pushed_authorize_response_writer_test.go index cff8428f..b92e781f 100644 --- a/pushed_authorize_response_writer_test.go +++ b/pushed_authorize_response_writer_test.go @@ -57,6 +57,5 @@ func TestNewPushedAuthorizeResponse(t *testing.T) { } else { assert.NotNil(t, responder, "%d", k) } - t.Logf("Passed test case %d", k) } } diff --git a/rfc8628_device_authorize_response_writer_test.go b/rfc8628_device_authorize_response_writer_test.go index 798e6fea..884ce975 100644 --- a/rfc8628_device_authorize_response_writer_test.go +++ b/rfc8628_device_authorize_response_writer_test.go @@ -68,6 +68,5 @@ func TestNewDeviceResponse(t *testing.T) { } else { assert.NotNil(t, responder, "%d", k) } - t.Logf("Passed test case %d", k) } } diff --git a/rfc8628_user_authorize_request_handler_test.go b/rfc8628_user_authorize_request_handler_test.go index 6fbb5242..4887178c 100644 --- a/rfc8628_user_authorize_request_handler_test.go +++ b/rfc8628_user_authorize_request_handler_test.go @@ -75,6 +75,5 @@ func TestFosite_NewRFC8628UserAuthorizeRequest(t *testing.T) { assert.NotNil(t, resp, "%d", k) assert.Equal(t, req.Form, resp.GetRequestForm()) } - t.Logf("Passed test case %d", k) } } diff --git a/rfc8628_user_authorize_response_writer_test.go b/rfc8628_user_authorize_response_writer_test.go index f17296bc..46861116 100644 --- a/rfc8628_user_authorize_response_writer_test.go +++ b/rfc8628_user_authorize_response_writer_test.go @@ -71,6 +71,5 @@ func TestFosite_NewRFC8628UserAuthorizeResponse(t *testing.T) { } else { assert.NotNil(t, responder, "%d", k) } - t.Logf("Passed test case %d", k) } } diff --git a/session.go b/session.go index ea691175..9b1faa9a 100644 --- a/session.go +++ b/session.go @@ -19,7 +19,7 @@ type Session interface { // GetExpiresAt returns the expiration time of a token if set, or time.IsZero() if not. // - // session.GetExpiresAt(oauth2.AccessToken) + // session.GetExpiresTimeX(oauth2.AccessToken) GetExpiresAt(key TokenType) time.Time // GetUsername returns the username, if set. This is optional and only used during token introspection. diff --git a/testing/mock/client.go b/testing/mock/client.go index 7c88ba6c..a665c669 100644 --- a/testing/mock/client.go +++ b/testing/mock/client.go @@ -42,7 +42,7 @@ func (m *MockClient) EXPECT() *MockClientMockRecorder { // GetAudience mocks base method. func (m *MockClient) GetAudience() oauth2.Arguments { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "GetAudience") + ret := m.ctrl.Call(m, "GetAudienceX") ret0, _ := ret[0].(oauth2.Arguments) return ret0 } @@ -50,7 +50,7 @@ func (m *MockClient) GetAudience() oauth2.Arguments { // GetAudience indicates an expected call of GetAudience. func (mr *MockClientMockRecorder) GetAudience() *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetAudience", reflect.TypeOf((*MockClient)(nil).GetAudience)) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetAudienceX", reflect.TypeOf((*MockClient)(nil).GetAudience)) } // GetClientSecret mocks base method. diff --git a/token/hmac/hmacsha_test.go b/token/hmac/hmacsha_test.go index a064a2f1..63b8a0c8 100644 --- a/token/hmac/hmacsha_test.go +++ b/token/hmac/hmacsha_test.go @@ -65,7 +65,7 @@ func TestValidateSignatureRejects(t *testing.T) { cg := HMACStrategy{ Config: &oauth2.Config{GlobalSecret: []byte("1234567890123456789012345678901234567890")}, } - for k, c := range []string{ + for _, c := range []string{ "", " ", "foo.bar", @@ -74,7 +74,6 @@ func TestValidateSignatureRejects(t *testing.T) { } { err = cg.Validate(context.Background(), c) assert.Error(t, err) - t.Logf("Passed test case %d", k) } } diff --git a/token/jarm/generate.go b/token/jarm/generate.go index edd94ba5..23321b65 100644 --- a/token/jarm/generate.go +++ b/token/jarm/generate.go @@ -79,5 +79,5 @@ func Generate(ctx context.Context, config Configurator, client Client, session a return "", "", errors.New("The JARM response modes require the JWTSecuredAuthorizeResponseModeSignerProvider to return a jwt.Strategy but it didn't.") } - return signer.Encode(ctx, jwt.WithClaims(claims.ToMapClaims()), jwt.WithHeaders(&jwt.Headers{Extra: headers}), jwt.WithJARMClient(client)) + return signer.Encode(ctx, claims.ToMapClaims(), jwt.WithHeaders(&jwt.Headers{Extra: headers}), jwt.WithJARMClient(client)) } diff --git a/token/jwt/claims.go b/token/jwt/claims.go index 501e1ad3..36719a70 100644 --- a/token/jwt/claims.go +++ b/token/jwt/claims.go @@ -7,6 +7,17 @@ import ( "time" ) +type Claims interface { + GetExpirationTime() (exp *NumericDate, err error) + GetIssuedAt() (iat *NumericDate, err error) + GetNotBefore() (nbf *NumericDate, err error) + GetIssuer() (iss string, err error) + GetSubject() (sub string, err error) + GetAudience() (aud ClaimStrings, err error) + ToMapClaims() MapClaims + Valid(opts ...ClaimValidationOption) (err error) +} + // Mapper is the interface used internally to map key-value pairs type Mapper interface { ToMap() map[string]any diff --git a/token/jwt/claims_id_token.go b/token/jwt/claims_id_token.go index 255508bf..aa45f8fd 100644 --- a/token/jwt/claims_id_token.go +++ b/token/jwt/claims_id_token.go @@ -50,13 +50,13 @@ func (c *IDTokenClaims) UnmarshalJSON(data []byte) error { ok = false switch claim { - case consts.ClaimJWTID: + case ClaimJWTID: c.JTI, ok = value.(string) - case consts.ClaimIssuer: + case ClaimIssuer: c.Issuer, ok = value.(string) - case consts.ClaimSubject: + case ClaimSubject: c.Subject, ok = value.(string) - case consts.ClaimAudience: + case ClaimAudience: switch aud := value.(type) { case string: ok = true @@ -81,21 +81,21 @@ func (c *IDTokenClaims) UnmarshalJSON(data []byte) error { } } } - case consts.ClaimNonce: + case ClaimNonce: c.Nonce, ok = value.(string) - case consts.ClaimExpirationTime: + case ClaimExpirationTime: c.ExpiresAt, ok = toTime(value, c.ExpiresAt) - case consts.ClaimIssuedAt: + case ClaimIssuedAt: c.IssuedAt, ok = toTime(value, c.IssuedAt) - case consts.ClaimRequestedAt: + case ClaimRequestedAt: c.RequestedAt, ok = toTime(value, c.RequestedAt) - case consts.ClaimAuthenticationTime: + case ClaimAuthenticationTime: c.AuthTime, ok = toTime(value, c.AuthTime) - case consts.ClaimCodeHash: + case ClaimCodeHash: c.CodeHash, ok = value.(string) - case consts.ClaimStateHash: + case ClaimStateHash: c.StateHash, ok = value.(string) - case consts.ClaimAuthenticationContextClassReference: + case ClaimAuthenticationContextClassReference: c.AuthenticationContextClassReference, ok = value.(string) default: if c.Extra == nil { @@ -120,13 +120,13 @@ func (c *IDTokenClaims) ToMap() map[string]any { var ret = Copy(c.Extra) if c.Subject != "" { - ret[consts.ClaimSubject] = c.Subject + ret[ClaimSubject] = c.Subject } else { - delete(ret, consts.ClaimSubject) + delete(ret, ClaimSubject) } if c.Issuer != "" { - ret[consts.ClaimIssuer] = c.Issuer + ret[ClaimIssuer] = c.Issuer } else { delete(ret, consts.ClaimIssuer) } diff --git a/token/jwt/claims_id_token_test.go b/token/jwt/claims_id_token_test.go index 0fecf56f..66801dc4 100644 --- a/token/jwt/claims_id_token_test.go +++ b/token/jwt/claims_id_token_test.go @@ -19,7 +19,7 @@ func TestIDTokenAssert(t *testing.T) { assert.Error(t, (&IDTokenClaims{ExpiresAt: time.Now().UTC().Add(-time.Hour)}). ToMapClaims().Valid()) - assert.NotEmpty(t, (new(IDTokenClaims)).ToMapClaims()[consts.ClaimJWTID]) + assert.NotEmpty(t, (new(IDTokenClaims)).ToMapClaims()[ClaimJWTID]) } func TestIDTokenClaimsToMap(t *testing.T) { @@ -43,18 +43,18 @@ func TestIDTokenClaimsToMap(t *testing.T) { }, } assert.Equal(t, map[string]any{ - consts.ClaimJWTID: idTokenClaims.JTI, - consts.ClaimSubject: idTokenClaims.Subject, - consts.ClaimIssuedAt: idTokenClaims.IssuedAt.Unix(), - consts.ClaimIssuer: idTokenClaims.Issuer, - consts.ClaimAudience: idTokenClaims.Audience, - consts.ClaimExpirationTime: idTokenClaims.ExpiresAt.Unix(), - "foo": idTokenClaims.Extra["foo"], - "baz": idTokenClaims.Extra["baz"], - consts.ClaimAccessTokenHash: idTokenClaims.AccessTokenHash, - consts.ClaimCodeHash: idTokenClaims.CodeHash, - consts.ClaimStateHash: idTokenClaims.StateHash, - consts.ClaimAuthenticationTime: idTokenClaims.AuthTime.Unix(), + ClaimJWTID: idTokenClaims.JTI, + ClaimSubject: idTokenClaims.Subject, + ClaimIssuedAt: idTokenClaims.IssuedAt.Unix(), + ClaimIssuer: idTokenClaims.Issuer, + ClaimAudience: idTokenClaims.Audience, + ClaimExpirationTime: idTokenClaims.ExpiresAt.Unix(), + "foo": idTokenClaims.Extra["foo"], + "baz": idTokenClaims.Extra["baz"], + ClaimAccessTokenHash: idTokenClaims.AccessTokenHash, + ClaimCodeHash: idTokenClaims.CodeHash, + ClaimStateHash: idTokenClaims.StateHash, + ClaimAuthenticationTime: idTokenClaims.AuthTime.Unix(), consts.ClaimAuthenticationContextClassReference: idTokenClaims.AuthenticationContextClassReference, consts.ClaimAuthenticationMethodsReference: idTokenClaims.AuthenticationMethodsReferences, }, idTokenClaims.ToMap()) diff --git a/token/jwt/claims_jarm.go b/token/jwt/claims_jarm.go index 8557c19d..5328cf5e 100644 --- a/token/jwt/claims_jarm.go +++ b/token/jwt/claims_jarm.go @@ -23,33 +23,33 @@ func (c *JARMClaims) ToMap() map[string]any { var ret = Copy(c.Extra) if c.Issuer != "" { - ret[consts.ClaimIssuer] = c.Issuer + ret[ClaimIssuer] = c.Issuer } else { - delete(ret, consts.ClaimIssuer) + delete(ret, ClaimIssuer) } if c.JTI != "" { - ret[consts.ClaimJWTID] = c.JTI + ret[ClaimJWTID] = c.JTI } else { - ret[consts.ClaimJWTID] = uuid.New().String() + ret[ClaimJWTID] = uuid.New().String() } if len(c.Audience) > 0 { - ret[consts.ClaimAudience] = c.Audience + ret[ClaimAudience] = c.Audience } else { - ret[consts.ClaimAudience] = []string{} + ret[ClaimAudience] = []string{} } if !c.IssuedAt.IsZero() { - ret[consts.ClaimIssuedAt] = c.IssuedAt.Unix() + ret[ClaimIssuedAt] = c.IssuedAt.Unix() } else { - delete(ret, consts.ClaimIssuedAt) + delete(ret, ClaimIssuedAt) } if !c.ExpiresAt.IsZero() { - ret[consts.ClaimExpirationTime] = c.ExpiresAt.Unix() + ret[ClaimExpirationTime] = c.ExpiresAt.Unix() } else { - delete(ret, consts.ClaimExpirationTime) + delete(ret, ClaimExpirationTime) } return ret @@ -60,15 +60,15 @@ func (c *JARMClaims) FromMap(m map[string]any) { c.Extra = make(map[string]any) for k, v := range m { switch k { - case consts.ClaimIssuer: + case ClaimIssuer: if s, ok := v.(string); ok { c.Issuer = s } - case consts.ClaimJWTID: + case ClaimJWTID: if s, ok := v.(string); ok { c.JTI = s } - case consts.ClaimAudience: + case ClaimAudience: if aud, ok := StringSliceFromMap(v); ok { c.Audience = aud } diff --git a/token/jwt/claims_jarm_test.go b/token/jwt/claims_jarm_test.go index 941236d7..c35873f6 100644 --- a/token/jwt/claims_jarm_test.go +++ b/token/jwt/claims_jarm_test.go @@ -9,7 +9,6 @@ import ( "github.com/stretchr/testify/assert" - "authelia.com/provider/oauth2/internal/consts" . "authelia.com/provider/oauth2/token/jwt" ) @@ -26,13 +25,13 @@ var jarmClaims = &JARMClaims{ } var jarmClaimsMap = map[string]any{ - consts.ClaimIssuer: jwtClaims.Issuer, - consts.ClaimAudience: jwtClaims.Audience, - consts.ClaimJWTID: jwtClaims.JTI, - consts.ClaimIssuedAt: jwtClaims.IssuedAt.Unix(), - consts.ClaimExpirationTime: jwtClaims.ExpiresAt.Unix(), - "foo": jwtClaims.Extra["foo"], - "baz": jwtClaims.Extra["baz"], + ClaimIssuer: jwtClaims.Issuer, + ClaimAudience: jwtClaims.Audience, + ClaimJWTID: jwtClaims.JTI, + ClaimIssuedAt: jwtClaims.IssuedAt.Unix(), + ClaimExpirationTime: jwtClaims.ExpiresAt.Unix(), + "foo": jwtClaims.Extra["foo"], + "baz": jwtClaims.Extra["baz"], } func TestJARMClaimAddGetString(t *testing.T) { @@ -41,7 +40,7 @@ func TestJARMClaimAddGetString(t *testing.T) { } func TestJARMClaimsToMapSetsID(t *testing.T) { - assert.NotEmpty(t, (&JARMClaims{}).ToMap()[consts.ClaimJWTID]) + assert.NotEmpty(t, (&JARMClaims{}).ToMap()[ClaimJWTID]) } func TestJARMAssert(t *testing.T) { diff --git a/token/jwt/claims_jwt.go b/token/jwt/claims_jwt.go index 7efdbb5f..ecaf5aea 100644 --- a/token/jwt/claims_jwt.go +++ b/token/jwt/claims_jwt.go @@ -101,58 +101,58 @@ func (c *JWTClaims) ToMap() map[string]any { var ret = Copy(c.Extra) if c.Subject != "" { - ret[consts.ClaimSubject] = c.Subject + ret[ClaimSubject] = c.Subject } else { - delete(ret, consts.ClaimSubject) + delete(ret, ClaimSubject) } if c.Issuer != "" { - ret[consts.ClaimIssuer] = c.Issuer + ret[ClaimIssuer] = c.Issuer } else { - delete(ret, consts.ClaimIssuer) + delete(ret, ClaimIssuer) } if c.JTI != "" { - ret[consts.ClaimJWTID] = c.JTI + ret[ClaimJWTID] = c.JTI } else { - ret[consts.ClaimJWTID] = uuid.New().String() + ret[ClaimJWTID] = uuid.New().String() } if len(c.Audience) > 0 { - ret[consts.ClaimAudience] = c.Audience + ret[ClaimAudience] = c.Audience } else { - ret[consts.ClaimAudience] = []string{} + ret[ClaimAudience] = []string{} } if !c.IssuedAt.IsZero() { - ret[consts.ClaimIssuedAt] = c.IssuedAt.Unix() + ret[ClaimIssuedAt] = c.IssuedAt.Unix() } else { - delete(ret, consts.ClaimIssuedAt) + delete(ret, ClaimIssuedAt) } if !c.NotBefore.IsZero() { - ret[consts.ClaimNotBefore] = c.NotBefore.Unix() + ret[ClaimNotBefore] = c.NotBefore.Unix() } else { - delete(ret, consts.ClaimNotBefore) + delete(ret, ClaimNotBefore) } if !c.ExpiresAt.IsZero() { - ret[consts.ClaimExpirationTime] = c.ExpiresAt.Unix() + ret[ClaimExpirationTime] = c.ExpiresAt.Unix() } else { - delete(ret, consts.ClaimExpirationTime) + delete(ret, ClaimExpirationTime) } if c.Scope != nil { // ScopeField default (when value is JWTScopeFieldUnset) is the list for backwards compatibility with old versions of oauth2. if c.ScopeField == JWTScopeFieldUnset || c.ScopeField == JWTScopeFieldList || c.ScopeField == JWTScopeFieldBoth { - ret[consts.ClaimScopeNonStandard] = c.Scope + ret[ClaimScopeNonStandard] = c.Scope } if c.ScopeField == JWTScopeFieldString || c.ScopeField == JWTScopeFieldBoth { - ret[consts.ClaimScope] = strings.Join(c.Scope, " ") + ret[ClaimScope] = strings.Join(c.Scope, " ") } } else { - delete(ret, consts.ClaimScopeNonStandard) - delete(ret, consts.ClaimScope) + delete(ret, ClaimScopeNonStandard) + delete(ret, ClaimScope) } return ret @@ -267,6 +267,41 @@ func toInt64(v any) (val int64, ok bool) { return 0, false } +func toNumericDate(v any) (date *NumericDate, err error) { + switch value := v.(type) { + case float64: + if value == 0 { + return nil, nil + } + + return newNumericDateFromSeconds(value), nil + case int64: + if value == 0 { + return nil, nil + } + + return newNumericDateFromSeconds(float64(value)), nil + case int32: + if value == 0 { + return nil, nil + } + + return newNumericDateFromSeconds(float64(value)), nil + case int: + if value == 0 { + return nil, nil + } + + return newNumericDateFromSeconds(float64(value)), nil + case json.Number: + vv, _ := value.Float64() + + return newNumericDateFromSeconds(vv), nil + } + + return nil, newError("value has invalid type", ErrInvalidType) +} + // Add will add a key-value pair to the extra field func (c *JWTClaims) Add(key string, value any) { if c.Extra == nil { diff --git a/token/jwt/claims_jwt_test.go b/token/jwt/claims_jwt_test.go index 3b612140..653ea8ef 100644 --- a/token/jwt/claims_jwt_test.go +++ b/token/jwt/claims_jwt_test.go @@ -30,16 +30,16 @@ var jwtClaims = &JWTClaims{ } var jwtClaimsMap = map[string]any{ - consts.ClaimSubject: jwtClaims.Subject, - consts.ClaimIssuedAt: jwtClaims.IssuedAt.Unix(), - consts.ClaimIssuer: jwtClaims.Issuer, - consts.ClaimNotBefore: jwtClaims.NotBefore.Unix(), - consts.ClaimAudience: jwtClaims.Audience, - consts.ClaimExpirationTime: jwtClaims.ExpiresAt.Unix(), - consts.ClaimJWTID: jwtClaims.JTI, - consts.ClaimScopeNonStandard: []string{consts.ScopeEmail, consts.ScopeOffline}, - "foo": jwtClaims.Extra["foo"], - "baz": jwtClaims.Extra["baz"], + ClaimSubject: jwtClaims.Subject, + ClaimIssuedAt: jwtClaims.IssuedAt.Unix(), + ClaimIssuer: jwtClaims.Issuer, + ClaimNotBefore: jwtClaims.NotBefore.Unix(), + ClaimAudience: jwtClaims.Audience, + ClaimExpirationTime: jwtClaims.ExpiresAt.Unix(), + ClaimJWTID: jwtClaims.JTI, + ClaimScopeNonStandard: []string{consts.ScopeEmail, consts.ScopeOffline}, + "foo": jwtClaims.Extra["foo"], + "baz": jwtClaims.Extra["baz"], } func TestClaimAddGetString(t *testing.T) { @@ -48,7 +48,7 @@ func TestClaimAddGetString(t *testing.T) { } func TestClaimsToMapSetsID(t *testing.T) { - assert.NotEmpty(t, (&JWTClaims{}).ToMap()[consts.ClaimJWTID]) + assert.NotEmpty(t, (&JWTClaims{}).ToMap()[ClaimJWTID]) } func TestAssert(t *testing.T) { @@ -78,8 +78,8 @@ func TestScopeFieldString(t *testing.T) { jwtClaimsWithString := jwtClaims.WithScopeField(JWTScopeFieldString) // Making a copy of jwtClaimsMap. jwtClaimsMapWithString := jwtClaims.ToMap() - delete(jwtClaimsMapWithString, consts.ClaimScopeNonStandard) - jwtClaimsMapWithString[consts.ClaimScope] = "email offline" + delete(jwtClaimsMapWithString, ClaimScopeNonStandard) + jwtClaimsMapWithString[ClaimScope] = "email offline" assert.Equal(t, jwtClaimsMapWithString, map[string]any(jwtClaimsWithString.ToMapClaims())) var claims JWTClaims claims.FromMap(jwtClaimsMapWithString) @@ -90,7 +90,7 @@ func TestScopeFieldBoth(t *testing.T) { jwtClaimsWithBoth := jwtClaims.WithScopeField(JWTScopeFieldBoth) // Making a copy of jwtClaimsMap jwtClaimsMapWithBoth := jwtClaims.ToMap() - jwtClaimsMapWithBoth[consts.ClaimScope] = "email offline" + jwtClaimsMapWithBoth[ClaimScope] = "email offline" assert.Equal(t, jwtClaimsMapWithBoth, map[string]any(jwtClaimsWithBoth.ToMapClaims())) var claims JWTClaims claims.FromMap(jwtClaimsMapWithBoth) diff --git a/token/jwt/claims_map.go b/token/jwt/claims_map.go index 844ee7de..4e1620b3 100644 --- a/token/jwt/claims_map.go +++ b/token/jwt/claims_map.go @@ -7,90 +7,83 @@ import ( "bytes" "crypto/subtle" "errors" + "fmt" "time" jjson "github.com/go-jose/go-jose/v4/json" - "authelia.com/provider/oauth2/internal/consts" "authelia.com/provider/oauth2/x/errorsx" ) -var TimeFunc = time.Now - -// MapClaims provides backwards compatible validations not available in `go-jose`. -// It was taken from [here](https://raw.githubusercontent.com/form3tech-oss/jwt-go/master/map_claims.go). -// -// Claims type that uses the map[string]any for JSON decoding -// This is the default claims type if you don't supply one +// MapClaims is a simple map based claims structure. type MapClaims map[string]any -// GetIssuer returns the iss claim. -func (m MapClaims) GetIssuer() (iss string, ok bool) { - var v any - - if v, ok = m[consts.ClaimIssuer]; !ok { - return "", false - } - - iss, ok = v.(string) - - return iss, ok +// GetIssuer returns the 'iss' claim. +func (m MapClaims) GetIssuer() (iss string, err error) { + return m.toString(ClaimIssuer) } // VerifyIssuer compares the iss claim against cmp. // If required is false, this method will return true if the value matches or is unset func (m MapClaims) VerifyIssuer(cmp string, required bool) (ok bool) { - var iss string + var ( + iss string + err error + ) - if iss, ok = m.GetIssuer(); !ok { - return !required + if iss, err = m.GetIssuer(); err != nil { + return false } - return verifyMapString(iss, cmp, required) -} - -// GetSubject returns the sub claim. -func (m MapClaims) GetSubject() (sub string, ok bool) { - var v any - - if v, ok = m[consts.ClaimSubject]; !ok { - return "", false + if iss == "" { + return !required } - sub, ok = v.(string) + return verifyString(iss, cmp, required) +} - return sub, ok +// GetSubject returns the 'sub' claim. +func (m MapClaims) GetSubject() (sub string, err error) { + return m.toString(ClaimSubject) } // VerifySubject compares the iss claim against cmp. // If required is false, this method will return true if the value matches or is unset func (m MapClaims) VerifySubject(cmp string, required bool) (ok bool) { - var sub string + var ( + sub string + err error + ) + + if sub, err = m.GetSubject(); err != nil { + return false + } - if sub, ok = m.GetSubject(); !ok { + if sub == "" { return !required } - return verifyMapString(sub, cmp, required) + return verifyString(sub, cmp, required) } -// GetAudience returns the aud claim. -func (m MapClaims) GetAudience() (aud []string, ok bool) { - var v any - - if v, ok = m[consts.ClaimAudience]; !ok { - return nil, false - } - - return StringSliceFromMap(v) +// GetAudience returns the 'aud' claim. +func (m MapClaims) GetAudience() (aud ClaimStrings, err error) { + return m.toClaimsString(ClaimAudience) } // VerifyAudience compares the aud claim against cmp. // If required is false, this method will return true if the value matches or is unset func (m MapClaims) VerifyAudience(cmp string, required bool) (ok bool) { - var aud []string + var ( + aud ClaimStrings + err error + ) - if aud, ok = m.GetAudience(); !ok { + if aud, err = m.GetAudience(); err != nil { + return false + } + + if aud == nil { return !required } @@ -101,9 +94,16 @@ func (m MapClaims) VerifyAudience(cmp string, required bool) (ok bool) { // If required is false, this method will return true if the value matches or is unset. // This variant requires all of the audience values in the cmp. func (m MapClaims) VerifyAudienceAll(cmp []string, required bool) (ok bool) { - var aud []string + var ( + aud ClaimStrings + err error + ) + + if aud, err = m.GetAudience(); err != nil { + return false + } - if aud, ok = m.GetAudience(); !ok { + if aud == nil { return !required } @@ -114,64 +114,104 @@ func (m MapClaims) VerifyAudienceAll(cmp []string, required bool) (ok bool) { // If required is false, this method will return true if the value matches or is unset. // This variant requires any of the audience values in the cmp. func (m MapClaims) VerifyAudienceAny(cmp []string, required bool) (ok bool) { - var aud []string + var ( + aud ClaimStrings + err error + ) + + if aud, err = m.GetAudience(); err != nil { + return false + } - if aud, ok = m.GetAudience(); !ok { + if aud == nil { return !required } return verifyAudAny(aud, cmp, required) } -// GetExpiresAt returns the exp claim. -func (m MapClaims) GetExpiresAt() (exp int64, ok bool) { - return m.toInt64(consts.ClaimExpirationTime) +// GetExpirationTime returns the 'exp' claim. +func (m MapClaims) GetExpirationTime() (exp *NumericDate, err error) { + return m.toNumericDate(ClaimExpirationTime) } -// VerifyExpiresAt compares the exp claim against cmp. +// VerifyExpirationTime compares the exp claim against cmp. // If required is false, this method will return true if the value matches or is unset -func (m MapClaims) VerifyExpiresAt(cmp int64, required bool) (ok bool) { - var exp int64 +func (m MapClaims) VerifyExpirationTime(cmp int64, required bool) (ok bool) { + var ( + exp *NumericDate + err error + ) + + if exp, err = m.GetExpirationTime(); err != nil { + return false + } - if exp, ok = m.GetExpiresAt(); !ok { + if exp == nil { return !required } - return verifyInt64Future(exp, cmp, required) + return verifyInt64Future(exp.Int64(), cmp, required) } -// GetIssuedAt returns the iat claim. -func (m MapClaims) GetIssuedAt() (iat int64, ok bool) { - return m.toInt64(consts.ClaimIssuedAt) +// GetIssuedAt returns the 'iat' claim. +func (m MapClaims) GetIssuedAt() (iat *NumericDate, err error) { + return m.toNumericDate(ClaimIssuedAt) } // VerifyIssuedAt compares the iat claim against cmp. // If required is false, this method will return true if the value matches or is unset func (m MapClaims) VerifyIssuedAt(cmp int64, required bool) (ok bool) { - var iat int64 + var ( + iat *NumericDate + err error + ) - if iat, ok = m.GetIssuedAt(); !ok { + if iat, err = m.GetIssuedAt(); err != nil { + return false + } + + if iat == nil { return !required } - return verifyInt64Past(iat, cmp, required) + return verifyInt64Past(iat.Int64(), cmp, required) } -// GetNotBefore returns the nbf claim. -func (m MapClaims) GetNotBefore() (nbf int64, ok bool) { - return m.toInt64(consts.ClaimNotBefore) +// GetNotBefore returns the 'nbf' claim. +func (m MapClaims) GetNotBefore() (nbf *NumericDate, err error) { + return m.toNumericDate(ClaimNotBefore) } // VerifyNotBefore compares the nbf claim against cmp. // If required is false, this method will return true if the value matches or is unset func (m MapClaims) VerifyNotBefore(cmp int64, required bool) (ok bool) { - var nbf int64 + var ( + nbf *NumericDate + err error + ) + + if nbf, err = m.GetNotBefore(); err != nil { + return false + } - if nbf, ok = m.GetNotBefore(); !ok { + if nbf == nil { return !required } - return verifyInt64Past(nbf, cmp, required) + return verifyInt64Past(nbf.Int64(), cmp, required) +} + +func (m MapClaims) ToMapClaims() MapClaims { + if m == nil { + return nil + } + + return m +} + +func (m MapClaims) ToMap() map[string]any { + return m } // Valid validates the given claims. By default it only validates time based claims "exp, iat, nbf"; there is no @@ -194,7 +234,7 @@ func (m MapClaims) Valid(opts ...ClaimValidationOption) (err error) { vErr := new(ValidationError) - if !m.VerifyExpiresAt(now, vopts.expRequired) { + if !m.VerifyExpirationTime(now, vopts.expRequired) { vErr.Inner = errors.New("Token is expired") vErr.Errors |= ValidationErrorExpired } @@ -267,6 +307,61 @@ func (m MapClaims) toInt64(claim string) (val int64, ok bool) { return toInt64(v) } +func (m MapClaims) toNumericDate(key string) (date *NumericDate, err error) { + var ( + v any + ok bool + ) + + if v, ok = m[key]; !ok { + return nil, nil + } + + return toNumericDate(v) +} + +func (m MapClaims) toString(key string) (value string, err error) { + var ( + ok bool + raw any + ) + + if raw, ok = m[key]; !ok { + return "", nil + } + + if value, ok = raw.(string); !ok { + return "", newError(fmt.Sprintf("%s is invalid", key), ErrInvalidType) + } + + return value, nil +} + +func (m MapClaims) toClaimsString(key string) (ClaimStrings, error) { + var cs []string + + switch v := m[key].(type) { + case string: + cs = append(cs, v) + case []string: + cs = v + case []any: + for _, a := range v { + if vs, ok := a.(string); !ok { + return nil, newError(fmt.Sprintf("%s is invalid", key), ErrInvalidType) + } else { + cs = append(cs, vs) + } + } + case nil: + return nil, nil + default: + return cs, newError(fmt.Sprintf("%s is invalid", key), ErrInvalidType) + } + + return cs, nil +} + type ClaimValidationOption func(opts *ClaimValidationOptions) type ClaimValidationOptions struct { @@ -395,7 +490,7 @@ func verifyInt64Past(value, now int64, required bool) bool { return now >= value } -func verifyMapString(value, cmp string, required bool) bool { +func verifyString(value, cmp string, required bool) bool { if value == "" { return !required } diff --git a/token/jwt/claims_map_test.go b/token/jwt/claims_map_test.go index d5fdb53d..ba16becf 100644 --- a/token/jwt/claims_map_test.go +++ b/token/jwt/claims_map_test.go @@ -25,7 +25,7 @@ func TestMapClaims_VerifyAudience(t *testing.T) { { "ShouldPass", MapClaims{ - consts.ClaimAudience: []string{"foo"}, + ClaimAudience: []string{"foo"}, }, "foo", true, @@ -34,7 +34,7 @@ func TestMapClaims_VerifyAudience(t *testing.T) { { "ShouldPassMultiple", MapClaims{ - consts.ClaimAudience: []string{"foo", "bar"}, + ClaimAudience: []string{"foo", "bar"}, }, "foo", true, @@ -50,7 +50,7 @@ func TestMapClaims_VerifyAudience(t *testing.T) { { "ShouldFailNoMatch", MapClaims{ - consts.ClaimAudience: []string{"bar"}, + ClaimAudience: []string{"bar"}, }, "foo", true, @@ -66,7 +66,7 @@ func TestMapClaims_VerifyAudience(t *testing.T) { { "ShouldPassTypeAny", MapClaims{ - consts.ClaimAudience: []any{"foo"}, + ClaimAudience: []any{"foo"}, }, "foo", true, @@ -75,7 +75,7 @@ func TestMapClaims_VerifyAudience(t *testing.T) { { "ShouldPassTypeString", MapClaims{ - consts.ClaimAudience: "foo", + ClaimAudience: "foo", }, "foo", true, @@ -453,15 +453,6 @@ func TestMapClaims_VerifyIssuer(t *testing.T) { true, false, }, - { - "ShouldPassNil", - MapClaims{ - consts.ClaimIssuer: nil, - }, - "foo", - false, - true, - }, } for _, tc := range testCases { @@ -529,15 +520,6 @@ func TestMapClaims_VerifySubject(t *testing.T) { true, false, }, - { - "ShouldPassNil", - MapClaims{ - consts.ClaimSubject: nil, - }, - "foo", - false, - true, - }, } for _, tc := range testCases { @@ -623,20 +605,11 @@ func TestMapClaims_VerifyExpiresAt(t *testing.T) { true, false, }, - { - "ShouldPassNil", - MapClaims{ - consts.ClaimExpirationTime: nil, - }, - int64(123), - false, - true, - }, } for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { - assert.Equal(t, tc.expected, tc.have.VerifyExpiresAt(tc.cmp, tc.required)) + assert.Equal(t, tc.expected, tc.have.VerifyExpirationTime(tc.cmp, tc.required)) }) } } @@ -726,15 +699,6 @@ func TestMapClaims_VerifyIssuedAt(t *testing.T) { true, false, }, - { - "ShouldPassNil", - MapClaims{ - consts.ClaimIssuedAt: nil, - }, - int64(123), - false, - true, - }, } for _, tc := range testCases { @@ -829,15 +793,6 @@ func TestMapClaims_VerifyNotBefore(t *testing.T) { true, false, }, - { - "ShouldPassNil", - MapClaims{ - consts.ClaimNotBefore: nil, - }, - int64(123), - false, - true, - }, } for _, tc := range testCases { diff --git a/token/jwt/consts.go b/token/jwt/consts.go index 19fb5470..c27c8ae8 100644 --- a/token/jwt/consts.go +++ b/token/jwt/consts.go @@ -7,7 +7,7 @@ import ( ) const ( - SigningMethodNone = jose.SignatureAlgorithm(consts.JSONWebTokenAlgNone) + SigningMethodNone = jose.SignatureAlgorithm(JSONWebTokenAlgNone) // UnsafeAllowNoneSignatureType is unsafe to use and should be use to correctly sign and verify alg:none JWT tokens. UnsafeAllowNoneSignatureType unsafeNoneMagicConstant = "none signing method allowed" @@ -22,7 +22,7 @@ type Keyfunc func(token *Token) (key any, err error) var ( // SignatureAlgorithmsNone contain all algorithms including 'none'. - SignatureAlgorithmsNone = []jose.SignatureAlgorithm{consts.JSONWebTokenAlgNone, jose.HS256, jose.HS384, jose.HS512, jose.RS256, jose.RS384, jose.RS512, jose.PS256, jose.PS384, jose.PS512, jose.ES256, jose.ES384, jose.ES512} + SignatureAlgorithmsNone = []jose.SignatureAlgorithm{JSONWebTokenAlgNone, jose.HS256, jose.HS384, jose.HS512, jose.RS256, jose.RS384, jose.RS512, jose.PS256, jose.PS384, jose.PS512, jose.ES256, jose.ES384, jose.ES512} // SignatureAlgorithms contain all algorithms excluding 'none'. SignatureAlgorithms = []jose.SignatureAlgorithm{jose.HS256, jose.HS384, jose.HS512, jose.RS256, jose.RS384, jose.RS512, jose.PS256, jose.PS384, jose.PS512, jose.ES256, jose.ES384, jose.ES512} @@ -32,3 +32,64 @@ var ( ContentEncryptionAlgorithms = []jose.ContentEncryption{jose.A128CBC_HS256, jose.A192CBC_HS384, jose.A256CBC_HS512, jose.A128GCM, jose.A192GCM, jose.A256GCM} ) + +const ( + ClaimJWTID = consts.ClaimJWTID + ClaimSessionID = consts.ClaimSessionID + ClaimIssuedAt = consts.ClaimIssuedAt + ClaimNotBefore = consts.ClaimNotBefore + ClaimRequestedAt = consts.ClaimRequestedAt + ClaimExpirationTime = consts.ClaimExpirationTime + ClaimAuthenticationTime = consts.ClaimAuthenticationTime + ClaimIssuer = consts.ClaimIssuer + ClaimSubject = consts.ClaimSubject + ClaimAudience = consts.ClaimAudience + ClaimGroups = consts.ClaimGroups + ClaimFullName = consts.ClaimFullName + ClaimPreferredUsername = consts.ClaimPreferredUsername + ClaimPreferredEmail = consts.ClaimPreferredEmail + ClaimEmailVerified = consts.ClaimEmailVerified + ClaimAuthorizedParty = consts.ClaimAuthorizedParty + ClaimAuthenticationContextClassReference = consts.ClaimAuthenticationContextClassReference + ClaimAuthenticationMethodsReference = consts.ClaimAuthenticationMethodsReference + ClaimClientIdentifier = consts.ClaimClientIdentifier + ClaimScope = consts.ClaimScope + ClaimScopeNonStandard = consts.ClaimScopeNonStandard + ClaimExtra = consts.ClaimExtra + ClaimActive = consts.ClaimActive + ClaimUsername = consts.ClaimUsername + ClaimTokenIntrospection = consts.ClaimTokenIntrospection + ClaimAccessTokenHash = consts.ClaimAccessTokenHash + ClaimCodeHash = consts.ClaimCodeHash + ClaimStateHash = consts.ClaimStateHash + ClaimNonce = consts.ClaimNonce + ClaimAuthorizedActor = consts.ClaimAuthorizedActor + ClaimActor = consts.ClaimActor +) + +const ( + JSONWebTokenHeaderKeyIdentifier = consts.JSONWebTokenHeaderKeyIdentifier + JSONWebTokenHeaderAlgorithm = consts.JSONWebTokenHeaderAlgorithm + JSONWebTokenHeaderEncryptionAlgorithm = consts.JSONWebTokenHeaderEncryptionAlgorithm + JSONWebTokenHeaderCompressionAlgorithm = consts.JSONWebTokenHeaderCompressionAlgorithm + JSONWebTokenHeaderPBES2Count = consts.JSONWebTokenHeaderPBES2Count + + JSONWebTokenHeaderType = consts.JSONWebTokenHeaderType + JSONWebTokenHeaderContentType = consts.JSONWebTokenHeaderContentType +) + +const ( + JSONWebTokenUseSignature = consts.JSONWebTokenUseSignature + JSONWebTokenUseEncryption = consts.JSONWebTokenUseEncryption +) + +const ( + JSONWebTokenTypeJWT = consts.JSONWebTokenTypeJWT + JSONWebTokenTypeAccessToken = consts.JSONWebTokenTypeAccessToken + JSONWebTokenTypeAccessTokenAlternative = consts.JSONWebTokenTypeAccessTokenAlternative + JSONWebTokenTypeTokenIntrospection = consts.JSONWebTokenTypeTokenIntrospection +) + +const ( + JSONWebTokenAlgNone = consts.JSONWebTokenAlgNone +) diff --git a/token/jwt/date.go b/token/jwt/date.go new file mode 100644 index 00000000..187f5cbf --- /dev/null +++ b/token/jwt/date.go @@ -0,0 +1,115 @@ +package jwt + +import ( + "encoding/json" + "errors" + "fmt" + "math" + "strconv" + "time" +) + +type NumericDate struct { + time.Time +} + +func NewNumericDate(t time.Time) *NumericDate { + return &NumericDate{t.Truncate(TimePrecision)} +} + +func newNumericDateFromSeconds(f float64) *NumericDate { + round, frac := math.Modf(f) + return NewNumericDate(time.Unix(int64(round), int64(frac*1e9))) +} + +func (date NumericDate) MarshalJSON() (b []byte, err error) { + var prec int + if TimePrecision < time.Second { + prec = int(math.Log10(float64(time.Second) / float64(TimePrecision))) + } + truncatedDate := date.Truncate(TimePrecision) + + seconds := strconv.FormatInt(truncatedDate.Unix(), 10) + nanosecondsOffset := strconv.FormatFloat(float64(truncatedDate.Nanosecond())/float64(time.Second), 'f', prec, 64) + + output := append([]byte(seconds), []byte(nanosecondsOffset)[1:]...) + + return output, nil +} + +func (date *NumericDate) UnmarshalJSON(b []byte) (err error) { + var ( + number json.Number + f float64 + ) + + if err = json.Unmarshal(b, &number); err != nil { + return fmt.Errorf("could not parse NumericData: %w", err) + } + + if f, err = number.Float64(); err != nil { + return fmt.Errorf("could not convert json number value to float: %w", err) + } + + n := newNumericDateFromSeconds(f) + *date = *n + + return nil +} + +// Int64 returns the time value with UTC as the location, truncated with TimePrecision; as a number of +// since the Unix epoch. +func (date *NumericDate) Int64() (val int64) { + if date == nil { + return 0 + } + + return date.UTC().Truncate(TimePrecision).Unix() +} + +type ClaimStrings []string + +func (s *ClaimStrings) UnmarshalJSON(data []byte) (err error) { + var value interface{} + + if err = json.Unmarshal(data, &value); err != nil { + return err + } + + var aud []string + + switch v := value.(type) { + case string: + aud = append(aud, v) + case []string: + aud = ClaimStrings(v) + case []interface{}: + for _, vv := range v { + vs, ok := vv.(string) + if !ok { + return ErrInvalidType + } + aud = append(aud, vs) + } + case nil: + return nil + default: + return ErrInvalidType + } + + *s = aud + + return +} + +func (s ClaimStrings) MarshalJSON() (b []byte, err error) { + if len(s) == 1 && !MarshalSingleStringAsArray { + return json.Marshal(s[0]) + } + + return json.Marshal([]string(s)) +} + +var ( + ErrInvalidType = errors.New("invalid type for claim") +) diff --git a/token/jwt/header.go b/token/jwt/header.go index 06c0571e..36550c13 100644 --- a/token/jwt/header.go +++ b/token/jwt/header.go @@ -3,10 +3,6 @@ package jwt -import ( - "authelia.com/provider/oauth2/internal/consts" -) - // Headers is the jwt headers type Headers struct { Extra map[string]any `json:"extra"` @@ -18,7 +14,7 @@ func NewHeaders() *Headers { // ToMap will transform the headers to a map structure func (h *Headers) ToMap() map[string]any { - var filter = map[string]bool{consts.JSONWebTokenHeaderAlgorithm: true} + var filter = map[string]bool{JSONWebTokenHeaderAlgorithm: true} var extra = map[string]any{} // filter known values from extra. diff --git a/token/jwt/issuer.go b/token/jwt/issuer.go index de778d40..d90f5595 100644 --- a/token/jwt/issuer.go +++ b/token/jwt/issuer.go @@ -8,8 +8,6 @@ import ( "fmt" "github.com/go-jose/go-jose/v4" - - "authelia.com/provider/oauth2/internal/consts" ) // NewDefaultIssuer returns a new issuer and verifies that one RS256 key exists. @@ -27,7 +25,7 @@ func NewDefaultIssuer(keys ...jose.JSONWebKey) (issuer *DefaultIssuer, err error continue } - if key.Use != consts.JSONWebTokenUseSignature { + if key.Use != JSONWebTokenUseSignature { continue } @@ -47,7 +45,7 @@ func NewDefaultIssuer(keys ...jose.JSONWebKey) (issuer *DefaultIssuer, err error func NewDefaultIssuerFromJWKS(jwks *jose.JSONWebKeySet) (issuer *DefaultIssuer, err error) { for _, key := range jwks.Keys { - if key.Use != consts.JSONWebTokenUseSignature { + if key.Use != JSONWebTokenUseSignature { continue } @@ -100,7 +98,7 @@ func NewDefaultIssuerRS256Unverified(key any) (issuer *DefaultIssuer) { Key: key, KeyID: "default", Algorithm: string(jose.RS256), - Use: consts.JSONWebTokenUseSignature, + Use: JSONWebTokenUseSignature, }, }, }, diff --git a/token/jwt/jwt_strategy.go b/token/jwt/jwt_strategy.go index e507d08e..21c54ad9 100644 --- a/token/jwt/jwt_strategy.go +++ b/token/jwt/jwt_strategy.go @@ -7,7 +7,6 @@ import ( "github.com/go-jose/go-jose/v4" "github.com/go-jose/go-jose/v4/jwt" - "authelia.com/provider/oauth2/internal/consts" "authelia.com/provider/oauth2/x/errorsx" ) @@ -15,7 +14,7 @@ import ( // specifically so it can be mocked and the opts values have very important semantics which are difficult to document. type Strategy interface { // Encode a JWT as either a JWS or JWE nested JWS. - Encode(ctx context.Context, opts ...StrategyOpt) (tokenString string, signature string, err error) + Encode(ctx context.Context, claims Claims, opts ...StrategyOpt) (tokenString string, signature string, err error) // Decrypt a JWT or if the provided JWT is a JWS just return it. Decrypt(ctx context.Context, tokenStringEnc string, opts ...StrategyOpt) (tokenString, signature string, jwe *jose.JSONWebEncryption, err error) @@ -47,9 +46,8 @@ type DefaultStrategy struct { Issuer Issuer } -func (j *DefaultStrategy) Encode(ctx context.Context, opts ...StrategyOpt) (tokenString string, signature string, err error) { +func (j *DefaultStrategy) Encode(ctx context.Context, claims Claims, opts ...StrategyOpt) (tokenString string, signature string, err error) { o := &StrategyOpts{ - claims: MapClaims{}, headers: NewHeaders(), } @@ -64,21 +62,21 @@ func (j *DefaultStrategy) Encode(ctx context.Context, opts ...StrategyOpt) (toke ) if o.client == nil { - if keySig, err = j.Issuer.GetIssuerJWK(ctx, "", string(jose.RS256), consts.JSONWebTokenUseSignature); err != nil { + if keySig, err = j.Issuer.GetIssuerJWK(ctx, "", string(jose.RS256), JSONWebTokenUseSignature); err != nil { return "", "", errorsx.WithStack(fmt.Errorf("error occurred retrieving issuer jwk: %w", err)) } - } else if keySig, err = j.Issuer.GetIssuerJWK(ctx, o.client.GetSigningKeyID(), o.client.GetSigningAlg(), consts.JSONWebTokenUseSignature); err != nil { + } else if keySig, err = j.Issuer.GetIssuerJWK(ctx, o.client.GetSigningKeyID(), o.client.GetSigningAlg(), JSONWebTokenUseSignature); err != nil { return "", "", errorsx.WithStack(fmt.Errorf("error occurred retrieving issuer jwk: %w", err)) } if o.client == nil { - return EncodeCompactSigned(ctx, o.claims, o.headers, keySig) + return EncodeCompactSigned(ctx, claims, o.headers, keySig) } kid, alg, enc := o.client.GetEncryptionKeyID(), o.client.GetEncryptionAlg(), o.client.GetEncryptionEnc() if len(kid) == 0 && len(alg) == 0 { - return EncodeCompactSigned(ctx, o.claims, o.headers, keySig) + return EncodeCompactSigned(ctx, claims, o.headers, keySig) } if len(enc) == 0 { @@ -88,14 +86,14 @@ func (j *DefaultStrategy) Encode(ctx context.Context, opts ...StrategyOpt) (toke var keyEnc *jose.JSONWebKey if IsEncryptedJWTClientSecretAlg(alg) { - if keyEnc, err = NewClientSecretJWKFromClient(ctx, o.client, kid, alg, enc, consts.JSONWebTokenUseEncryption); err != nil { + if keyEnc, err = NewClientSecretJWKFromClient(ctx, o.client, kid, alg, enc, JSONWebTokenUseEncryption); err != nil { return "", "", errorsx.WithStack(fmt.Errorf("Failed to encrypt the JWT using the client secret. %w", err)) } - } else if keyEnc, err = FindClientPublicJWK(ctx, o.client, j.Config.GetJWKSFetcherStrategy(ctx), kid, alg, consts.JSONWebTokenUseEncryption, false); err != nil { + } else if keyEnc, err = FindClientPublicJWK(ctx, o.client, j.Config.GetJWKSFetcherStrategy(ctx), kid, alg, JSONWebTokenUseEncryption, false); err != nil { return "", "", errorsx.WithStack(fmt.Errorf("Failed to encrypt the JWT using the client configuration. %w", err)) } - return EncodeNestedCompactEncrypted(ctx, o.claims, o.headers, o.headersJWE, keySig, keyEnc, jose.ContentEncryption(enc)) + return EncodeNestedCompactEncrypted(ctx, claims, o.headers, o.headersJWE, keySig, keyEnc, jose.ContentEncryption(enc)) } func (j *DefaultStrategy) Decrypt(ctx context.Context, tokenStringEnc string, opts ...StrategyOpt) (tokenString, signature string, jwe *jose.JSONWebEncryption, err error) { @@ -144,10 +142,10 @@ func (j *DefaultStrategy) Decrypt(ctx context.Context, tokenStringEnc string, op return "", "", nil, errorsx.WithStack(&ValidationError{Errors: ValidationErrorUnverifiable, Inner: err}) } - if key, err = NewClientSecretJWKFromClient(ctx, o.client, kid, alg, enc, consts.JSONWebTokenUseEncryption); err != nil { + if key, err = NewClientSecretJWKFromClient(ctx, o.client, kid, alg, enc, JSONWebTokenUseEncryption); err != nil { return "", "", nil, errorsx.WithStack(&ValidationError{Errors: ValidationErrorUnverifiable, Inner: err}) } - } else if key, err = j.Issuer.GetIssuerStrictJWK(ctx, kid, alg, consts.JSONWebTokenUseEncryption); err != nil { + } else if key, err = j.Issuer.GetIssuerStrictJWK(ctx, kid, alg, JSONWebTokenUseEncryption); err != nil { return "", "", nil, errorsx.WithStack(&ValidationError{Errors: ValidationErrorUnverifiable, Inner: err}) } @@ -217,7 +215,7 @@ func (j *DefaultStrategy) Decode(ctx context.Context, tokenString string, opts . validate := o.client != nil || !o.allowUnverified - if alg != consts.JSONWebTokenAlgNone && validate { + if alg != JSONWebTokenAlgNone && validate { if err = j.validate(ctx, t, &claims, o); err != nil { return nil, errorsx.WithStack(err) } @@ -286,15 +284,15 @@ func (j *DefaultStrategy) validate(ctx context.Context, t *jwt.JSONWebToken, des return errorsx.WithStack(&ValidationError{Errors: ValidationErrorHeaderKeyIDInvalid, Inner: fmt.Errorf("error validating the jws header: alg '%s' does not support tokens with a kid but the token has kid '%s'", alg, kid)}) } - if key, err = NewClientSecretJWKFromClient(ctx, o.client, "", alg, "", consts.JSONWebTokenUseSignature); err != nil { + if key, err = NewClientSecretJWKFromClient(ctx, o.client, "", alg, "", JSONWebTokenUseSignature); err != nil { return errorsx.WithStack(&ValidationError{Errors: ValidationErrorUnverifiable, Inner: err}) } } else { - if key, err = FindClientPublicJWK(ctx, o.client, j.Config.GetJWKSFetcherStrategy(ctx), kid, alg, consts.JSONWebTokenUseSignature, true); err != nil { + if key, err = FindClientPublicJWK(ctx, o.client, j.Config.GetJWKSFetcherStrategy(ctx), kid, alg, JSONWebTokenUseSignature, true); err != nil { return errorsx.WithStack(&ValidationError{Errors: ValidationErrorUnverifiable, Inner: err}) } } - } else if key, err = j.Issuer.GetIssuerStrictJWK(ctx, kid, alg, consts.JSONWebTokenUseSignature); err != nil { + } else if key, err = j.Issuer.GetIssuerStrictJWK(ctx, kid, alg, JSONWebTokenUseSignature); err != nil { return errorsx.WithStack(&ValidationError{Errors: ValidationErrorUnverifiable, Inner: err}) } diff --git a/token/jwt/jwt_strategy_opts.go b/token/jwt/jwt_strategy_opts.go index 21520c35..e9a442c8 100644 --- a/token/jwt/jwt_strategy_opts.go +++ b/token/jwt/jwt_strategy_opts.go @@ -9,7 +9,6 @@ import ( type StrategyOpts struct { client Client - claims MapClaims headers, headersJWE Mapper @@ -53,14 +52,6 @@ func WithHeadersJWE(headers Mapper) StrategyOpt { } } -func WithClaims(claims MapClaims) StrategyOpt { - return func(opts *StrategyOpts) (err error) { - opts.claims = claims - - return nil - } -} - func WithClient(client Client) StrategyOpt { return func(opts *StrategyOpts) (err error) { opts.client = client diff --git a/token/jwt/jwt_strategy_test.go b/token/jwt/jwt_strategy_test.go index 02112b60..00738fa8 100644 --- a/token/jwt/jwt_strategy_test.go +++ b/token/jwt/jwt_strategy_test.go @@ -16,8 +16,6 @@ import ( "github.com/go-jose/go-jose/v4" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" - - "authelia.com/provider/oauth2/internal/consts" ) func TestDefaultStrategy(t *testing.T) { @@ -45,19 +43,19 @@ func TestDefaultStrategy(t *testing.T) { { KeyID: "rs256-sig", Key: issuerRS256, - Use: consts.JSONWebTokenUseSignature, + Use: JSONWebTokenUseSignature, Algorithm: string(jose.RS256), }, { KeyID: "es512-sig", Key: issuerES512, - Use: consts.JSONWebTokenUseSignature, + Use: JSONWebTokenUseSignature, Algorithm: string(jose.ES512), }, { KeyID: "es512-enc", Key: issuerES512enc, - Use: consts.JSONWebTokenUseEncryption, + Use: JSONWebTokenUseEncryption, Algorithm: string(jose.ECDH_ES_A256KW), }, }, @@ -68,19 +66,19 @@ func TestDefaultStrategy(t *testing.T) { { KeyID: "rs256-sig", Key: &issuerRS256.PublicKey, - Use: consts.JSONWebTokenUseSignature, + Use: JSONWebTokenUseSignature, Algorithm: string(jose.RS256), }, { KeyID: "es512-sig", Key: &issuerES512.PublicKey, - Use: consts.JSONWebTokenUseSignature, + Use: JSONWebTokenUseSignature, Algorithm: string(jose.ES512), }, { KeyID: "es512-enc", Key: &issuerES512enc.PublicKey, - Use: consts.JSONWebTokenUseEncryption, + Use: JSONWebTokenUseEncryption, Algorithm: string(jose.ECDH_ES_A256KW), }, }, @@ -95,13 +93,13 @@ func TestDefaultStrategy(t *testing.T) { { KeyID: "es512-sig", Key: clientES512, - Use: consts.JSONWebTokenUseSignature, + Use: JSONWebTokenUseSignature, Algorithm: string(jose.ES512), }, { KeyID: "es512-enc", Key: clientES512enc, - Use: consts.JSONWebTokenUseEncryption, + Use: JSONWebTokenUseEncryption, Algorithm: string(jose.ECDH_ES_A256KW), }, }, @@ -112,13 +110,13 @@ func TestDefaultStrategy(t *testing.T) { { KeyID: "es512-sig", Key: &clientES512.PublicKey, - Use: consts.JSONWebTokenUseSignature, + Use: JSONWebTokenUseSignature, Algorithm: string(jose.ES512), }, { KeyID: "es512-enc", Key: &clientES512enc.PublicKey, - Use: consts.JSONWebTokenUseEncryption, + Use: JSONWebTokenUseEncryption, Algorithm: string(jose.ECDH_ES_A256KW), }, }, @@ -129,13 +127,13 @@ func TestDefaultStrategy(t *testing.T) { { KeyID: "es512-sig", Key: &issuerES512.PublicKey, - Use: consts.JSONWebTokenUseSignature, + Use: JSONWebTokenUseSignature, Algorithm: string(jose.ES512), }, { KeyID: "es512-enc", Key: &issuerES512enc.PublicKey, - Use: consts.JSONWebTokenUseEncryption, + Use: JSONWebTokenUseEncryption, Algorithm: string(jose.ECDH_ES_A256KW), }, }, @@ -146,13 +144,13 @@ func TestDefaultStrategy(t *testing.T) { { KeyID: "es512-sig", Key: &clientES512.PublicKey, - Use: consts.JSONWebTokenUseSignature, + Use: JSONWebTokenUseSignature, Algorithm: string(jose.ES512), }, { KeyID: "es512-enc", Key: &clientES512enc.PublicKey, - Use: consts.JSONWebTokenUseEncryption, + Use: JSONWebTokenUseEncryption, Algorithm: string(jose.ECDH_ES_A256KW), }, }, @@ -208,7 +206,7 @@ func TestDefaultStrategy(t *testing.T) { headers1 := &Headers{ Extra: map[string]any{ - consts.JSONWebTokenHeaderType: consts.JSONWebTokenTypeAccessToken, + JSONWebTokenHeaderType: JSONWebTokenTypeAccessToken, }, } @@ -218,7 +216,7 @@ func TestDefaultStrategy(t *testing.T) { token1, signature1 string ) - token1, signature1, err = strategy.Encode(ctx, WithClaims(claims), WithHeaders(headers1), WithClient(client)) + token1, signature1, err = strategy.Encode(ctx, claims, WithHeaders(headers1), WithClient(client)) require.NoError(t, err) assert.NotEmpty(t, signature1) @@ -232,11 +230,11 @@ func TestDefaultStrategy(t *testing.T) { headers2 := &Headers{ Extra: map[string]any{ - consts.JSONWebTokenHeaderType: consts.JSONWebTokenTypeJWT, + JSONWebTokenHeaderType: JSONWebTokenTypeJWT, }, } - token2, signature2, err = strategy.Encode(ctx, WithClaims(claims), WithHeaders(headers2), WithHeadersJWE(headersEnc), WithClient(clientEnc)) + token2, signature2, err = strategy.Encode(ctx, claims, WithHeaders(headers2), WithHeadersJWE(headersEnc), WithClient(clientEnc)) require.NoError(t, err) require.True(t, IsEncryptedJWT(token2)) require.NotEmpty(t, signature2) @@ -245,7 +243,7 @@ func TestDefaultStrategy(t *testing.T) { token3, signature3 string ) - token3, signature3, err = strategy.Encode(ctx, WithClaims(claims), WithHeaders(headers1), WithHeadersJWE(headersEnc), WithClient(clientEncAsymmetric)) + token3, signature3, err = strategy.Encode(ctx, claims, WithHeaders(headers1), WithHeadersJWE(headersEnc), WithClient(clientEncAsymmetric)) require.NoError(t, err) assert.NotEmpty(t, signature3) @@ -361,7 +359,7 @@ func TestNestedJWTEncodeDecode(t *testing.T) { }, } - tokenString, sig, err := providerStrategy.Encode(context.TODO(), WithClaims(claims), WithClient(encodeClientRSA)) + tokenString, sig, err := providerStrategy.Encode(context.TODO(), claims, WithClient(encodeClientRSA)) require.NoError(t, err) assert.NotEmpty(t, sig) assert.NotEmpty(t, tokenString) @@ -416,7 +414,7 @@ func TestNestedJWTEncodeDecode(t *testing.T) { }, } - tokenString, sig, err = providerStrategy.Encode(context.TODO(), WithClaims(claims), WithClient(encodeClientECDSA)) + tokenString, sig, err = providerStrategy.Encode(context.TODO(), claims, WithClient(encodeClientECDSA)) require.NoError(t, err) assert.NotEmpty(t, sig) assert.NotEmpty(t, tokenString) @@ -631,13 +629,13 @@ func init() { testKeySigRSA = jose.JSONWebKey{ Key: k, KeyID: "test-rsa-sig", - Use: consts.JSONWebTokenUseSignature, + Use: JSONWebTokenUseSignature, Algorithm: string(jose.RS256), } testKeyPublicSigRSA = jose.JSONWebKey{ Key: k.Public(), KeyID: "test-rsa-sig", - Use: consts.JSONWebTokenUseSignature, + Use: JSONWebTokenUseSignature, Algorithm: string(jose.RS256), } default: @@ -653,13 +651,13 @@ func init() { testKeyEncRSA = jose.JSONWebKey{ Key: k, KeyID: "test-rsa-enc", - Use: consts.JSONWebTokenUseEncryption, + Use: JSONWebTokenUseEncryption, Algorithm: string(jose.RSA_OAEP_256), } testKeyPublicEncRSA = jose.JSONWebKey{ Key: k.Public(), KeyID: "test-rsa-enc", - Use: consts.JSONWebTokenUseEncryption, + Use: JSONWebTokenUseEncryption, Algorithm: string(jose.RSA_OAEP_256), } default: @@ -675,13 +673,13 @@ func init() { testKeySigECDSA = jose.JSONWebKey{ Key: k, KeyID: "test-ecdsa-sig", - Use: consts.JSONWebTokenUseSignature, + Use: JSONWebTokenUseSignature, Algorithm: string(jose.ES256), } testKeyPublicSigECDSA = jose.JSONWebKey{ Key: k.Public(), KeyID: "test-ecdsa-sig", - Use: consts.JSONWebTokenUseSignature, + Use: JSONWebTokenUseSignature, Algorithm: string(jose.ES256), } default: @@ -697,13 +695,13 @@ func init() { testKeyEncECDSA = jose.JSONWebKey{ Key: k, KeyID: "test-ecdsa-enc", - Use: consts.JSONWebTokenUseEncryption, + Use: JSONWebTokenUseEncryption, Algorithm: string(jose.ECDH_ES_A128KW), } testKeyPublicEncECDSA = jose.JSONWebKey{ Key: k.Public(), KeyID: "test-ecdsa-enc", - Use: consts.JSONWebTokenUseEncryption, + Use: JSONWebTokenUseEncryption, Algorithm: string(jose.ECDH_ES_A128KW), } default: diff --git a/token/jwt/token.go b/token/jwt/token.go index 39c01e7f..b85e72fe 100644 --- a/token/jwt/token.go +++ b/token/jwt/token.go @@ -57,7 +57,7 @@ func ParseCustomWithClaims(tokenString string, claims MapClaims, keyFunc Keyfunc var parsed *jwt.JSONWebToken if parsed, err = jwt.ParseSigned(tokenString, algs); err != nil { - return &Token{}, &ValidationError{Errors: ValidationErrorMalformed, Inner: err} + return &Token{Claims: MapClaims(nil)}, &ValidationError{Errors: ValidationErrorMalformed, Inner: err} } // fill unverified claims @@ -68,12 +68,12 @@ func ParseCustomWithClaims(tokenString string, claims MapClaims, keyFunc Keyfunc // Token, that is an unverified token, therefore an UnsafeClaimsWithoutVerification is done first // then with the returned key, the claims gets verified. if err = parsed.UnsafeClaimsWithoutVerification(&claims); err != nil { - return nil, &ValidationError{Errors: ValidationErrorClaimsInvalid, Inner: err} + return &Token{Claims: MapClaims(nil)}, &ValidationError{Errors: ValidationErrorClaimsInvalid, Inner: err} } // creates an unsafe token if token, err = newToken(parsed, claims); err != nil { - return nil, err + return &Token{Claims: MapClaims(nil)}, err } if keyFunc == nil { @@ -144,7 +144,7 @@ type Token struct { Header map[string]any HeaderJWE map[string]any - Claims MapClaims + Claims Claims parsedToken *jwt.JSONWebToken @@ -162,13 +162,13 @@ func (t *Token) IsSignatureValid() bool { // // > For a type to be a Claims object, it must just have a Valid method that determines // if the token is invalid for any supported reason -type Claims interface { - Valid() error -} +// type Claims interface { +// Valid() error +//} func (t *Token) toSignedJoseHeader() (header map[jose.HeaderKey]any) { header = map[jose.HeaderKey]any{ - consts.JSONWebTokenHeaderType: consts.JSONWebTokenTypeJWT, + JSONWebTokenHeaderType: JSONWebTokenTypeJWT, } for k, v := range t.Header { @@ -180,11 +180,11 @@ func (t *Token) toSignedJoseHeader() (header map[jose.HeaderKey]any) { func (t *Token) toEncryptedJoseHeader() (header map[jose.HeaderKey]any) { header = map[jose.HeaderKey]any{ - consts.JSONWebTokenHeaderType: consts.JSONWebTokenTypeJWT, + JSONWebTokenHeaderType: JSONWebTokenTypeJWT, } - if cty, ok := t.Header[consts.JSONWebTokenHeaderType]; ok { - header[consts.JSONWebTokenHeaderContentType] = cty + if cty, ok := t.Header[JSONWebTokenHeaderType]; ok { + header[JSONWebTokenHeaderContentType] = cty } for k, v := range t.HeaderJWE { @@ -195,7 +195,7 @@ func (t *Token) toEncryptedJoseHeader() (header map[jose.HeaderKey]any) { } // SetJWS sets the JWS output values. -func (t *Token) SetJWS(header Mapper, claims MapClaims, kid string, alg jose.SignatureAlgorithm) { +func (t *Token) SetJWS(header Mapper, claims Claims, kid string, alg jose.SignatureAlgorithm) { assign(t.Header, header.ToMap()) t.KeyID = kid @@ -221,11 +221,11 @@ func (t *Token) AssignJWE(jwe *jose.JSONWebEncryption) { } t.HeaderJWE = map[string]any{ - consts.JSONWebTokenHeaderAlgorithm: jwe.Header.Algorithm, + JSONWebTokenHeaderAlgorithm: jwe.Header.Algorithm, } if jwe.Header.KeyID != "" { - t.HeaderJWE[consts.JSONWebTokenHeaderKeyIdentifier] = jwe.Header.KeyID + t.HeaderJWE[JSONWebTokenHeaderKeyIdentifier] = jwe.Header.KeyID t.EncryptionKeyID = jwe.Header.KeyID } @@ -235,11 +235,11 @@ func (t *Token) AssignJWE(jwe *jose.JSONWebEncryption) { t.HeaderJWE[h] = value switch h { - case consts.JSONWebTokenHeaderEncryptionAlgorithm: + case JSONWebTokenHeaderEncryptionAlgorithm: if v, ok := value.(string); ok { t.ContentEncryption = jose.ContentEncryption(v) } - case consts.JSONWebTokenHeaderCompressionAlgorithm: + case JSONWebTokenHeaderCompressionAlgorithm: if v, ok := value.(string); ok { t.CompressionAlgorithm = jose.CompressionAlgorithm(v) } @@ -270,13 +270,13 @@ func (t *Token) CompactEncrypted(keySig, keyEnc any) (tokenString, signature str ExtraHeaders: t.toEncryptedJoseHeader(), } - if _, ok := opts.ExtraHeaders[consts.JSONWebTokenHeaderContentType]; !ok { + if _, ok := opts.ExtraHeaders[JSONWebTokenHeaderContentType]; !ok { var typ any - if typ, ok = t.Header[consts.JSONWebTokenHeaderType]; ok { - opts.ExtraHeaders[consts.JSONWebTokenHeaderContentType] = typ + if typ, ok = t.Header[JSONWebTokenHeaderType]; ok { + opts.ExtraHeaders[JSONWebTokenHeaderContentType] = typ } else { - opts.ExtraHeaders[consts.JSONWebTokenHeaderContentType] = consts.JSONWebTokenTypeJWT + opts.ExtraHeaders[JSONWebTokenHeaderContentType] = JSONWebTokenTypeJWT } } @@ -338,9 +338,9 @@ func (t *Token) CompactSignedString(k any) (tokenString string, err error) { // to map[string]any is required because the // go-jose CompactSerialize() only support explicit maps // as claims or structs but not type aliases from maps. - claims := map[string]any(t.Claims) + // claims := t.Claims.ToMapClaims() - if tokenString, err = jwt.Signed(signer).Claims(claims).Serialize(); err != nil { + if tokenString, err = jwt.Signed(signer).Claims(t.Claims.ToMapClaims().ToMap()).Serialize(); err != nil { return "", &ValidationError{Errors: ValidationErrorClaimsInvalid, Inner: err} } @@ -350,7 +350,7 @@ func (t *Token) CompactSignedString(k any) (tokenString string, err error) { // Valid validates the token headers given various input options. This does not validate any claims. func (t *Token) Valid(opts ...HeaderValidationOption) (err error) { vopts := &HeaderValidationOptions{ - types: []string{consts.JSONWebTokenTypeJWT}, + types: []string{JSONWebTokenTypeJWT}, } for _, opt := range opts { @@ -370,13 +370,13 @@ func (t *Token) Valid(opts ...HeaderValidationOption) (err error) { ok bool ) - if typ, ok = t.HeaderJWE[consts.JSONWebTokenHeaderType]; !ok || typ != consts.JSONWebTokenTypeJWT { + if typ, ok = t.HeaderJWE[JSONWebTokenHeaderType]; !ok || typ != JSONWebTokenTypeJWT { vErr.Inner = errors.New("token was encrypted with invalid typ") vErr.Errors |= ValidationErrorHeaderEncryptionTypeInvalid } - ttyp := t.Header[consts.JSONWebTokenHeaderType] - cty := t.HeaderJWE[consts.JSONWebTokenHeaderContentType] + ttyp := t.Header[JSONWebTokenHeaderType] + cty := t.HeaderJWE[JSONWebTokenHeaderContentType] if cty != ttyp { vErr.Inner = errors.New("token was encrypted with a cty value that doesn't match the typ value") @@ -448,26 +448,26 @@ func (t *Token) IsJWTProfileAccessToken() (ok bool) { ) if t.HeaderJWE != nil && len(t.HeaderJWE) > 0 { - if raw, ok = t.HeaderJWE[consts.JSONWebTokenHeaderContentType]; ok { + if raw, ok = t.HeaderJWE[JSONWebTokenHeaderContentType]; ok { cty, ok = raw.(string) if !ok { return false } - if cty != consts.JSONWebTokenTypeAccessToken && cty != consts.JSONWebTokenTypeAccessTokenAlternative { + if cty != JSONWebTokenTypeAccessToken && cty != JSONWebTokenTypeAccessTokenAlternative { return false } } } - if raw, ok = t.Header[consts.JSONWebTokenHeaderType]; !ok { + if raw, ok = t.Header[JSONWebTokenHeaderType]; !ok { return false } typ, ok = raw.(string) - return ok && (typ == consts.JSONWebTokenTypeAccessToken || typ == consts.JSONWebTokenTypeAccessTokenAlternative) + return ok && (typ == JSONWebTokenTypeAccessToken || typ == JSONWebTokenTypeAccessTokenAlternative) } type HeaderValidationOption func(opts *HeaderValidationOptions) @@ -518,10 +518,10 @@ func ValidateContentEncryption(enc string) HeaderValidationOption { } func unsignedToken(token *Token) (tokenString string, err error) { - token.Header[consts.JSONWebTokenHeaderAlgorithm] = consts.JSONWebTokenAlgNone + token.Header[JSONWebTokenHeaderAlgorithm] = JSONWebTokenAlgNone - if _, ok := token.Header[consts.JSONWebTokenHeaderType]; !ok { - token.Header[consts.JSONWebTokenHeaderType] = consts.JSONWebTokenTypeJWT + if _, ok := token.Header[JSONWebTokenHeaderType]; !ok { + token.Header[JSONWebTokenHeaderType] = JSONWebTokenTypeJWT } var ( @@ -541,6 +541,11 @@ func unsignedToken(token *Token) (tokenString string, err error) { func newToken(parsedToken *jwt.JSONWebToken, claims MapClaims) (*Token, error) { token := &Token{Claims: claims, parsedToken: parsedToken} + + if token.Claims == nil { + token.Claims = MapClaims{} + } + if len(parsedToken.Headers) != 1 { return nil, &ValidationError{text: fmt.Sprintf("only one header supported, got %v", len(parsedToken.Headers)), Errors: ValidationErrorMalformed} } @@ -548,7 +553,7 @@ func newToken(parsedToken *jwt.JSONWebToken, claims MapClaims) (*Token, error) { // copy headers h := parsedToken.Headers[0] token.Header = map[string]any{ - consts.JSONWebTokenHeaderAlgorithm: h.Algorithm, + JSONWebTokenHeaderAlgorithm: h.Algorithm, } token.SignatureAlgorithm = jose.SignatureAlgorithm(h.Algorithm) diff --git a/token/jwt/token_test.go b/token/jwt/token_test.go index 00cb9d89..170598dd 100644 --- a/token/jwt/token_test.go +++ b/token/jwt/token_test.go @@ -18,7 +18,6 @@ import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" - "authelia.com/provider/oauth2/internal/consts" "authelia.com/provider/oauth2/internal/gen" ) @@ -55,10 +54,10 @@ func TestUnsignedToken(t *testing.T) { parts := strings.Split(rawToken, ".") require.Len(t, parts, 3) require.Empty(t, parts[2]) - tk, err := jwt.ParseSigned(rawToken, []jose.SignatureAlgorithm{consts.JSONWebTokenAlgNone, jose.HS256, jose.HS384, jose.HS512, jose.RS256, jose.RS384, jose.RS512, jose.PS256, jose.PS384, jose.PS512, jose.ES256, jose.ES384, jose.ES512}) + tk, err := jwt.ParseSigned(rawToken, []jose.SignatureAlgorithm{JSONWebTokenAlgNone, jose.HS256, jose.HS384, jose.HS512, jose.RS256, jose.RS384, jose.RS512, jose.PS256, jose.PS384, jose.PS512, jose.ES256, jose.ES384, jose.ES512}) require.NoError(t, err) require.Len(t, tk.Headers, 1) - require.Equal(t, tc.expectedType, tk.Headers[0].ExtraHeaders[(consts.JSONWebTokenHeaderType)]) + require.Equal(t, tc.expectedType, tk.Headers[0].ExtraHeaders[(JSONWebTokenHeaderType)]) }) } } @@ -72,12 +71,12 @@ func TestJWTHeaders(t *testing.T) { { name: "set JWT as 'typ' when the the type is not specified in the headers", jwtHeaders: map[string]any{}, - expectedType: consts.JSONWebTokenTypeJWT, + expectedType: JSONWebTokenTypeJWT, }, { name: "'typ' set explicitly", - jwtHeaders: map[string]any{consts.JSONWebTokenHeaderType: consts.JSONWebTokenTypeAccessToken}, - expectedType: consts.JSONWebTokenTypeAccessToken, + jwtHeaders: map[string]any{JSONWebTokenHeaderType: JSONWebTokenTypeAccessToken}, + expectedType: JSONWebTokenTypeAccessToken, }, } for _, tc := range testCases { @@ -87,7 +86,7 @@ func TestJWTHeaders(t *testing.T) { require.NoError(t, err) require.Len(t, tk.Headers, 1) require.Equal(t, tk.Headers[0].Algorithm, "RS256") - require.Equal(t, tc.expectedType, tk.Headers[0].ExtraHeaders[(consts.JSONWebTokenHeaderType)]) + require.Equal(t, tc.expectedType, tk.Headers[0].ExtraHeaders[(JSONWebTokenHeaderType)]) }) } } @@ -322,12 +321,12 @@ func TestParser_Parse(t *testing.T) { given: given{ name: "used before issued", generate: &generate{ - claims: MapClaims{"foo": "bar", consts.ClaimIssuedAt: time.Now().Unix() + 500}, + claims: MapClaims{"foo": "bar", ClaimIssuedAt: time.Now().Unix() + 500}, }, }, expected: expected{ keyFunc: defaultKeyFunc, - claims: MapClaims{"foo": "bar", consts.ClaimIssuedAt: time.Now().Unix() + 500}, + claims: MapClaims{"foo": "bar", ClaimIssuedAt: time.Now().Unix() + 500}, valid: false, errors: ValidationErrorIssuedAt, }, @@ -419,7 +418,7 @@ func TestParser_Parse(t *testing.T) { // Figure out correct claims type token, err = ParseWithClaims(data.tokenString, MapClaims{}, data.keyFunc) // Verify result matches expectation - assert.EqualValues(t, data.claims, token.Claims) + assert.EqualValues(t, data.claims, token.Claims.ToMapClaims()) if data.valid && err != nil { t.Errorf("[%v] Error while verifying token: %T:%v", data.name, err, err) } diff --git a/token/jwt/util.go b/token/jwt/util.go index 79f94475..f97ec97c 100644 --- a/token/jwt/util.go +++ b/token/jwt/util.go @@ -14,8 +14,6 @@ import ( "github.com/go-jose/go-jose/v4" "github.com/go-jose/go-jose/v4/jwt" "github.com/pkg/errors" - - "authelia.com/provider/oauth2/internal/consts" ) var ( @@ -108,7 +106,7 @@ func headerValidateJWE(header jose.Header) (kid, alg, enc, cty string, err error ) if IsEncryptedJWTPasswordBasedAlg(jose.KeyAlgorithm(header.Algorithm)) { - if value, ok = header.ExtraHeaders[consts.JSONWebTokenHeaderPBES2Count]; ok { + if value, ok = header.ExtraHeaders[JSONWebTokenHeaderPBES2Count]; ok { switch p2c := value.(type) { case float64: if p2c > 5000000 { @@ -122,7 +120,7 @@ func headerValidateJWE(header jose.Header) (kid, alg, enc, cty string, err error } } - if value, ok = header.ExtraHeaders[consts.JSONWebTokenHeaderEncryptionAlgorithm]; ok { + if value, ok = header.ExtraHeaders[JSONWebTokenHeaderEncryptionAlgorithm]; ok { switch encv := value.(type) { case string: if encv != "" { @@ -137,7 +135,7 @@ func headerValidateJWE(header jose.Header) (kid, alg, enc, cty string, err error } } - if value, ok = header.ExtraHeaders[consts.JSONWebTokenHeaderContentType]; ok { + if value, ok = header.ExtraHeaders[JSONWebTokenHeaderContentType]; ok { cty, _ = value.(string) } @@ -277,14 +275,14 @@ func NewClientSecretJWK(ctx context.Context, secret []byte, kid, alg, enc, use s } switch use { - case consts.JSONWebTokenUseSignature: + case JSONWebTokenUseSignature: return &jose.JSONWebKey{ Key: secret, KeyID: kid, Algorithm: alg, Use: use, }, nil - case consts.JSONWebTokenUseEncryption: + case JSONWebTokenUseEncryption: var ( hasher hash.Hash bits int @@ -344,7 +342,7 @@ func NewClientSecretJWK(ctx context.Context, secret []byte, kid, alg, enc, use s } // EncodeCompactSigned helps encoding a token using a signature backed compact encoding. -func EncodeCompactSigned(ctx context.Context, claims MapClaims, headers Mapper, key *jose.JSONWebKey) (tokenString string, signature string, err error) { +func EncodeCompactSigned(ctx context.Context, claims Claims, headers Mapper, key *jose.JSONWebKey) (tokenString string, signature string, err error) { token := New() if headers == nil { @@ -358,7 +356,7 @@ func EncodeCompactSigned(ctx context.Context, claims MapClaims, headers Mapper, // EncodeNestedCompactEncrypted helps encoding a token using a signature backed compact encoding, then nests that within // an encrypted compact encoded JWT. -func EncodeNestedCompactEncrypted(ctx context.Context, claims MapClaims, headers, headersJWE Mapper, keySig, keyEnc *jose.JSONWebKey, enc jose.ContentEncryption) (tokenString string, signature string, err error) { +func EncodeNestedCompactEncrypted(ctx context.Context, claims Claims, headers, headersJWE Mapper, keySig, keyEnc *jose.JSONWebKey, enc jose.ContentEncryption) (tokenString string, signature string, err error) { token := New() if headers == nil { @@ -430,3 +428,23 @@ func UnsafeParseSignedAny(tokenString string, dest any) (token *jwt.JSONWebToken return token, nil } + +func newError(message string, err error, more ...error) error { + var format string + var args []any + if message != "" { + format = "%w: %s" + args = []any{err, message} + } else { + format = "%w" + args = []any{err} + } + + for _, e := range more { + format += ": %w" + args = append(args, e) + } + + err = fmt.Errorf(format, args...) + return err +} diff --git a/token/jwt/variables.go b/token/jwt/variables.go new file mode 100644 index 00000000..d5421afc --- /dev/null +++ b/token/jwt/variables.go @@ -0,0 +1,9 @@ +package jwt + +import "time" + +var ( + MarshalSingleStringAsArray = true + TimePrecision = time.Second + TimeFunc = time.Now +) From 7fac4548479b6b77dd5211dda5b6018cae288d0d Mon Sep 17 00:00:00 2001 From: James Elliott Date: Sun, 29 Sep 2024 11:34:55 +1000 Subject: [PATCH 30/33] feat: claims interface --- README.md | 9 +- handler/oauth2/flow_authorize_code_token.go | 5 +- .../oauth2/flow_authorize_code_token_test.go | 5 +- handler/oauth2/flow_authorize_implicit.go | 3 +- handler/oauth2/flow_generic_code_token.go | 5 +- handler/oauth2/flow_refresh.go | 5 +- handler/oauth2/flow_refresh_test.go | 35 +-- handler/oauth2/flow_resource_owner.go | 5 +- handler/oauth2/flow_resource_owner_test.go | 5 +- handler/oauth2/strategy_jwt_profile_test.go | 10 +- .../openid/flow_device_authorization_test.go | 4 +- handler/openid/flow_hybrid.go | 4 +- handler/openid/flow_hybrid_test.go | 2 +- handler/openid/flow_refresh_token.go | 5 +- handler/openid/strategy_jwt.go | 34 +-- handler/openid/strategy_jwt_test.go | 50 ++--- handler/openid/validator.go | 18 +- handler/openid/validator_test.go | 72 +++---- handler/rfc8628/device_authorize_handler.go | 3 +- .../rfc8628/token_endpoint_handler_test.go | 9 +- .../rfc8628/user_authorize_handler_test.go | 9 +- handler/rfc8693/access_token_type_handler.go | 3 +- handler/rfc8693/custom_jwt_type_handler.go | 8 +- handler/rfc8693/refresh_token_type_handler.go | 3 +- integration/oidc_explicit_test.go | 24 +-- internal/test_helpers.go | 4 +- introspection_response_writer_test.go | 2 +- rfc8628_device_authorize_write_test.go | 5 +- testing/mock/client.go | 4 +- token/jarm/generate.go | 30 +-- token/jwt/claims_id_token.go | 199 ++++++++++++++++-- token/jwt/claims_id_token_test.go | 16 +- token/jwt/claims_jarm.go | 118 +++++++++-- token/jwt/claims_jarm_test.go | 14 +- token/jwt/claims_jwt.go | 4 +- token/jwt/claims_jwt_test.go | 6 +- token/jwt/claims_map.go | 148 +------------ token/jwt/claims_test.go | 2 +- token/jwt/date.go | 61 +++++- token/jwt/token.go | 7 +- token/jwt/token_test.go | 1 - token/jwt/validate.go | 164 +++++++++++++++ 42 files changed, 723 insertions(+), 397 deletions(-) create mode 100644 token/jwt/validate.go diff --git a/README.md b/README.md index 15ce90d2..a623cdd5 100644 --- a/README.md +++ b/README.md @@ -22,9 +22,9 @@ following list of differences: - [x] Minimum dependency is go version 1.21 - [x] Replace string values with constants where applicable [commit](https://github.com/authelia/oauth2-provider/commit/de536dc0c9cd5f080c387621799e644319587bd0) -- [ ] Simplify the internal JWT logic to leverage `github.com/golang-jwt/jwt/v5` +- [x] Simplify the internal JWT logic to leverage `github.com/golang-jwt/jwt/v5` or other such libraries -- [ ] Implement internal JWKS logic +- [x] Implement internal JWKS logic - [x] Higher Debug error information visibility (Debug Field includes the complete RFC6749 error with debug information if available) - Fixes: @@ -103,10 +103,10 @@ following list of differences: - [x] General Refactor - [x] Prevent Multiple Client Authentication Methods - [x] Client Secret Validation Interface - - [ ] JWE support for Client Authentication and Issuance + - [x] JWE support for Client Authentication and Issuance - [x] Testing Package (mocks, etc) - [ ] Clock Drift Support - - [ ] Key Management + - [x] Key Management - [ ] Injectable Clock Configurator - [x] Support `s_hash` [commit](https://github.com/authelia/oauth2-provider/commit/edbbbe9467c70a2578db4b9af4d6cd319f74886e) @@ -125,6 +125,7 @@ following list of differences: - [x] `github.com/gobuffalo/packr` - [x] `github.com/form3tech-oss/jwt-go` - [x] `github.com/dgrijalva/jwt-go` + - [x] `github.com/golang-jwt/jwt` - Migration of the following dependencies: - [x] `github.com/go-jose/go-jose/v3` => `github.com/go-jose/go-jose/v4` - [x] `github.com/golang/mock` => `github.com/uber-go/mock` diff --git a/handler/oauth2/flow_authorize_code_token.go b/handler/oauth2/flow_authorize_code_token.go index 4ad81764..5976bbec 100644 --- a/handler/oauth2/flow_authorize_code_token.go +++ b/handler/oauth2/flow_authorize_code_token.go @@ -12,6 +12,7 @@ import ( "authelia.com/provider/oauth2" "authelia.com/provider/oauth2/internal/consts" "authelia.com/provider/oauth2/storage" + "authelia.com/provider/oauth2/token/jwt" "authelia.com/provider/oauth2/x/errorsx" ) @@ -106,11 +107,11 @@ func (c *AuthorizeExplicitGrantHandler) HandleTokenEndpointRequest(ctx context.C request.SetID(authorizeRequest.GetID()) atLifespan := oauth2.GetEffectiveLifespan(request.GetClient(), oauth2.GrantTypeAuthorizationCode, oauth2.AccessToken, c.Config.GetAccessTokenLifespan(ctx)) - request.GetSession().SetExpiresAt(oauth2.AccessToken, time.Now().UTC().Add(atLifespan).Round(time.Second)) + request.GetSession().SetExpiresAt(oauth2.AccessToken, time.Now().UTC().Add(atLifespan).Truncate(jwt.TimePrecision)) rtLifespan := oauth2.GetEffectiveLifespan(request.GetClient(), oauth2.GrantTypeAuthorizationCode, oauth2.RefreshToken, c.Config.GetRefreshTokenLifespan(ctx)) if rtLifespan > -1 { - request.GetSession().SetExpiresAt(oauth2.RefreshToken, time.Now().UTC().Add(rtLifespan).Round(time.Second)) + request.GetSession().SetExpiresAt(oauth2.RefreshToken, time.Now().UTC().Add(rtLifespan).Truncate(jwt.TimePrecision)) } return nil diff --git a/handler/oauth2/flow_authorize_code_token_test.go b/handler/oauth2/flow_authorize_code_token_test.go index 4e1810a2..2e583852 100644 --- a/handler/oauth2/flow_authorize_code_token_test.go +++ b/handler/oauth2/flow_authorize_code_token_test.go @@ -19,6 +19,7 @@ import ( "authelia.com/provider/oauth2/internal/consts" "authelia.com/provider/oauth2/storage" "authelia.com/provider/oauth2/testing/mock" + "authelia.com/provider/oauth2/token/jwt" ) func TestAuthorizeCode_PopulateTokenEndpointResponse(t *testing.T) { @@ -462,8 +463,8 @@ func TestAuthorizeExplicitGrantHandler_HandleTokenEndpointRequest(t *testing.T) require.NoError(t, s.InvalidateAuthorizeCodeSession(context.TODO(), sig)) }, func(t *testing.T, s CoreStorage, r *oauth2.AccessRequest, ar *oauth2.AuthorizeRequest) { - assert.Equal(t, time.Now().Add(time.Minute).UTC().Round(time.Second), r.GetSession().GetExpiresAt(oauth2.AccessToken)) - assert.Equal(t, time.Now().Add(time.Minute).UTC().Round(time.Second), r.GetSession().GetExpiresAt(oauth2.RefreshToken)) + assert.Equal(t, time.Now().Add(time.Minute).UTC().Truncate(jwt.TimePrecision), r.GetSession().GetExpiresAt(oauth2.AccessToken)) + assert.Equal(t, time.Now().Add(time.Minute).UTC().Truncate(jwt.TimePrecision), r.GetSession().GetExpiresAt(oauth2.RefreshToken)) }, "The provided authorization grant (e.g., authorization code, resource owner credentials) or refresh token is invalid, expired, revoked, does not match the redirection URI used in the authorization request, or was issued to another client. The authorization code has already been used.", }, diff --git a/handler/oauth2/flow_authorize_implicit.go b/handler/oauth2/flow_authorize_implicit.go index 00b05ece..2fa9e31d 100644 --- a/handler/oauth2/flow_authorize_implicit.go +++ b/handler/oauth2/flow_authorize_implicit.go @@ -11,6 +11,7 @@ import ( "authelia.com/provider/oauth2" "authelia.com/provider/oauth2/internal/consts" + "authelia.com/provider/oauth2/token/jwt" "authelia.com/provider/oauth2/x/errorsx" ) @@ -70,7 +71,7 @@ func (c *AuthorizeImplicitGrantTypeHandler) IssueImplicitAccessToken(ctx context // Only override expiry if none is set. atLifespan := oauth2.GetEffectiveLifespan(requester.GetClient(), oauth2.GrantTypeImplicit, oauth2.AccessToken, c.Config.GetAccessTokenLifespan(ctx)) if requester.GetSession().GetExpiresAt(oauth2.AccessToken).IsZero() { - requester.GetSession().SetExpiresAt(oauth2.AccessToken, time.Now().UTC().Add(atLifespan).Round(time.Second)) + requester.GetSession().SetExpiresAt(oauth2.AccessToken, time.Now().UTC().Add(atLifespan).Truncate(jwt.TimePrecision)) } // Generate the code diff --git a/handler/oauth2/flow_generic_code_token.go b/handler/oauth2/flow_generic_code_token.go index f94d5d00..9c4ac293 100644 --- a/handler/oauth2/flow_generic_code_token.go +++ b/handler/oauth2/flow_generic_code_token.go @@ -8,6 +8,7 @@ import ( "authelia.com/provider/oauth2" "authelia.com/provider/oauth2/storage" + "authelia.com/provider/oauth2/token/jwt" "authelia.com/provider/oauth2/x/errorsx" ) @@ -130,11 +131,11 @@ func (c *GenericCodeTokenEndpointHandler) HandleTokenEndpointRequest(ctx context request.SetID(authorizeRequest.GetID()) atLifespan := oauth2.GetEffectiveLifespan(request.GetClient(), oauth2.GrantTypeAuthorizationCode, oauth2.AccessToken, c.Config.GetAccessTokenLifespan(ctx)) - request.GetSession().SetExpiresAt(oauth2.AccessToken, time.Now().UTC().Add(atLifespan).Round(time.Second)) + request.GetSession().SetExpiresAt(oauth2.AccessToken, time.Now().UTC().Add(atLifespan).Truncate(jwt.TimePrecision)) rtLifespan := oauth2.GetEffectiveLifespan(request.GetClient(), oauth2.GrantTypeAuthorizationCode, oauth2.RefreshToken, c.Config.GetRefreshTokenLifespan(ctx)) if rtLifespan > -1 { - request.GetSession().SetExpiresAt(oauth2.RefreshToken, time.Now().UTC().Add(rtLifespan).Round(time.Second)) + request.GetSession().SetExpiresAt(oauth2.RefreshToken, time.Now().UTC().Add(rtLifespan).Truncate(jwt.TimePrecision)) } return nil diff --git a/handler/oauth2/flow_refresh.go b/handler/oauth2/flow_refresh.go index c0d0a6f5..fc121142 100644 --- a/handler/oauth2/flow_refresh.go +++ b/handler/oauth2/flow_refresh.go @@ -13,6 +13,7 @@ import ( "authelia.com/provider/oauth2" "authelia.com/provider/oauth2/internal/consts" "authelia.com/provider/oauth2/storage" + "authelia.com/provider/oauth2/token/jwt" "authelia.com/provider/oauth2/x/errorsx" ) @@ -146,11 +147,11 @@ func (c *RefreshTokenGrantHandler) HandleTokenEndpointRequest(ctx context.Contex } atLifespan := oauth2.GetEffectiveLifespan(request.GetClient(), oauth2.GrantTypeRefreshToken, oauth2.AccessToken, c.Config.GetAccessTokenLifespan(ctx)) - request.GetSession().SetExpiresAt(oauth2.AccessToken, time.Now().UTC().Add(atLifespan).Round(time.Second)) + request.GetSession().SetExpiresAt(oauth2.AccessToken, time.Now().UTC().Add(atLifespan).Truncate(jwt.TimePrecision)) rtLifespan := oauth2.GetEffectiveLifespan(request.GetClient(), oauth2.GrantTypeRefreshToken, oauth2.RefreshToken, c.Config.GetRefreshTokenLifespan(ctx)) if rtLifespan > -1 { - request.GetSession().SetExpiresAt(oauth2.RefreshToken, time.Now().UTC().Add(rtLifespan).Round(time.Second)) + request.GetSession().SetExpiresAt(oauth2.RefreshToken, time.Now().UTC().Add(rtLifespan).Truncate(jwt.TimePrecision)) } return nil diff --git a/handler/oauth2/flow_refresh_test.go b/handler/oauth2/flow_refresh_test.go index 230ce7a2..7f8f290e 100644 --- a/handler/oauth2/flow_refresh_test.go +++ b/handler/oauth2/flow_refresh_test.go @@ -20,6 +20,7 @@ import ( "authelia.com/provider/oauth2/internal/consts" "authelia.com/provider/oauth2/storage" "authelia.com/provider/oauth2/testing/mock" + "authelia.com/provider/oauth2/token/jwt" ) func TestRefreshFlow_HandleTokenEndpointRequest(t *testing.T) { @@ -114,7 +115,7 @@ func TestRefreshFlow_HandleTokenEndpointRequest(t *testing.T) { RequestedScope: oauth2.Arguments{"foo", "bar", consts.ScopeOffline}, Session: expiredSess, Form: url.Values{"foo": []string{"bar"}}, - RequestedAt: time.Now().UTC().Add(-time.Hour).Round(time.Hour), + RequestedAt: time.Now().UTC().Add(-time.Hour).Truncate(time.Hour), }) require.NoError(t, err) }, @@ -139,7 +140,7 @@ func TestRefreshFlow_HandleTokenEndpointRequest(t *testing.T) { RequestedScope: oauth2.Arguments{"foo", consts.ScopeOffline}, Session: sess, Form: url.Values{"foo": []string{"bar"}}, - RequestedAt: time.Now().UTC().Add(-time.Hour).Round(time.Hour), + RequestedAt: time.Now().UTC().Add(-time.Hour).Truncate(time.Hour), }) require.NoError(t, err) }, @@ -165,18 +166,18 @@ func TestRefreshFlow_HandleTokenEndpointRequest(t *testing.T) { RequestedScope: oauth2.Arguments{"foo", "bar", consts.ScopeOffline}, Session: sess, Form: url.Values{"foo": []string{"bar"}}, - RequestedAt: time.Now().UTC().Add(-time.Hour).Round(time.Hour), + RequestedAt: time.Now().UTC().Add(-time.Hour).Truncate(time.Hour), }) require.NoError(t, err) }, expect: func(t *testing.T) { assert.NotEqual(t, sess, areq.Session) - assert.NotEqual(t, time.Now().UTC().Add(-time.Hour).Round(time.Hour), areq.RequestedAt) + assert.NotEqual(t, time.Now().UTC().Add(-time.Hour).Truncate(time.Hour), areq.RequestedAt) assert.Equal(t, oauth2.Arguments{"foo", consts.ScopeOffline}, areq.GrantedScope) assert.Equal(t, oauth2.Arguments{"foo", consts.ScopeOffline}, areq.RequestedScope) assert.NotEqual(t, url.Values{"foo": []string{"bar"}}, areq.Form) - assert.Equal(t, time.Now().Add(time.Hour).UTC().Round(time.Second), areq.GetSession().GetExpiresAt(oauth2.AccessToken)) - assert.Equal(t, time.Now().Add(time.Hour).UTC().Round(time.Second), areq.GetSession().GetExpiresAt(oauth2.RefreshToken)) + assert.Equal(t, time.Now().Add(time.Hour).UTC().Truncate(jwt.TimePrecision), areq.GetSession().GetExpiresAt(oauth2.AccessToken)) + assert.Equal(t, time.Now().Add(time.Hour).UTC().Truncate(jwt.TimePrecision), areq.GetSession().GetExpiresAt(oauth2.RefreshToken)) }, }, { @@ -200,7 +201,7 @@ func TestRefreshFlow_HandleTokenEndpointRequest(t *testing.T) { RequestedScope: oauth2.Arguments{"foo", "bar", "baz", consts.ScopeOffline}, Session: sess, Form: url.Values{"foo": []string{"bar"}}, - RequestedAt: time.Now().UTC().Add(-time.Hour).Round(time.Hour), + RequestedAt: time.Now().UTC().Add(-time.Hour).Truncate(time.Hour), }) require.NoError(t, err) }, @@ -232,7 +233,7 @@ func TestRefreshFlow_HandleTokenEndpointRequest(t *testing.T) { RequestedScope: oauth2.Arguments{"foo", "bar", "baz", consts.ScopeOffline}, Session: sess, Form: url.Values{"foo": []string{"bar"}}, - RequestedAt: time.Now().UTC().Add(-time.Hour).Round(time.Hour), + RequestedAt: time.Now().UTC().Add(-time.Hour).Truncate(time.Hour), }) require.NoError(t, err) }, @@ -264,7 +265,7 @@ func TestRefreshFlow_HandleTokenEndpointRequest(t *testing.T) { RequestedScope: oauth2.Arguments{"foo", "baz", consts.ScopeOffline}, Session: sess, Form: url.Values{"foo": []string{"bar"}}, - RequestedAt: time.Now().UTC().Add(-time.Hour).Round(time.Hour), + RequestedAt: time.Now().UTC().Add(-time.Hour).Truncate(time.Hour), }) require.NoError(t, err) }, @@ -294,13 +295,13 @@ func TestRefreshFlow_HandleTokenEndpointRequest(t *testing.T) { RequestedScope: oauth2.Arguments{"foo", "bar", consts.ScopeOffline}, Session: sess, Form: url.Values{"foo": []string{"bar"}}, - RequestedAt: time.Now().UTC().Add(-time.Hour).Round(time.Hour), + RequestedAt: time.Now().UTC().Add(-time.Hour).Truncate(time.Hour), }) require.NoError(t, err) }, expect: func(t *testing.T) { assert.NotEqual(t, sess, areq.Session) - assert.NotEqual(t, time.Now().UTC().Add(-time.Hour).Round(time.Hour), areq.RequestedAt) + assert.NotEqual(t, time.Now().UTC().Add(-time.Hour).Truncate(time.Hour), areq.RequestedAt) assert.Equal(t, oauth2.Arguments{"foo", consts.ScopeOffline}, areq.GrantedScope) assert.Equal(t, oauth2.Arguments{"foo", consts.ScopeOffline}, areq.RequestedScope) assert.NotEqual(t, url.Values{"foo": []string{"bar"}}, areq.Form) @@ -328,7 +329,7 @@ func TestRefreshFlow_HandleTokenEndpointRequest(t *testing.T) { RequestedScope: oauth2.Arguments{"foo", "bar"}, Session: sess, Form: url.Values{"foo": []string{"bar"}}, - RequestedAt: time.Now().UTC().Add(-time.Hour).Round(time.Hour), + RequestedAt: time.Now().UTC().Add(-time.Hour).Truncate(time.Hour), }) require.NoError(t, err) }, @@ -355,18 +356,18 @@ func TestRefreshFlow_HandleTokenEndpointRequest(t *testing.T) { RequestedScope: oauth2.Arguments{"foo", "bar"}, Session: sess, Form: url.Values{"foo": []string{"bar"}}, - RequestedAt: time.Now().UTC().Add(-time.Hour).Round(time.Hour), + RequestedAt: time.Now().UTC().Add(-time.Hour).Truncate(time.Hour), }) require.NoError(t, err) }, expect: func(t *testing.T) { assert.NotEqual(t, sess, areq.Session) - assert.NotEqual(t, time.Now().UTC().Add(-time.Hour).Round(time.Hour), areq.RequestedAt) + assert.NotEqual(t, time.Now().UTC().Add(-time.Hour).Truncate(time.Hour), areq.RequestedAt) assert.Equal(t, oauth2.Arguments{"foo"}, areq.GrantedScope) assert.Equal(t, oauth2.Arguments{"foo"}, areq.RequestedScope) assert.NotEqual(t, url.Values{"foo": []string{"bar"}}, areq.Form) - assert.Equal(t, time.Now().Add(time.Hour).UTC().Round(time.Second), areq.GetSession().GetExpiresAt(oauth2.AccessToken)) - assert.Equal(t, time.Now().Add(time.Hour).UTC().Round(time.Second), areq.GetSession().GetExpiresAt(oauth2.RefreshToken)) + assert.Equal(t, time.Now().Add(time.Hour).UTC().Truncate(jwt.TimePrecision), areq.GetSession().GetExpiresAt(oauth2.AccessToken)) + assert.Equal(t, time.Now().Add(time.Hour).UTC().Truncate(jwt.TimePrecision), areq.GetSession().GetExpiresAt(oauth2.RefreshToken)) }, }, { @@ -389,7 +390,7 @@ func TestRefreshFlow_HandleTokenEndpointRequest(t *testing.T) { RequestedScope: oauth2.Arguments{"foo", "bar", consts.ScopeOffline}, Session: sess, Form: url.Values{"foo": []string{"bar"}}, - RequestedAt: time.Now().UTC().Add(-time.Hour).Round(time.Hour), + RequestedAt: time.Now().UTC().Add(-time.Hour).Truncate(time.Hour), } err = store.CreateRefreshTokenSession(context.TODO(), sig, req) require.NoError(t, err) diff --git a/handler/oauth2/flow_resource_owner.go b/handler/oauth2/flow_resource_owner.go index 89c59f3f..d140a948 100644 --- a/handler/oauth2/flow_resource_owner.go +++ b/handler/oauth2/flow_resource_owner.go @@ -11,6 +11,7 @@ import ( "authelia.com/provider/oauth2" "authelia.com/provider/oauth2/internal/consts" + "authelia.com/provider/oauth2/token/jwt" "authelia.com/provider/oauth2/x/errorsx" ) @@ -70,11 +71,11 @@ func (c *ResourceOwnerPasswordCredentialsGrantHandler) HandleTokenEndpointReques delete(request.GetRequestForm(), consts.FormParameterPassword) atLifespan := oauth2.GetEffectiveLifespan(request.GetClient(), oauth2.GrantTypePassword, oauth2.AccessToken, c.Config.GetAccessTokenLifespan(ctx)) - request.GetSession().SetExpiresAt(oauth2.AccessToken, time.Now().UTC().Add(atLifespan).Round(time.Second)) + request.GetSession().SetExpiresAt(oauth2.AccessToken, time.Now().UTC().Add(atLifespan).Truncate(jwt.TimePrecision)) rtLifespan := oauth2.GetEffectiveLifespan(request.GetClient(), oauth2.GrantTypePassword, oauth2.RefreshToken, c.Config.GetRefreshTokenLifespan(ctx)) if rtLifespan > -1 { - request.GetSession().SetExpiresAt(oauth2.RefreshToken, time.Now().UTC().Add(rtLifespan).Round(time.Second)) + request.GetSession().SetExpiresAt(oauth2.RefreshToken, time.Now().UTC().Add(rtLifespan).Truncate(jwt.TimePrecision)) } return nil diff --git a/handler/oauth2/flow_resource_owner_test.go b/handler/oauth2/flow_resource_owner_test.go index 030e7613..1cabbb16 100644 --- a/handler/oauth2/flow_resource_owner_test.go +++ b/handler/oauth2/flow_resource_owner_test.go @@ -18,6 +18,7 @@ import ( "authelia.com/provider/oauth2" "authelia.com/provider/oauth2/internal/consts" "authelia.com/provider/oauth2/testing/mock" + "authelia.com/provider/oauth2/token/jwt" ) func TestResourceOwnerFlow_HandleTokenEndpointRequest(t *testing.T) { @@ -89,8 +90,8 @@ func TestResourceOwnerFlow_HandleTokenEndpointRequest(t *testing.T) { store.EXPECT().Authenticate(context.TODO(), "peter", "pan").Return(nil) }, check: func(areq *oauth2.AccessRequest) { - assert.Equal(t, time.Now().Add(time.Hour).UTC().Round(time.Second), areq.GetSession().GetExpiresAt(oauth2.AccessToken)) - assert.Equal(t, time.Now().Add(time.Hour).UTC().Round(time.Second), areq.GetSession().GetExpiresAt(oauth2.RefreshToken)) + assert.Equal(t, time.Now().Add(time.Hour).UTC().Truncate(jwt.TimePrecision), areq.GetSession().GetExpiresAt(oauth2.AccessToken)) + assert.Equal(t, time.Now().Add(time.Hour).UTC().Truncate(jwt.TimePrecision), areq.GetSession().GetExpiresAt(oauth2.RefreshToken)) }, }, } { diff --git a/handler/oauth2/strategy_jwt_profile_test.go b/handler/oauth2/strategy_jwt_profile_test.go index a29476c6..ab820c26 100644 --- a/handler/oauth2/strategy_jwt_profile_test.go +++ b/handler/oauth2/strategy_jwt_profile_test.go @@ -23,8 +23,8 @@ import ( var rsaKey = gen.MustRSAKey() -// returns a valid JWT type. The JWTClaims.ExpiresAt time is intentionally -// left empty to ensure it is pulled from the session's ExpiresAt map for +// returns a valid JWT type. The JWTClaims.ExpirationTime time is intentionally +// left empty to ensure it is pulled from the session's ExpirationTime map for // the given oauth2.TokenType. var jwtValidCase = func(tokenType oauth2.TokenType) *oauth2.Request { r := &oauth2.Request{ @@ -132,7 +132,7 @@ var jwtValidCaseWithRefreshExpiry = func(tokenType oauth2.TokenType) *oauth2.Req }, ExpiresAt: map[oauth2.TokenType]time.Time{ tokenType: time.Now().UTC().Add(time.Hour), - oauth2.RefreshToken: time.Now().UTC().Add(time.Hour * 2).Round(time.Hour), + oauth2.RefreshToken: time.Now().UTC().Add(time.Hour * 2).Truncate(time.Hour), }, }, } @@ -144,8 +144,8 @@ var jwtValidCaseWithRefreshExpiry = func(tokenType oauth2.TokenType) *oauth2.Req return r } -// returns an expired JWT type. The JWTClaims.ExpiresAt time is intentionally -// left empty to ensure it is pulled from the session's ExpiresAt map for +// returns an expired JWT type. The JWTClaims.ExpirationTime time is intentionally +// left empty to ensure it is pulled from the session's ExpirationTime map for // the given oauth2.TokenType. var jwtExpiredCase = func(tokenType oauth2.TokenType, now time.Time) *oauth2.Request { r := &oauth2.Request{ diff --git a/handler/openid/flow_device_authorization_test.go b/handler/openid/flow_device_authorization_test.go index 0f997f84..b23309b3 100644 --- a/handler/openid/flow_device_authorization_test.go +++ b/handler/openid/flow_device_authorization_test.go @@ -169,7 +169,7 @@ func TestOpenIDConnectDeviceAuthorizeHandler_PopulateTokenEndpointResponse(t *te setup: func() { sess := &DefaultSession{ Claims: &jwt.IDTokenClaims{ - RequestedAt: time.Now().UTC(), + RequestedAt: jwt.Now(), Subject: "foobar", }, Headers: &jwt.Headers{}, @@ -297,7 +297,7 @@ func TestOpenIDConnectDeviceAuthorizeHandler_PopulateTokenEndpointResponse(t *te setup: func() { sess := &DefaultSession{ Claims: &jwt.IDTokenClaims{ - RequestedAt: time.Now().UTC(), + RequestedAt: jwt.Now(), Subject: "", }, Headers: &jwt.Headers{}, diff --git a/handler/openid/flow_hybrid.go b/handler/openid/flow_hybrid.go index 892b5df0..4c9915bc 100644 --- a/handler/openid/flow_hybrid.go +++ b/handler/openid/flow_hybrid.go @@ -132,11 +132,11 @@ func (c *OpenIDConnectHybridHandler) HandleAuthorizeEndpointRequest(ctx context. // sets the proper access/refresh token lifetimes. // // if c.AuthorizeExplicitGrantHandler.RefreshTokenLifespan > -1 { - // requester.GetSession().SetExpiresAt(oauth2.RefreshToken, time.Now().UTC().Add(c.AuthorizeExplicitGrantHandler.RefreshTokenLifespan).Round(time.Second)) + // requester.GetSession().SetExpiresAt(oauth2.RefreshToken, time.Now().UTC().Add(c.AuthorizeExplicitGrantHandler.RefreshTokenLifespan).Truncate(jwt.TimePrecision)) // } // This is required because we must limit the authorize code lifespan. - requester.GetSession().SetExpiresAt(oauth2.AuthorizeCode, time.Now().UTC().Add(c.AuthorizeExplicitGrantHandler.Config.GetAuthorizeCodeLifespan(ctx)).Round(time.Second)) + requester.GetSession().SetExpiresAt(oauth2.AuthorizeCode, time.Now().UTC().Add(c.AuthorizeExplicitGrantHandler.Config.GetAuthorizeCodeLifespan(ctx)).Truncate(jwt.TimePrecision)) if err = c.AuthorizeExplicitGrantHandler.CoreStorage.CreateAuthorizeCodeSession(ctx, signature, requester.Sanitize(c.AuthorizeExplicitGrantHandler.GetSanitationWhiteList(ctx))); err != nil { return errorsx.WithStack(oauth2.ErrServerError.WithWrap(err).WithDebugError(err)) diff --git a/handler/openid/flow_hybrid_test.go b/handler/openid/flow_hybrid_test.go index 4b7a8d76..6413e4b2 100644 --- a/handler/openid/flow_hybrid_test.go +++ b/handler/openid/flow_hybrid_test.go @@ -339,7 +339,7 @@ func TestHybrid_HandleAuthorizeEndpointRequest(t *testing.T) { _, err := jwt.UnsafeParseSignedAny(idToken, claims) require.NoError(t, err) - internal.RequireEqualTime(t, time.Now().Add(*internal.TestLifespans.ImplicitGrantIDTokenLifespan), claims.ExpiresAt, time.Minute) + internal.RequireEqualTime(t, time.Now().Add(*internal.TestLifespans.ImplicitGrantIDTokenLifespan), claims.GetExpirationTimeSafe(), time.Minute) assert.NotEmpty(t, claims.CodeHash) assert.Empty(t, claims.StateHash) diff --git a/handler/openid/flow_refresh_token.go b/handler/openid/flow_refresh_token.go index 80563849..97699187 100644 --- a/handler/openid/flow_refresh_token.go +++ b/handler/openid/flow_refresh_token.go @@ -12,6 +12,7 @@ import ( "authelia.com/provider/oauth2" "authelia.com/provider/oauth2/internal/consts" + "authelia.com/provider/oauth2/token/jwt" "authelia.com/provider/oauth2/x/errorsx" ) @@ -49,7 +50,7 @@ func (c *OpenIDConnectRefreshHandler) HandleTokenEndpointRequest(ctx context.Con } // We need to reset the expires at value as this would be the previous expiry. - sess.IDTokenClaims().ExpiresAt = time.Time{} + sess.IDTokenClaims().ExpirationTime = jwt.NewNumericDate(time.Time{}) // These will be recomputed in PopulateTokenEndpointResponse sess.IDTokenClaims().JTI = "" @@ -92,7 +93,7 @@ func (c *OpenIDConnectRefreshHandler) PopulateTokenEndpointResponse(ctx context. claims.AccessTokenHash = c.GetAccessTokenHash(ctx, requester, responder) claims.JTI = uuid.New().String() claims.CodeHash = "" - claims.IssuedAt = time.Now().Truncate(time.Second) + claims.IssuedAt = jwt.Now() idTokenLifespan := oauth2.GetEffectiveLifespan(requester.GetClient(), oauth2.GrantTypeRefreshToken, oauth2.IDToken, c.Config.GetIDTokenLifespan(ctx)) return c.IssueExplicitIDToken(ctx, idTokenLifespan, requester, responder) diff --git a/handler/openid/strategy_jwt.go b/handler/openid/strategy_jwt.go index 646ff1b4..52802788 100644 --- a/handler/openid/strategy_jwt.go +++ b/handler/openid/strategy_jwt.go @@ -43,7 +43,7 @@ type DefaultSession struct { func NewDefaultSession() *DefaultSession { return &DefaultSession{ Claims: &jwt.IDTokenClaims{ - RequestedAt: time.Now().UTC(), + RequestedAt: jwt.Now(), }, Headers: &jwt.Headers{}, } @@ -121,7 +121,7 @@ type DefaultStrategy struct { // GenerateIDToken returns a JWT string. // -// lifespan is ignored if requester.GetSession().IDTokenClaims().ExpiresAt is not zero. +// lifespan is ignored if requester.GetSession().IDTokenClaims().ExpirationTime is not zero. // // TODO: Refactor time permitting. // @@ -151,38 +151,38 @@ func (h DefaultStrategy) GenerateIDToken(ctx context.Context, lifespan time.Dura } // Adds a bit of wiggle room for timing issues - if claims.AuthTime.After(time.Now().UTC().Add(time.Second * 5)) { + if claims.GetAuthTimeSafe().After(time.Now().UTC().Add(time.Second * 5)) { return "", errorsx.WithStack(oauth2.ErrServerError.WithDebug("Failed to validate OpenID Connect request because authentication time is in the future.")) } if maxAge > 0 { switch { - case claims.AuthTime.IsZero(): + case claims.AuthTime == nil, claims.AuthTime.IsZero(): return "", errorsx.WithStack(oauth2.ErrServerError.WithDebug("Failed to generate id token because authentication time claim is required when max_age is set.")) - case claims.RequestedAt.IsZero(): + case claims.RequestedAt == nil, claims.RequestedAt.IsZero(): return "", errorsx.WithStack(oauth2.ErrServerError.WithDebug("Failed to generate id token because requested at claim is required when max_age is set.")) - case claims.AuthTime.Add(time.Second * time.Duration(maxAge)).Before(claims.RequestedAt): + case claims.AuthTime.Add(time.Second * time.Duration(maxAge)).Before(claims.RequestedAt.Time): return "", errorsx.WithStack(oauth2.ErrServerError.WithDebug("Failed to generate id token because authentication time does not satisfy max_age time.")) } } prompt := requester.GetRequestForm().Get(consts.FormParameterPrompt) if prompt != "" { - if claims.AuthTime.IsZero() { + if claims.AuthTime == nil || claims.AuthTime.IsZero() { return "", errorsx.WithStack(oauth2.ErrServerError.WithDebug("Unable to determine validity of prompt parameter because auth_time is missing in id token claims.")) } } switch prompt { case consts.PromptTypeNone: - if !claims.AuthTime.Equal(claims.RequestedAt) && claims.AuthTime.After(claims.RequestedAt) { + if !claims.GetAuthTimeSafe().Equal(claims.GetRequestedAtSafe()) && claims.GetAuthTimeSafe().After(claims.GetRequestedAtSafe()) { return "", errorsx.WithStack(oauth2.ErrServerError. - WithDebugf("Failed to generate id token because prompt was set to 'none' but auth_time ('%s') happened after the authorization request ('%s') was registered, indicating that the user was logged in during this request which is not allowed.", claims.AuthTime, claims.RequestedAt)) + WithDebugf("Failed to generate id token because prompt was set to 'none' but auth_time ('%s') happened after the authorization request ('%s') was registered, indicating that the user was logged in during this request which is not allowed.", claims.GetAuthTimeSafe(), claims.GetRequestedAtSafe())) } case consts.PromptTypeLogin: - if !claims.AuthTime.Equal(claims.RequestedAt) && claims.AuthTime.Before(claims.RequestedAt) { + if !claims.GetAuthTimeSafe().Equal(claims.GetRequestedAtSafe()) && claims.GetAuthTimeSafe().Before(claims.GetRequestedAtSafe()) { return "", errorsx.WithStack(oauth2.ErrServerError. - WithDebugf("Failed to generate id token because prompt was set to 'login' but auth_time ('%s') happened before the authorization request ('%s') was registered, indicating that the user was not re-authenticated which is forbidden.", claims.AuthTime, claims.RequestedAt)) + WithDebugf("Failed to generate id token because prompt was set to 'login' but auth_time ('%s') happened before the authorization request ('%s') was registered, indicating that the user was not re-authenticated which is forbidden.", claims.GetAuthTimeSafe(), claims.GetRequestedAtSafe())) } } @@ -214,16 +214,16 @@ func (h DefaultStrategy) GenerateIDToken(ctx context.Context, lifespan time.Dura } } - if claims.ExpiresAt.IsZero() { - claims.ExpiresAt = time.Now().UTC().Add(lifespan) + if claims.ExpirationTime == nil || claims.ExpirationTime.IsZero() { + claims.ExpirationTime = jwt.NewNumericDate(time.Now().Add(lifespan)) } - if claims.ExpiresAt.Before(time.Now().UTC()) { + if claims.ExpirationTime.Before(time.Now().UTC()) { return "", errorsx.WithStack(oauth2.ErrServerError.WithDebug("Failed to generate id token because expiry claim can not be in the past.")) } - if claims.AuthTime.IsZero() { - claims.AuthTime = time.Now().Truncate(time.Second).UTC() + if claims.AuthTime == nil || claims.AuthTime.IsZero() { + claims.AuthTime = jwt.Now() } if claims.Issuer == "" { @@ -240,7 +240,7 @@ func (h DefaultStrategy) GenerateIDToken(ctx context.Context, lifespan time.Dura } claims.Audience = stringslice.Unique(append(claims.Audience, requester.GetClient().GetID())) - claims.IssuedAt = time.Now().UTC() + claims.IssuedAt = jwt.Now() token, _, err = h.Strategy.Encode(ctx, claims.ToMapClaims(), jwt.WithHeaders(sess.IDTokenHeaders()), jwt.WithClient(jwtClient)) diff --git a/handler/openid/strategy_jwt_test.go b/handler/openid/strategy_jwt_test.go index 4b1377c5..d0c479a3 100644 --- a/handler/openid/strategy_jwt_test.go +++ b/handler/openid/strategy_jwt_test.go @@ -52,8 +52,8 @@ func TestJWTStrategy_GenerateIDToken(t *testing.T) { req = oauth2.NewAccessRequest(&DefaultSession{ Claims: &jwt.IDTokenClaims{ Subject: "peter", - AuthTime: time.Now().UTC(), - RequestedAt: time.Now().UTC(), + AuthTime: jwt.Now(), + RequestedAt: jwt.Now(), }, Headers: &jwt.Headers{}, }) @@ -66,8 +66,8 @@ func TestJWTStrategy_GenerateIDToken(t *testing.T) { setup: func() { req = oauth2.NewAccessRequest(&DefaultSession{ Claims: &jwt.IDTokenClaims{ - Subject: "peter", - ExpiresAt: time.Now().UTC().Add(-time.Hour), + Subject: "peter", + ExpirationTime: jwt.NewNumericDate(time.Now().Add(-time.Hour)), }, Headers: &jwt.Headers{}, }) @@ -115,8 +115,8 @@ func TestJWTStrategy_GenerateIDToken(t *testing.T) { req = oauth2.NewAccessRequest(&DefaultSession{ Claims: &jwt.IDTokenClaims{ Subject: "peter", - AuthTime: time.Now().UTC(), - RequestedAt: time.Now().UTC(), + AuthTime: jwt.Now(), + RequestedAt: jwt.Now(), }, Headers: &jwt.Headers{}, }) @@ -130,7 +130,7 @@ func TestJWTStrategy_GenerateIDToken(t *testing.T) { req = oauth2.NewAccessRequest(&DefaultSession{ Claims: &jwt.IDTokenClaims{ Subject: "peter", - AuthTime: time.Now().Add(-time.Hour).UTC(), + AuthTime: jwt.NewNumericDate(time.Now().Add(-time.Hour)), }, Headers: &jwt.Headers{}, }) @@ -144,8 +144,8 @@ func TestJWTStrategy_GenerateIDToken(t *testing.T) { req = oauth2.NewAccessRequest(&DefaultSession{ Claims: &jwt.IDTokenClaims{ Subject: "peter", - AuthTime: time.Now().UTC(), - RequestedAt: time.Now().Add(-time.Minute), + AuthTime: jwt.Now(), + RequestedAt: jwt.NewNumericDate(time.Now().Add(-time.Minute)), }, Headers: &jwt.Headers{}, }) @@ -159,8 +159,8 @@ func TestJWTStrategy_GenerateIDToken(t *testing.T) { req = oauth2.NewAccessRequest(&DefaultSession{ Claims: &jwt.IDTokenClaims{ Subject: "peter", - AuthTime: time.Now().UTC(), - RequestedAt: time.Now().Add(-time.Minute), + AuthTime: jwt.Now(), + RequestedAt: jwt.NewNumericDate(time.Now().Add(-time.Minute)), }, Headers: &jwt.Headers{}, }) @@ -175,8 +175,8 @@ func TestJWTStrategy_GenerateIDToken(t *testing.T) { req = oauth2.NewAccessRequest(&DefaultSession{ Claims: &jwt.IDTokenClaims{ Subject: "peter", - AuthTime: time.Now().Add(-time.Hour).UTC(), - RequestedAt: time.Now().Add(-time.Minute), + AuthTime: jwt.NewNumericDate(time.Now().Add(-time.Hour)), + RequestedAt: jwt.NewNumericDate(time.Now().Add(-time.Minute)), }, Headers: &jwt.Headers{}, }) @@ -190,8 +190,8 @@ func TestJWTStrategy_GenerateIDToken(t *testing.T) { req = oauth2.NewAccessRequest(&DefaultSession{ Claims: &jwt.IDTokenClaims{ Subject: "peter", - AuthTime: time.Now().UTC(), - RequestedAt: time.Now().Add(-time.Minute), + AuthTime: jwt.Now(), + RequestedAt: jwt.NewNumericDate(time.Now().Add(-time.Minute)), }, Headers: &jwt.Headers{}, }) @@ -205,8 +205,8 @@ func TestJWTStrategy_GenerateIDToken(t *testing.T) { req = oauth2.NewAccessRequest(&DefaultSession{ Claims: &jwt.IDTokenClaims{ Subject: "peter", - AuthTime: time.Now().Add(-time.Hour).UTC(), - RequestedAt: time.Now().Add(-time.Minute), + AuthTime: jwt.NewNumericDate(time.Now().Add(-time.Hour)), + RequestedAt: jwt.NewNumericDate(time.Now().Add(-time.Minute)), }, Headers: &jwt.Headers{}, }) @@ -220,8 +220,8 @@ func TestJWTStrategy_GenerateIDToken(t *testing.T) { req = oauth2.NewAccessRequest(&DefaultSession{ Claims: &jwt.IDTokenClaims{ Subject: "peter", - AuthTime: time.Now().Add(-time.Hour).UTC(), - RequestedAt: time.Now().Add(-time.Minute), + AuthTime: jwt.NewNumericDate(time.Now().Add(-time.Hour)), + RequestedAt: jwt.NewNumericDate(time.Now().Add(-time.Minute)), }, Headers: &jwt.Headers{}, }) @@ -241,15 +241,15 @@ func TestJWTStrategy_GenerateIDToken(t *testing.T) { req = oauth2.NewAccessRequest(&DefaultSession{ Claims: &jwt.IDTokenClaims{ Subject: "peter", - AuthTime: time.Now().Add(-time.Hour).UTC(), - RequestedAt: time.Now().Add(-time.Minute), + AuthTime: jwt.NewNumericDate(time.Now().Add(-time.Hour)), + RequestedAt: jwt.NewNumericDate(time.Now().Add(-time.Minute)), }, Headers: &jwt.Headers{}, }) token, _ := j.GenerateIDToken(context.TODO(), time.Duration(0), oauth2.NewAccessRequest(&DefaultSession{ Claims: &jwt.IDTokenClaims{ - Subject: "peter", - ExpiresAt: time.Now().Add(-time.Hour).UTC(), + Subject: "peter", + ExpirationTime: jwt.NewNumericDate(time.Now().Add(-time.Hour)), }, Headers: &jwt.Headers{}, })) @@ -263,8 +263,8 @@ func TestJWTStrategy_GenerateIDToken(t *testing.T) { req = oauth2.NewAccessRequest(&DefaultSession{ Claims: &jwt.IDTokenClaims{ Subject: "peter", - AuthTime: time.Now().Add(-time.Hour).UTC(), - RequestedAt: time.Now().Add(-time.Minute), + AuthTime: jwt.NewNumericDate(time.Now().Add(-time.Hour)), + RequestedAt: jwt.NewNumericDate(time.Now().Add(-time.Minute)), }, Headers: &jwt.Headers{}, }) diff --git a/handler/openid/validator.go b/handler/openid/validator.go index 85c3789d..0ad6441a 100644 --- a/handler/openid/validator.go +++ b/handler/openid/validator.go @@ -108,34 +108,34 @@ func (v *OpenIDConnectRequestValidator) ValidatePrompt(ctx context.Context, req } // Adds a bit of wiggle room for timing issues - if claims.AuthTime.After(time.Now().UTC().Add(time.Second * 5)) { + if claims.GetAuthTimeSafe().After(time.Now().UTC().Add(time.Second * 5)) { return errorsx.WithStack(oauth2.ErrServerError.WithDebug("Failed to validate OpenID Connect request because authentication time is in the future.")) } if maxAge > 0 { switch { - case claims.AuthTime.IsZero(): + case claims.AuthTime == nil, claims.AuthTime.IsZero(): return errorsx.WithStack(oauth2.ErrServerError.WithDebug("Failed to validate OpenID Connect request because authentication time claim is required when max_age is set.")) - case claims.RequestedAt.IsZero(): + case claims.RequestedAt == nil, claims.RequestedAt.IsZero(): return errorsx.WithStack(oauth2.ErrServerError.WithDebug("Failed to validate OpenID Connect request because requested at claim is required when max_age is set.")) - case claims.AuthTime.Add(time.Second * time.Duration(maxAge)).Before(claims.RequestedAt): + case claims.GetAuthTimeSafe().Add(time.Second * time.Duration(maxAge)).Before(claims.GetRequestedAtSafe()): return errorsx.WithStack(oauth2.ErrLoginRequired.WithDebug("Failed to validate OpenID Connect request because authentication time does not satisfy max_age time.")) } } if stringslice.Has(requiredPrompt, consts.PromptTypeNone) { - if claims.AuthTime.IsZero() { + if claims.AuthTime == nil || claims.AuthTime.IsZero() { return errorsx.WithStack(oauth2.ErrServerError.WithDebug("Failed to validate OpenID Connect request because because auth_time is missing from session.")) } - if !claims.AuthTime.Equal(claims.RequestedAt) && claims.AuthTime.After(claims.RequestedAt) { + if !claims.GetAuthTimeSafe().Equal(claims.GetRequestedAtSafe()) && claims.GetAuthTimeSafe().After(claims.GetRequestedAtSafe()) { // !claims.AuthTime.Truncate(time.Second).Equal(claims.RequestedAt) && claims.AuthTime.Truncate(time.Second).Before(claims.RequestedAt) { - return errorsx.WithStack(oauth2.ErrLoginRequired.WithHintf("Failed to validate OpenID Connect request because prompt was set to 'none' but auth_time ('%s') happened after the authorization request ('%s') was registered, indicating that the user was logged in during this request which is not allowed.", claims.AuthTime, claims.RequestedAt)) + return errorsx.WithStack(oauth2.ErrLoginRequired.WithHintf("Failed to validate OpenID Connect request because prompt was set to 'none' but auth_time ('%s') happened after the authorization request ('%s') was registered, indicating that the user was logged in during this request which is not allowed.", claims.GetAuthTimeSafe(), claims.GetRequestedAtSafe())) } } if stringslice.Has(requiredPrompt, consts.PromptTypeLogin) { - if claims.AuthTime.Before(claims.RequestedAt) { - return errorsx.WithStack(oauth2.ErrLoginRequired.WithHintf("Failed to validate OpenID Connect request because prompt was set to 'login' but auth_time ('%s') happened before the authorization request ('%s') was registered, indicating that the user was not re-authenticated which is forbidden.", claims.AuthTime, claims.RequestedAt)) + if claims.GetAuthTimeSafe().Before(claims.GetRequestedAtSafe()) { + return errorsx.WithStack(oauth2.ErrLoginRequired.WithHintf("Failed to validate OpenID Connect request because prompt was set to 'login' but auth_time ('%s') happened before the authorization request ('%s') was registered, indicating that the user was not re-authenticated which is forbidden.", claims.GetAuthTimeSafe(), claims.GetRequestedAtSafe())) } } diff --git a/handler/openid/validator_test.go b/handler/openid/validator_test.go index c493cf40..a1864c89 100644 --- a/handler/openid/validator_test.go +++ b/handler/openid/validator_test.go @@ -59,8 +59,8 @@ func TestValidatePrompt(t *testing.T) { Subject: "foo", Claims: &jwt.IDTokenClaims{ Subject: "foo", - RequestedAt: time.Now().UTC(), - AuthTime: time.Now().UTC().Add(-time.Minute), + RequestedAt: jwt.Now(), + AuthTime: jwt.NewNumericDate(time.Now().Add(-time.Minute)), }, }, }, @@ -74,8 +74,8 @@ func TestValidatePrompt(t *testing.T) { Subject: "foo", Claims: &jwt.IDTokenClaims{ Subject: "foo", - RequestedAt: time.Now().UTC(), - AuthTime: time.Now().UTC().Add(-time.Minute), + RequestedAt: jwt.Now(), + AuthTime: jwt.NewNumericDate(time.Now().Add(-time.Minute)), }, }, }, @@ -89,8 +89,8 @@ func TestValidatePrompt(t *testing.T) { Subject: "foo", Claims: &jwt.IDTokenClaims{ Subject: "foo", - RequestedAt: time.Now().UTC(), - AuthTime: time.Now().UTC().Add(-time.Minute), + RequestedAt: jwt.Now(), + AuthTime: jwt.NewNumericDate(time.Now().Add(-time.Minute)), }, }, }, @@ -103,7 +103,7 @@ func TestValidatePrompt(t *testing.T) { Subject: "foo", Claims: &jwt.IDTokenClaims{ Subject: "foo", - RequestedAt: time.Now().UTC(), + RequestedAt: jwt.Now(), }, }, }, @@ -116,8 +116,8 @@ func TestValidatePrompt(t *testing.T) { Subject: "foo", Claims: &jwt.IDTokenClaims{ Subject: "foo", - RequestedAt: time.Now().UTC().Add(-time.Minute), - AuthTime: time.Now().UTC(), + RequestedAt: jwt.NewNumericDate(time.Now().Add(-time.Minute)), + AuthTime: jwt.Now(), }, }, }, @@ -130,8 +130,8 @@ func TestValidatePrompt(t *testing.T) { Subject: "foo", Claims: &jwt.IDTokenClaims{ Subject: "foo", - RequestedAt: time.Now().UTC(), - AuthTime: time.Now().UTC().Add(-time.Minute), + RequestedAt: jwt.Now(), + AuthTime: jwt.NewNumericDate(time.Now().Add(-time.Minute)), }, }, }, @@ -144,8 +144,8 @@ func TestValidatePrompt(t *testing.T) { Subject: "foo", Claims: &jwt.IDTokenClaims{ Subject: "foo", - RequestedAt: time.Now().UTC(), - AuthTime: time.Now().UTC(), + RequestedAt: jwt.Now(), + AuthTime: jwt.Now(), }, }, }, @@ -158,8 +158,8 @@ func TestValidatePrompt(t *testing.T) { Subject: "foo", Claims: &jwt.IDTokenClaims{ Subject: "foo", - RequestedAt: time.Now().UTC(), - AuthTime: time.Now().UTC(), + RequestedAt: jwt.Now(), + AuthTime: jwt.Now(), }, }, }, @@ -172,8 +172,8 @@ func TestValidatePrompt(t *testing.T) { Subject: "foo", Claims: &jwt.IDTokenClaims{ Subject: "foo", - RequestedAt: time.Now().UTC().Add(-time.Second * 5), - AuthTime: time.Now().UTC().Add(-time.Second), + RequestedAt: jwt.NewNumericDate(time.Now().Add(-time.Second * 5)), + AuthTime: jwt.NewNumericDate(time.Now().Add(-time.Second)), }, }, }, @@ -186,8 +186,8 @@ func TestValidatePrompt(t *testing.T) { Subject: "foo", Claims: &jwt.IDTokenClaims{ Subject: "foo", - RequestedAt: time.Now().UTC().Add(-time.Second * 5), - AuthTime: time.Now().UTC().Add(-time.Second), + RequestedAt: jwt.NewNumericDate(time.Now().Add(-time.Second * 5)), + AuthTime: jwt.NewNumericDate(time.Now().Add(-time.Second)), }, }, }, @@ -200,14 +200,14 @@ func TestValidatePrompt(t *testing.T) { Subject: "foo", Claims: &jwt.IDTokenClaims{ Subject: "foo", - RequestedAt: time.Now().UTC(), - AuthTime: time.Now().UTC().Add(-time.Second), + RequestedAt: jwt.Now(), + AuthTime: jwt.NewNumericDate(time.Now().Add(-time.Second)), }, }, idTokenHint: genIDToken(jwt.IDTokenClaims{ - Subject: "bar", - RequestedAt: time.Now(), - ExpiresAt: time.Now().Add(time.Hour), + Subject: "bar", + RequestedAt: jwt.Now(), + ExpirationTime: jwt.NewNumericDate(time.Now().Add(time.Hour)), }), }, { @@ -219,14 +219,14 @@ func TestValidatePrompt(t *testing.T) { Subject: "foo", Claims: &jwt.IDTokenClaims{ Subject: "foo", - RequestedAt: time.Now().UTC(), - AuthTime: time.Now().UTC().Add(-time.Second), + RequestedAt: jwt.Now(), + AuthTime: jwt.NewNumericDate(time.Now().Add(-time.Second)), }, }, idTokenHint: genIDToken(jwt.IDTokenClaims{ - Subject: "foo", - RequestedAt: time.Now(), - ExpiresAt: time.Now().Add(time.Hour), + Subject: "foo", + RequestedAt: jwt.Now(), + ExpirationTime: jwt.NewNumericDate(time.Now().Add(time.Hour)), }), }, { @@ -237,16 +237,16 @@ func TestValidatePrompt(t *testing.T) { s: &DefaultSession{ Subject: "foo", Claims: &jwt.IDTokenClaims{ - Subject: "foo", - RequestedAt: time.Now().UTC(), - AuthTime: time.Now().UTC().Add(-time.Second), - ExpiresAt: time.Now().UTC().Add(-time.Second), + Subject: "foo", + RequestedAt: jwt.Now(), + AuthTime: jwt.NewNumericDate(time.Now().Add(-time.Second)), + ExpirationTime: jwt.NewNumericDate(time.Now().Add(-time.Second)), }, }, idTokenHint: genIDToken(jwt.IDTokenClaims{ - Subject: "foo", - RequestedAt: time.Now(), - ExpiresAt: time.Now().Add(time.Hour), + Subject: "foo", + RequestedAt: jwt.Now(), + ExpirationTime: jwt.NewNumericDate(time.Now().Add(time.Hour)), }), }, } { diff --git a/handler/rfc8628/device_authorize_handler.go b/handler/rfc8628/device_authorize_handler.go index 8c879657..dad35955 100644 --- a/handler/rfc8628/device_authorize_handler.go +++ b/handler/rfc8628/device_authorize_handler.go @@ -7,6 +7,7 @@ import ( "authelia.com/provider/oauth2" "authelia.com/provider/oauth2/internal/consts" + "authelia.com/provider/oauth2/token/jwt" "authelia.com/provider/oauth2/x/errorsx" ) @@ -38,7 +39,7 @@ func (d *DeviceAuthorizeHandler) HandleRFC8628DeviceAuthorizeEndpointRequest(ctx dar.SetDeviceCodeSignature(deviceCodeSignature) dar.SetUserCodeSignature(userCodeSignature) - expireAt := time.Now().UTC().Add(d.Config.GetRFC8628CodeLifespan(ctx)).Round(time.Second) + expireAt := time.Now().UTC().Add(d.Config.GetRFC8628CodeLifespan(ctx)).Truncate(jwt.TimePrecision) session.SetExpiresAt(oauth2.DeviceCode, expireAt) session.SetExpiresAt(oauth2.UserCode, expireAt) diff --git a/handler/rfc8628/token_endpoint_handler_test.go b/handler/rfc8628/token_endpoint_handler_test.go index 90e1703d..ecdc0e84 100644 --- a/handler/rfc8628/token_endpoint_handler_test.go +++ b/handler/rfc8628/token_endpoint_handler_test.go @@ -19,6 +19,7 @@ import ( "authelia.com/provider/oauth2/storage" "authelia.com/provider/oauth2/testing/mock" "authelia.com/provider/oauth2/token/hmac" + "authelia.com/provider/oauth2/token/jwt" ) var o2hmacshaStrategy = hoauth2.HMACCoreStrategy{ @@ -428,16 +429,16 @@ func TestDeviceAuthorizeCode_HandleTokenEndpointRequest(t *testing.T) { }, }, check: func(t *testing.T, areq *oauth2.AccessRequest, authreq *oauth2.DeviceAuthorizeRequest) { - assert.Equal(t, time.Now().Add(time.Minute).UTC().Round(time.Second), areq.GetSession().GetExpiresAt(oauth2.AccessToken)) - assert.Equal(t, time.Now().Add(time.Minute).UTC().Round(time.Second), areq.GetSession().GetExpiresAt(oauth2.RefreshToken)) + assert.Equal(t, time.Now().Add(time.Minute).UTC().Truncate(jwt.TimePrecision), areq.GetSession().GetExpiresAt(oauth2.AccessToken)) + assert.Equal(t, time.Now().Add(time.Minute).UTC().Truncate(jwt.TimePrecision), areq.GetSession().GetExpiresAt(oauth2.RefreshToken)) }, setup: func(t *testing.T, areq *oauth2.AccessRequest, authreq *oauth2.DeviceAuthorizeRequest) { authreq = oauth2.NewDeviceAuthorizeRequest() authreq.SetSession(openid.NewDefaultSession()) authreq.GetSession().SetExpiresAt(oauth2.UserCode, - time.Now().Add(-time.Hour).UTC().Round(time.Second)) + time.Now().Add(-time.Hour).UTC().Truncate(jwt.TimePrecision)) authreq.GetSession().SetExpiresAt(oauth2.DeviceCode, - time.Now().Add(-time.Hour).UTC().Round(time.Second)) + time.Now().Add(-time.Hour).UTC().Truncate(jwt.TimePrecision)) dCode, dSig, err := strategy.GenerateRFC8628DeviceCode(context.TODO()) require.NoError(t, err) _, uSig, err := strategy.GenerateRFC8628UserCode(context.TODO()) diff --git a/handler/rfc8628/user_authorize_handler_test.go b/handler/rfc8628/user_authorize_handler_test.go index c02703d8..7240a072 100644 --- a/handler/rfc8628/user_authorize_handler_test.go +++ b/handler/rfc8628/user_authorize_handler_test.go @@ -16,6 +16,7 @@ import ( "authelia.com/provider/oauth2/handler/openid" . "authelia.com/provider/oauth2/handler/rfc8628" "authelia.com/provider/oauth2/storage" + "authelia.com/provider/oauth2/token/jwt" ) func TestUserAuthorizeHandler_PopulateRFC8628UserAuthorizeEndpointResponse(t *testing.T) { @@ -37,7 +38,7 @@ func TestUserAuthorizeHandler_PopulateRFC8628UserAuthorizeEndpointResponse(t *te dar.SetSession(openid.NewDefaultSession()) dar.GetSession().SetExpiresAt(oauth2.UserCode, time.Now().UTC().Add( - f.Config.GetRFC8628CodeLifespan(a.ctx)).Round(time.Second)) + f.Config.GetRFC8628CodeLifespan(a.ctx)).Truncate(jwt.TimePrecision)) code, sig, err := f.Strategy.GenerateRFC8628UserCode(a.ctx) require.NoError(t, err) dar.SetUserCodeSignature(sig) @@ -243,7 +244,7 @@ func TestUserAuthorizeHandler_PopulateRFC8628UserAuthorizeEndpointResponse_Handl dar.SetSession(openid.NewDefaultSession()) dar.GetSession().SetExpiresAt(oauth2.UserCode, time.Now().UTC().Add( - f.Config.GetRFC8628CodeLifespan(a.ctx)).Round(time.Second)) + f.Config.GetRFC8628CodeLifespan(a.ctx)).Truncate(jwt.TimePrecision)) code, sig, err := f.Strategy.GenerateRFC8628UserCode(a.ctx) require.NoError(t, err) dar.SetUserCodeSignature(sig) @@ -400,7 +401,7 @@ func TestUserAuthorizeHandler_PopulateRFC8628UserAuthorizeEndpointResponse_Handl dar.SetSession(openid.NewDefaultSession()) dar.GetSession().SetExpiresAt(oauth2.UserCode, time.Now().UTC().Add( - f.Config.GetRFC8628CodeLifespan(a.ctx)).Round(time.Second)) + f.Config.GetRFC8628CodeLifespan(a.ctx)).Truncate(jwt.TimePrecision)) code, sig, err := f.Strategy.GenerateRFC8628UserCode(a.ctx) require.NoError(t, err) dar.SetUserCodeSignature(sig) @@ -443,7 +444,7 @@ func TestUserAuthorizeHandler_PopulateRFC8628UserAuthorizeEndpointResponse_Handl dar.SetSession(openid.NewDefaultSession()) dar.GetSession().SetExpiresAt(oauth2.UserCode, time.Now().UTC().Add(time.Duration(-1)* - f.Config.GetRFC8628CodeLifespan(a.ctx)).Round(time.Second)) + f.Config.GetRFC8628CodeLifespan(a.ctx)).Truncate(jwt.TimePrecision)) code, sig, err := f.Strategy.GenerateRFC8628UserCode(a.ctx) require.NoError(t, err) dar.SetUserCodeSignature(sig) diff --git a/handler/rfc8693/access_token_type_handler.go b/handler/rfc8693/access_token_type_handler.go index 7e83ae76..b03551c9 100644 --- a/handler/rfc8693/access_token_type_handler.go +++ b/handler/rfc8693/access_token_type_handler.go @@ -10,6 +10,7 @@ import ( hoauth2 "authelia.com/provider/oauth2/handler/oauth2" "authelia.com/provider/oauth2/internal/consts" "authelia.com/provider/oauth2/storage" + "authelia.com/provider/oauth2/token/jwt" "authelia.com/provider/oauth2/x/errorsx" ) @@ -161,7 +162,7 @@ func (c *AccessTokenTypeHandler) issue(ctx context.Context, request oauth2.Acces issueRefreshToken := c.canIssueRefreshToken(request) if issueRefreshToken { - request.GetSession().SetExpiresAt(oauth2.RefreshToken, time.Now().UTC().Add(c.RefreshTokenLifespan).Round(time.Second)) + request.GetSession().SetExpiresAt(oauth2.RefreshToken, time.Now().UTC().Add(c.RefreshTokenLifespan).Truncate(jwt.TimePrecision)) refresh, refreshSignature, err := c.CoreStrategy.GenerateRefreshToken(ctx, request) if err != nil { return errors.WithStack(oauth2.ErrServerError.WithDebugError(err)) diff --git a/handler/rfc8693/custom_jwt_type_handler.go b/handler/rfc8693/custom_jwt_type_handler.go index 18b4f392..2fb274c2 100644 --- a/handler/rfc8693/custom_jwt_type_handler.go +++ b/handler/rfc8693/custom_jwt_type_handler.go @@ -176,8 +176,8 @@ func (c *CustomJWTTypeHandler) issue(ctx context.Context, request oauth2.AccessR claims.Subject = request.GetClient().GetID() } - if claims.ExpiresAt.IsZero() { - claims.ExpiresAt = time.Now().UTC().Add(jwtType.Expiry) + if claims.ExpirationTime == nil || claims.ExpirationTime.IsZero() { + claims.ExpirationTime = jwt.NewNumericDate(time.Now().Add(jwtType.Expiry)) } if claims.Issuer == "" { @@ -201,7 +201,7 @@ func (c *CustomJWTTypeHandler) issue(ctx context.Context, request oauth2.AccessR claims.JTI = uuid.New().String() } - claims.IssuedAt = time.Now().UTC() + claims.IssuedAt = jwt.Now() token, _, err := c.Strategy.Encode(ctx, claims.ToMapClaims(), jwt.WithHeaders(sess.IDTokenHeaders()), jwt.WithIDTokenClient(request.GetClient())) if err != nil { @@ -210,7 +210,7 @@ func (c *CustomJWTTypeHandler) issue(ctx context.Context, request oauth2.AccessR response.SetAccessToken(token) response.SetTokenType("N_A") - response.SetExpiresIn(time.Duration(claims.ExpiresAt.UnixNano() - time.Now().UTC().UnixNano())) + response.SetExpiresIn(time.Duration(claims.GetExpirationTimeSafe().UnixNano() - time.Now().UTC().UnixNano())) return nil } diff --git a/handler/rfc8693/refresh_token_type_handler.go b/handler/rfc8693/refresh_token_type_handler.go index c7e673b2..3e31f251 100644 --- a/handler/rfc8693/refresh_token_type_handler.go +++ b/handler/rfc8693/refresh_token_type_handler.go @@ -10,6 +10,7 @@ import ( hoauth2 "authelia.com/provider/oauth2/handler/oauth2" "authelia.com/provider/oauth2/internal/consts" "authelia.com/provider/oauth2/storage" + "authelia.com/provider/oauth2/token/jwt" "authelia.com/provider/oauth2/x/errorsx" ) @@ -148,7 +149,7 @@ func (c *RefreshTokenTypeHandler) validate(ctx context.Context, request oauth2.A } func (c *RefreshTokenTypeHandler) issue(ctx context.Context, request oauth2.AccessRequester, response oauth2.AccessResponder) error { - request.GetSession().SetExpiresAt(oauth2.RefreshToken, time.Now().UTC().Add(c.RefreshTokenLifespan).Round(time.Second)) + request.GetSession().SetExpiresAt(oauth2.RefreshToken, time.Now().UTC().Add(c.RefreshTokenLifespan).Truncate(jwt.TimePrecision)) refresh, refreshSignature, err := c.CoreStrategy.GenerateRefreshToken(ctx, request) if err != nil { return errors.WithStack(oauth2.ErrServerError.WithDebugError(err)) diff --git a/integration/oidc_explicit_test.go b/integration/oidc_explicit_test.go index 77c2ec65..e71a535c 100644 --- a/integration/oidc_explicit_test.go +++ b/integration/oidc_explicit_test.go @@ -105,8 +105,8 @@ func TestOpenIDConnectExplicitFlow(t *testing.T) { { session: newIDSession(&jwt.IDTokenClaims{ Subject: "peter", - RequestedAt: time.Now().UTC(), - AuthTime: time.Now().Add(time.Second).UTC(), + RequestedAt: jwt.Now(), + AuthTime: jwt.NewNumericDate(time.Now().Add(time.Second)), }), description: "should not pass missing redirect uri", setup: func(oauthClient *xoauth2.Config) string { @@ -130,8 +130,8 @@ func TestOpenIDConnectExplicitFlow(t *testing.T) { { session: newIDSession(&jwt.IDTokenClaims{ Subject: "peter", - RequestedAt: time.Now().UTC(), - AuthTime: time.Now().Add(time.Second).UTC(), + RequestedAt: jwt.Now(), + AuthTime: jwt.NewNumericDate(time.Now().Add(time.Second)), }), description: "should pass", setup: func(oauthClient *xoauth2.Config) string { @@ -143,8 +143,8 @@ func TestOpenIDConnectExplicitFlow(t *testing.T) { { session: newIDSession(&jwt.IDTokenClaims{ Subject: "peter", - RequestedAt: time.Now().UTC(), - AuthTime: time.Now().Add(time.Second).UTC(), + RequestedAt: jwt.Now(), + AuthTime: jwt.NewNumericDate(time.Now().Add(time.Second)), }), description: "should not pass missing redirect uri", setup: func(oauthClient *xoauth2.Config) string { @@ -158,8 +158,8 @@ func TestOpenIDConnectExplicitFlow(t *testing.T) { { session: newIDSession(&jwt.IDTokenClaims{ Subject: "peter", - RequestedAt: time.Now().UTC(), - AuthTime: time.Now().Add(time.Second).UTC(), + RequestedAt: jwt.Now(), + AuthTime: jwt.NewNumericDate(time.Now().Add(time.Second)), }), description: "should not pass missing redirect uri", setup: func(oauthClient *xoauth2.Config) string { @@ -173,8 +173,8 @@ func TestOpenIDConnectExplicitFlow(t *testing.T) { { session: newIDSession(&jwt.IDTokenClaims{ Subject: "peter", - RequestedAt: time.Now().UTC(), - AuthTime: time.Now().Add(-time.Minute).UTC(), + RequestedAt: jwt.Now(), + AuthTime: jwt.NewNumericDate(time.Now().Add(-time.Minute)), }), description: "should fail because authentication was in the past", setup: func(oauthClient *xoauth2.Config) string { @@ -187,8 +187,8 @@ func TestOpenIDConnectExplicitFlow(t *testing.T) { { session: newIDSession(&jwt.IDTokenClaims{ Subject: "peter", - RequestedAt: time.Now().UTC(), - AuthTime: time.Now().Add(-time.Minute).UTC(), + RequestedAt: jwt.Now(), + AuthTime: jwt.NewNumericDate(time.Now().Add(-time.Minute)), }), description: "should pass because authorization was in the past and no login was required", setup: func(oauthClient *xoauth2.Config) string { diff --git a/internal/test_helpers.go b/internal/test_helpers.go index b2ce1f64..d09e32c2 100644 --- a/internal/test_helpers.go +++ b/internal/test_helpers.go @@ -67,11 +67,11 @@ func ExtractJwtExpClaim(t *testing.T, token string) *time.Time { require.NoError(t, err) - if claims.ExpiresAt.IsZero() { + if claims.ExpirationTime == nil { return nil } - return &claims.ExpiresAt + return &claims.ExpirationTime.Time } //nolint:gocyclo diff --git a/introspection_response_writer_test.go b/introspection_response_writer_test.go index faf033a7..be35c122 100644 --- a/introspection_response_writer_test.go +++ b/introspection_response_writer_test.go @@ -99,7 +99,7 @@ func TestWriteIntrospectionResponseBody(t *testing.T) { hasExtra: false, }, { - description: "should success for ExpiresAt not set access token", + description: "should success for ExpirationTime not set access token", setup: func() { ires.Active = true ires.TokenUse = AccessToken diff --git a/rfc8628_device_authorize_write_test.go b/rfc8628_device_authorize_write_test.go index 4b7eb5b8..8e307ce6 100644 --- a/rfc8628_device_authorize_write_test.go +++ b/rfc8628_device_authorize_write_test.go @@ -10,6 +10,7 @@ import ( "github.com/stretchr/testify/require" . "authelia.com/provider/oauth2" + "authelia.com/provider/oauth2/token/jwt" ) func TestWriteDeviceAuthorizeResponse(t *testing.T) { @@ -25,10 +26,10 @@ func TestWriteDeviceAuthorizeResponse(t *testing.T) { resp.SetUserCode("AAAA") resp.SetDeviceCode("BBBB") resp.SetInterval(int( - oauth2.Config.GetRFC8628TokenPollingInterval(context.TODO()).Round(time.Second).Seconds(), + oauth2.Config.GetRFC8628TokenPollingInterval(context.TODO()).Truncate(jwt.TimePrecision).Seconds(), )) resp.SetExpiresIn(int64( - time.Now().Round(time.Second).Add(oauth2.Config.GetRFC8628CodeLifespan(context.TODO())).Second(), + time.Now().Truncate(jwt.TimePrecision).Add(oauth2.Config.GetRFC8628CodeLifespan(context.TODO())).Second(), )) resp.SetVerificationURI(oauth2.Config.GetRFC8628UserVerificationURL(context.TODO())) resp.SetVerificationURIComplete( diff --git a/testing/mock/client.go b/testing/mock/client.go index a665c669..7c88ba6c 100644 --- a/testing/mock/client.go +++ b/testing/mock/client.go @@ -42,7 +42,7 @@ func (m *MockClient) EXPECT() *MockClientMockRecorder { // GetAudience mocks base method. func (m *MockClient) GetAudience() oauth2.Arguments { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "GetAudienceX") + ret := m.ctrl.Call(m, "GetAudience") ret0, _ := ret[0].(oauth2.Arguments) return ret0 } @@ -50,7 +50,7 @@ func (m *MockClient) GetAudience() oauth2.Arguments { // GetAudience indicates an expected call of GetAudience. func (mr *MockClientMockRecorder) GetAudience() *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetAudienceX", reflect.TypeOf((*MockClient)(nil).GetAudience)) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetAudience", reflect.TypeOf((*MockClient)(nil).GetAudience)) } // GetClientSecret mocks base method. diff --git a/token/jarm/generate.go b/token/jarm/generate.go index 23321b65..4801fc5e 100644 --- a/token/jarm/generate.go +++ b/token/jarm/generate.go @@ -4,9 +4,6 @@ import ( "context" "errors" "net/url" - "time" - - "github.com/google/uuid" "authelia.com/provider/oauth2/internal/consts" "authelia.com/provider/oauth2/token/jwt" @@ -22,15 +19,15 @@ func EncodeParameters(token, _ string, tErr error) (parameters url.Values, err e } // Generate generates the token and signature for a JARM response. -func Generate(ctx context.Context, config Configurator, client Client, session any, in url.Values) (token, signature string, err error) { +func Generate(ctx context.Context, config Configurator, client Client, session any, parameters url.Values) (token, signature string, err error) { headers := map[string]any{} if alg := client.GetAuthorizationSignedResponseAlg(); len(alg) > 0 { - headers[consts.JSONWebTokenHeaderAlgorithm] = alg + headers[jwt.JSONWebTokenHeaderAlgorithm] = alg } if kid := client.GetAuthorizationSignedResponseKeyID(); len(kid) > 0 { - headers[consts.JSONWebTokenHeaderKeyIdentifier] = kid + headers[jwt.JSONWebTokenHeaderKeyIdentifier] = kid } var issuer string @@ -55,29 +52,22 @@ func Generate(ctx context.Context, config Configurator, client Client, session a return "", "", errors.New("The JARM response modes require the Authorize Requester session to implement either the openid.Session or oauth2.JWTSessionContainer interfaces but it doesn't.") } - if value, ok = src[consts.ClaimIssuer]; ok { + if value, ok = src[jwt.ClaimIssuer]; ok { issuer, _ = value.(string) } } - claims := &jwt.JARMClaims{ - JTI: uuid.New().String(), - Issuer: issuer, - IssuedAt: time.Now().UTC(), - ExpiresAt: time.Now().UTC().Add(config.GetJWTSecuredAuthorizeResponseModeLifespan(ctx)), - Audience: []string{client.GetID()}, - Extra: map[string]any{}, - } + claims := jwt.NewJARMClaims(issuer, jwt.ClaimStrings{client.GetID()}, config.GetJWTSecuredAuthorizeResponseModeLifespan(ctx)) - for param := range in { - claims.Extra[param] = in.Get(param) + for param := range parameters { + claims.Extra[param] = parameters.Get(param) } - var signer jwt.Strategy + var strategy jwt.Strategy - if signer = config.GetJWTSecuredAuthorizeResponseModeStrategy(ctx); signer == nil { + if strategy = config.GetJWTSecuredAuthorizeResponseModeStrategy(ctx); strategy == nil { return "", "", errors.New("The JARM response modes require the JWTSecuredAuthorizeResponseModeSignerProvider to return a jwt.Strategy but it didn't.") } - return signer.Encode(ctx, claims.ToMapClaims(), jwt.WithHeaders(&jwt.Headers{Extra: headers}), jwt.WithJARMClient(client)) + return strategy.Encode(ctx, claims.ToMapClaims(), jwt.WithHeaders(&jwt.Headers{Extra: headers}), jwt.WithJARMClient(client)) } diff --git a/token/jwt/claims_id_token.go b/token/jwt/claims_id_token.go index aa45f8fd..4dec51a2 100644 --- a/token/jwt/claims_id_token.go +++ b/token/jwt/claims_id_token.go @@ -5,6 +5,7 @@ package jwt import ( "bytes" + "errors" "fmt" "time" @@ -22,10 +23,10 @@ type IDTokenClaims struct { Subject string `json:"sub"` Audience []string `json:"aud"` Nonce string `json:"nonce"` - ExpiresAt time.Time `json:"exp"` - IssuedAt time.Time `json:"iat"` - RequestedAt time.Time `json:"rat"` - AuthTime time.Time `json:"auth_time"` + ExpirationTime *NumericDate `json:"exp"` + IssuedAt *NumericDate `json:"iat"` + RequestedAt *NumericDate `json:"rat"` + AuthTime *NumericDate `json:"auth_time"` AccessTokenHash string `json:"at_hash"` AuthenticationContextClassReference string `json:"acr"` AuthenticationMethodsReferences []string `json:"amr"` @@ -34,6 +35,141 @@ type IDTokenClaims struct { Extra map[string]any `json:"ext"` } +func (c *IDTokenClaims) GetExpirationTime() (exp *NumericDate, err error) { + return c.ExpirationTime, nil +} + +func (c *IDTokenClaims) GetIssuedAt() (iat *NumericDate, err error) { + return c.IssuedAt, nil +} + +func (c *IDTokenClaims) GetNotBefore() (nbf *NumericDate, err error) { + return toNumericDate(ClaimNotBefore) +} + +func (c *IDTokenClaims) GetIssuer() (iss string, err error) { + return c.Issuer, nil +} + +func (c *IDTokenClaims) GetSubject() (sub string, err error) { + return c.Subject, nil +} + +func (c *IDTokenClaims) GetAudience() (aud ClaimStrings, err error) { + return c.Audience, nil +} + +func (c IDTokenClaims) Valid(opts ...ClaimValidationOption) (err error) { + vopts := &ClaimValidationOptions{} + + for _, opt := range opts { + opt(vopts) + } + + var now int64 + + if vopts.timef != nil { + now = vopts.timef().UTC().Unix() + } else { + now = TimeFunc().UTC().Unix() + } + + vErr := new(ValidationError) + + var date *NumericDate + + if date, err = c.GetExpirationTime(); !validDate(validInt64Future, now, vopts.expRequired, date, err) { + vErr.Inner = errors.New("Token is expired") + vErr.Errors |= ValidationErrorExpired + } + + if date, err = c.GetIssuedAt(); !validDate(validInt64Past, now, vopts.expRequired, date, err) { + vErr.Inner = errors.New("Token used before issued") + vErr.Errors |= ValidationErrorIssuedAt + } + + if date, err = c.GetNotBefore(); !validDate(validInt64Past, now, vopts.expRequired, date, err) { + vErr.Inner = errors.New("Token is not valid yet") + vErr.Errors |= ValidationErrorNotValidYet + } + + var str string + + if len(vopts.iss) != 0 { + if str, err = c.GetIssuer(); err != nil { + vErr.Inner = errors.New("Token has invalid issuer") + vErr.Errors |= ValidationErrorIssuer + } else if !validString(str, vopts.iss, true) { + vErr.Inner = errors.New("Token has invalid issuer") + vErr.Errors |= ValidationErrorIssuer + } + } + + if len(vopts.sub) != 0 { + if str, err = c.GetSubject(); err != nil { + vErr.Inner = errors.New("Token has invalid subject") + vErr.Errors |= ValidationErrorIssuer + } else if !validString(str, vopts.sub, true) { + vErr.Inner = errors.New("Token has invalid subject") + vErr.Errors |= ValidationErrorSubject + } + } + + var aud ClaimStrings + + if len(vopts.aud) != 0 { + if aud, err = c.GetAudience(); err != nil || aud == nil || !aud.ValidAny(vopts.aud, true) { + vErr.Inner = errors.New("Token has invalid audience") + vErr.Errors |= ValidationErrorAudience + } + } + + if len(vopts.audAll) != 0 { + if aud, err = c.GetAudience(); err != nil || aud == nil || !aud.ValidAll(vopts.audAll, true) { + vErr.Inner = errors.New("Token has invalid audience") + vErr.Errors |= ValidationErrorAudience + } + } + + if vErr.valid() { + return nil + } + + return vErr +} + +func (c *IDTokenClaims) GetExpirationTimeSafe() time.Time { + if c.ExpirationTime == nil { + return time.Unix(0, 0).UTC() + } + + return c.ExpirationTime.UTC() +} + +func (c *IDTokenClaims) GetIssuedAtSafe() time.Time { + if c.IssuedAt == nil { + return time.Unix(0, 0).UTC() + } + + return c.IssuedAt.UTC() +} + +func (c *IDTokenClaims) GetAuthTimeSafe() time.Time { + if c.AuthTime == nil { + return time.Unix(0, 0).UTC() + } + + return c.AuthTime.UTC() +} + +func (c *IDTokenClaims) GetRequestedAtSafe() time.Time { + if c.RequestedAt == nil { + return time.Unix(0, 0).UTC() + } + + return c.RequestedAt.UTC() +} + func (c *IDTokenClaims) UnmarshalJSON(data []byte) error { claims := MapClaims{} @@ -44,7 +180,10 @@ func (c *IDTokenClaims) UnmarshalJSON(data []byte) error { return errorsx.WithStack(err) } - var ok bool + var ( + ok bool + err error + ) for claim, value := range claims { ok = false @@ -58,6 +197,8 @@ func (c *IDTokenClaims) UnmarshalJSON(data []byte) error { c.Subject, ok = value.(string) case ClaimAudience: switch aud := value.(type) { + case nil: + ok = true case string: ok = true @@ -84,13 +225,21 @@ func (c *IDTokenClaims) UnmarshalJSON(data []byte) error { case ClaimNonce: c.Nonce, ok = value.(string) case ClaimExpirationTime: - c.ExpiresAt, ok = toTime(value, c.ExpiresAt) + if c.ExpirationTime, err = toNumericDate(value); err == nil { + ok = true + } case ClaimIssuedAt: - c.IssuedAt, ok = toTime(value, c.IssuedAt) + if c.IssuedAt, err = toNumericDate(value); err == nil { + ok = true + } case ClaimRequestedAt: - c.RequestedAt, ok = toTime(value, c.RequestedAt) + if c.RequestedAt, err = toNumericDate(value); err == nil { + ok = true + } case ClaimAuthenticationTime: - c.AuthTime, ok = toTime(value, c.AuthTime) + if c.AuthTime, err = toNumericDate(value); err == nil { + ok = true + } case ClaimCodeHash: c.CodeHash, ok = value.(string) case ClaimStateHash: @@ -143,14 +292,14 @@ func (c *IDTokenClaims) ToMap() map[string]any { ret[consts.ClaimAudience] = []string{} } - if !c.IssuedAt.IsZero() { + if c.IssuedAt != nil { ret[consts.ClaimIssuedAt] = c.IssuedAt.Unix() } else { delete(ret, consts.ClaimIssuedAt) } - if !c.ExpiresAt.IsZero() { - ret[consts.ClaimExpirationTime] = c.ExpiresAt.Unix() + if c.ExpirationTime != nil { + ret[consts.ClaimExpirationTime] = c.ExpirationTime.Unix() } else { delete(ret, consts.ClaimExpirationTime) } @@ -179,7 +328,7 @@ func (c *IDTokenClaims) ToMap() map[string]any { delete(ret, consts.ClaimStateHash) } - if !c.AuthTime.IsZero() { + if c.AuthTime != nil { ret[consts.ClaimAuthenticationTime] = c.AuthTime.Unix() } else { delete(ret, consts.ClaimAuthenticationTime) @@ -200,11 +349,17 @@ func (c *IDTokenClaims) ToMap() map[string]any { return ret } +// ToMapClaims will return a jwt-go MapClaims representation +func (c IDTokenClaims) ToMapClaims() MapClaims { + return c.ToMap() +} + // Add will add a key-value pair to the extra field func (c *IDTokenClaims) Add(key string, value any) { if c.Extra == nil { c.Extra = make(map[string]any) } + c.Extra[key] = value } @@ -213,7 +368,19 @@ func (c *IDTokenClaims) Get(key string) any { return c.ToMap()[key] } -// ToMapClaims will return a jwt-go MapClaims representation -func (c IDTokenClaims) ToMapClaims() MapClaims { - return c.ToMap() +func (c IDTokenClaims) toNumericDate(key string) (date *NumericDate, err error) { + var ( + v any + ok bool + ) + + if v, ok = c.Extra[key]; !ok { + return nil, nil + } + + return toNumericDate(v) } + +var ( + _ Claims = (*IDTokenClaims)(nil) +) diff --git a/token/jwt/claims_id_token_test.go b/token/jwt/claims_id_token_test.go index 66801dc4..bc35493f 100644 --- a/token/jwt/claims_id_token_test.go +++ b/token/jwt/claims_id_token_test.go @@ -14,9 +14,9 @@ import ( ) func TestIDTokenAssert(t *testing.T) { - assert.NoError(t, (&IDTokenClaims{ExpiresAt: time.Now().UTC().Add(time.Hour)}). + assert.NoError(t, (&IDTokenClaims{ExpirationTime: NewNumericDate(time.Now().Add(time.Hour))}). ToMapClaims().Valid()) - assert.Error(t, (&IDTokenClaims{ExpiresAt: time.Now().UTC().Add(-time.Hour)}). + assert.Error(t, (&IDTokenClaims{ExpirationTime: NewNumericDate(time.Now().Add(-time.Hour))}). ToMapClaims().Valid()) assert.NotEmpty(t, (new(IDTokenClaims)).ToMapClaims()[ClaimJWTID]) @@ -26,12 +26,12 @@ func TestIDTokenClaimsToMap(t *testing.T) { idTokenClaims := &IDTokenClaims{ JTI: "foo-id", Subject: "peter", - IssuedAt: time.Now().UTC().Round(time.Second), + IssuedAt: Now(), Issuer: "authelia", Audience: []string{"tests"}, - ExpiresAt: time.Now().UTC().Add(time.Hour).Round(time.Second), - AuthTime: time.Now().UTC(), - RequestedAt: time.Now().UTC(), + ExpirationTime: NewNumericDate(time.Now().Add(time.Hour)), + AuthTime: Now(), + RequestedAt: Now(), AccessTokenHash: "foobar", CodeHash: "barfoo", StateHash: "boofar", @@ -48,7 +48,7 @@ func TestIDTokenClaimsToMap(t *testing.T) { ClaimIssuedAt: idTokenClaims.IssuedAt.Unix(), ClaimIssuer: idTokenClaims.Issuer, ClaimAudience: idTokenClaims.Audience, - ClaimExpirationTime: idTokenClaims.ExpiresAt.Unix(), + ClaimExpirationTime: idTokenClaims.ExpirationTime.Unix(), "foo": idTokenClaims.Extra["foo"], "baz": idTokenClaims.Extra["baz"], ClaimAccessTokenHash: idTokenClaims.AccessTokenHash, @@ -66,7 +66,7 @@ func TestIDTokenClaimsToMap(t *testing.T) { consts.ClaimIssuedAt: idTokenClaims.IssuedAt.Unix(), consts.ClaimIssuer: idTokenClaims.Issuer, consts.ClaimAudience: idTokenClaims.Audience, - consts.ClaimExpirationTime: idTokenClaims.ExpiresAt.Unix(), + consts.ClaimExpirationTime: idTokenClaims.ExpirationTime.Unix(), "foo": idTokenClaims.Extra["foo"], "baz": idTokenClaims.Extra["baz"], consts.ClaimAccessTokenHash: idTokenClaims.AccessTokenHash, diff --git a/token/jwt/claims_jarm.go b/token/jwt/claims_jarm.go index 5328cf5e..dd0a4773 100644 --- a/token/jwt/claims_jarm.go +++ b/token/jwt/claims_jarm.go @@ -1,21 +1,61 @@ package jwt import ( + "fmt" "time" "github.com/google/uuid" - - "authelia.com/provider/oauth2/internal/consts" ) +func NewJARMClaims(issuer string, aud ClaimStrings, lifespan time.Duration) *JARMClaims { + now := time.Now() + + return &JARMClaims{ + Issuer: issuer, + Audience: aud, + JTI: uuid.NewString(), + IssuedAt: NewNumericDate(now), + ExpirationTime: NewNumericDate(now.Add(lifespan)), + Extra: map[string]any{}, + } +} + // JARMClaims represent a token's claims. type JARMClaims struct { - Issuer string - Audience []string - JTI string - IssuedAt time.Time - ExpiresAt time.Time - Extra map[string]any + Issuer string `json:"iss"` + Audience ClaimStrings `json:"aud"` + JTI string `json:"jti"` + IssuedAt *NumericDate `json:"iat,omitempty"` + ExpirationTime *NumericDate `json:"exp,omitempty"` + Extra map[string]any `json:"-"` +} + +func (c *JARMClaims) GetExpirationTime() (exp *NumericDate, err error) { + return c.ExpirationTime, nil +} + +func (c *JARMClaims) GetIssuedAt() (iat *NumericDate, err error) { + return c.IssuedAt, nil +} + +func (c *JARMClaims) GetNotBefore() (nbf *NumericDate, err error) { + return c.toNumericDate(ClaimNotBefore) +} + +func (c *JARMClaims) GetIssuer() (iss string, err error) { + return c.Issuer, nil +} + +func (c *JARMClaims) GetSubject() (sub string, err error) { + return c.toString(ClaimIssuer) +} + +func (c *JARMClaims) GetAudience() (aud ClaimStrings, err error) { + return c.Audience, nil +} + +func (c *JARMClaims) Valid(opts ...ClaimValidationOption) (err error) { + return nil } // ToMap will transform the headers to a map structure @@ -35,19 +75,19 @@ func (c *JARMClaims) ToMap() map[string]any { } if len(c.Audience) > 0 { - ret[ClaimAudience] = c.Audience + ret[ClaimAudience] = []string(c.Audience) } else { ret[ClaimAudience] = []string{} } - if !c.IssuedAt.IsZero() { + if c.IssuedAt != nil { ret[ClaimIssuedAt] = c.IssuedAt.Unix() } else { delete(ret, ClaimIssuedAt) } - if !c.ExpiresAt.IsZero() { - ret[ClaimExpirationTime] = c.ExpiresAt.Unix() + if c.ExpirationTime != nil { + ret[ClaimExpirationTime] = c.ExpirationTime.Unix() } else { delete(ret, ClaimExpirationTime) } @@ -55,6 +95,11 @@ func (c *JARMClaims) ToMap() map[string]any { return ret } +// ToMapClaims will return a jwt-go MapClaims representation +func (c JARMClaims) ToMapClaims() MapClaims { + return c.ToMap() +} + // FromMap will set the claims based on a mapping func (c *JARMClaims) FromMap(m map[string]any) { c.Extra = make(map[string]any) @@ -72,16 +117,21 @@ func (c *JARMClaims) FromMap(m map[string]any) { if aud, ok := StringSliceFromMap(v); ok { c.Audience = aud } - case consts.ClaimIssuedAt: - c.IssuedAt, _ = toTime(v, c.IssuedAt) - case consts.ClaimExpirationTime: - c.ExpiresAt, _ = toTime(v, c.ExpiresAt) + case ClaimIssuedAt: + c.IssuedAt, _ = toNumericDate(v) + case ClaimExpirationTime: + c.ExpirationTime, _ = toNumericDate(v) default: c.Extra[k] = v } } } +// FromMapClaims will populate claims from a jwt-go MapClaims representation +func (c *JARMClaims) FromMapClaims(mc MapClaims) { + c.FromMap(mc) +} + // Add will add a key-value pair to the extra field func (c *JARMClaims) Add(key string, value any) { if c.Extra == nil { @@ -96,12 +146,36 @@ func (c JARMClaims) Get(key string) any { return c.ToMap()[key] } -// ToMapClaims will return a jwt-go MapClaims representation -func (c JARMClaims) ToMapClaims() MapClaims { - return c.ToMap() +func (c JARMClaims) toNumericDate(key string) (date *NumericDate, err error) { + var ( + v any + ok bool + ) + + if v, ok = c.Extra[key]; !ok { + return nil, nil + } + + return toNumericDate(v) } -// FromMapClaims will populate claims from a jwt-go MapClaims representation -func (c *JARMClaims) FromMapClaims(mc MapClaims) { - c.FromMap(mc) +func (c JARMClaims) toString(key string) (value string, err error) { + var ( + ok bool + raw any + ) + + if raw, ok = c.Extra[key]; !ok { + return "", nil + } + + if value, ok = raw.(string); !ok { + return "", newError(fmt.Sprintf("%s is invalid", key), ErrInvalidType) + } + + return value, nil } + +var ( + _ Claims = (*JARMClaims)(nil) +) diff --git a/token/jwt/claims_jarm_test.go b/token/jwt/claims_jarm_test.go index c35873f6..f114177e 100644 --- a/token/jwt/claims_jarm_test.go +++ b/token/jwt/claims_jarm_test.go @@ -13,11 +13,11 @@ import ( ) var jarmClaims = &JARMClaims{ - Issuer: "authelia", - Audience: []string{"tests"}, - JTI: "abcdef", - IssuedAt: time.Now().UTC().Round(time.Second), - ExpiresAt: time.Now().UTC().Add(time.Hour).Round(time.Second), + Issuer: "authelia", + Audience: []string{"tests"}, + JTI: "abcdef", + IssuedAt: Now(), + ExpirationTime: NewNumericDate(time.Now().Add(time.Hour)), Extra: map[string]any{ "foo": "bar", "baz": "bar", @@ -44,9 +44,9 @@ func TestJARMClaimsToMapSetsID(t *testing.T) { } func TestJARMAssert(t *testing.T) { - assert.Nil(t, (&JARMClaims{ExpiresAt: time.Now().UTC().Add(time.Hour)}). + assert.Nil(t, (&JARMClaims{ExpirationTime: NewNumericDate(time.Now().Add(time.Hour))}). ToMapClaims().Valid()) - assert.NotNil(t, (&JARMClaims{ExpiresAt: time.Now().UTC().Add(-2 * time.Hour)}). + assert.NotNil(t, (&JARMClaims{ExpirationTime: NewNumericDate(time.Now().Add(-2 * time.Hour))}). ToMapClaims().Valid()) } diff --git a/token/jwt/claims_jwt.go b/token/jwt/claims_jwt.go index ecaf5aea..db1bcf93 100644 --- a/token/jwt/claims_jwt.go +++ b/token/jwt/claims_jwt.go @@ -44,7 +44,7 @@ type JWTClaimsContainer interface { // WithScopeField configures how a scope field should be represented in JWT. WithScopeField(scopeField JWTScopeFieldEnum) JWTClaimsContainer - // ToMapClaims returns the claims as a github.com/dgrijalva/jwt-go.MapClaims type. + // ToMapClaims returns the claims as a MapClaims type. ToMapClaims() MapClaims } @@ -269,6 +269,8 @@ func toInt64(v any) (val int64, ok bool) { func toNumericDate(v any) (date *NumericDate, err error) { switch value := v.(type) { + case nil: + return nil, nil case float64: if value == 0 { return nil, nil diff --git a/token/jwt/claims_jwt_test.go b/token/jwt/claims_jwt_test.go index 653ea8ef..7cc2ee07 100644 --- a/token/jwt/claims_jwt_test.go +++ b/token/jwt/claims_jwt_test.go @@ -15,11 +15,11 @@ import ( var jwtClaims = &JWTClaims{ Subject: "peter", - IssuedAt: time.Now().UTC().Round(time.Second), + IssuedAt: time.Now().UTC().Truncate(TimePrecision), Issuer: "authelia", - NotBefore: time.Now().UTC().Round(time.Second), + NotBefore: time.Now().UTC().Truncate(TimePrecision), Audience: []string{"tests"}, - ExpiresAt: time.Now().UTC().Add(time.Hour).Round(time.Second), + ExpiresAt: time.Now().UTC().Add(time.Hour).Truncate(TimePrecision), JTI: "abcdef", Scope: []string{consts.ScopeEmail, consts.ScopeOffline}, Extra: map[string]any{ diff --git a/token/jwt/claims_map.go b/token/jwt/claims_map.go index 4e1620b3..904e0e09 100644 --- a/token/jwt/claims_map.go +++ b/token/jwt/claims_map.go @@ -5,10 +5,8 @@ package jwt import ( "bytes" - "crypto/subtle" "errors" "fmt" - "time" jjson "github.com/go-jose/go-jose/v4/json" @@ -39,7 +37,7 @@ func (m MapClaims) VerifyIssuer(cmp string, required bool) (ok bool) { return !required } - return verifyString(iss, cmp, required) + return validString(iss, cmp, required) } // GetSubject returns the 'sub' claim. @@ -63,7 +61,7 @@ func (m MapClaims) VerifySubject(cmp string, required bool) (ok bool) { return !required } - return verifyString(sub, cmp, required) + return validString(sub, cmp, required) } // GetAudience returns the 'aud' claim. @@ -151,7 +149,7 @@ func (m MapClaims) VerifyExpirationTime(cmp int64, required bool) (ok bool) { return !required } - return verifyInt64Future(exp.Int64(), cmp, required) + return validInt64Future(exp.Int64(), cmp, required) } // GetIssuedAt returns the 'iat' claim. @@ -175,7 +173,7 @@ func (m MapClaims) VerifyIssuedAt(cmp int64, required bool) (ok bool) { return !required } - return verifyInt64Past(iat.Int64(), cmp, required) + return validInt64Past(iat.Int64(), cmp, required) } // GetNotBefore returns the 'nbf' claim. @@ -199,7 +197,7 @@ func (m MapClaims) VerifyNotBefore(cmp int64, required bool) (ok bool) { return !required } - return verifyInt64Past(nbf.Int64(), cmp, required) + return validInt64Past(nbf.Int64(), cmp, required) } func (m MapClaims) ToMapClaims() MapClaims { @@ -361,139 +359,3 @@ func (m MapClaims) toClaimsString(key string) (ClaimStrings, error) { return cs, nil } - -type ClaimValidationOption func(opts *ClaimValidationOptions) - -type ClaimValidationOptions struct { - timef func() time.Time - iss string - aud []string - audAll []string - sub string - expRequired bool - iatRequired bool - nbfRequired bool -} - -func ValidateTimeFunc(timef func() time.Time) ClaimValidationOption { - return func(opts *ClaimValidationOptions) { - opts.timef = timef - } -} - -func ValidateIssuer(iss string) ClaimValidationOption { - return func(opts *ClaimValidationOptions) { - opts.iss = iss - } -} - -func ValidateAudienceAny(aud ...string) ClaimValidationOption { - return func(opts *ClaimValidationOptions) { - opts.aud = aud - } -} - -func ValidateAudienceAll(aud ...string) ClaimValidationOption { - return func(opts *ClaimValidationOptions) { - opts.audAll = aud - } -} - -func ValidateSubject(sub string) ClaimValidationOption { - return func(opts *ClaimValidationOptions) { - opts.sub = sub - } -} - -func ValidateRequireExpiresAt() ClaimValidationOption { - return func(opts *ClaimValidationOptions) { - opts.expRequired = true - } -} - -func ValidateRequireIssuedAt() ClaimValidationOption { - return func(opts *ClaimValidationOptions) { - opts.iatRequired = true - } -} - -func ValidateRequireNotBefore() ClaimValidationOption { - return func(opts *ClaimValidationOptions) { - opts.nbfRequired = true - } -} - -func verifyAud(aud []string, cmp string, required bool) bool { - if len(aud) == 0 { - return !required - } - - for _, a := range aud { - if subtle.ConstantTimeCompare([]byte(a), []byte(cmp)) == 1 { - return true - } - } - - return false -} - -func verifyAudAny(aud []string, cmp []string, required bool) bool { - if len(aud) == 0 { - return !required - } - - for _, c := range cmp { - for _, a := range aud { - if subtle.ConstantTimeCompare([]byte(a), []byte(c)) == 1 { - return true - } - } - } - - return false -} - -func verifyAudAll(aud []string, cmp []string, required bool) bool { - if len(aud) == 0 { - return !required - } - -outer: - for _, c := range cmp { - for _, a := range aud { - if subtle.ConstantTimeCompare([]byte(a), []byte(c)) == 1 { - continue outer - } - } - - return false - } - - return true -} - -// verifyInt64Future ensures the given value is in the future. -func verifyInt64Future(value, now int64, required bool) bool { - if value == 0 { - return !required - } - - return now <= value -} - -// verifyInt64Past ensures the given value is in the past or the current value. -func verifyInt64Past(value, now int64, required bool) bool { - if value == 0 { - return !required - } - - return now >= value -} - -func verifyString(value, cmp string, required bool) bool { - if value == "" { - return !required - } - - return subtle.ConstantTimeCompare([]byte(value), []byte(cmp)) == 1 -} diff --git a/token/jwt/claims_test.go b/token/jwt/claims_test.go index e06f39b9..c5245036 100644 --- a/token/jwt/claims_test.go +++ b/token/jwt/claims_test.go @@ -21,7 +21,7 @@ func TestToTime(t *testing.T) { assert.Equal(t, time.Time{}, ToTime(nil)) assert.Equal(t, time.Time{}, ToTime("1234")) - now := time.Now().UTC().Round(time.Second) + now := time.Now().UTC().Truncate(TimePrecision) assert.Equal(t, now, ToTime(now)) assert.Equal(t, now, ToTime(now.Unix())) assert.Equal(t, now, ToTime(float64(now.Unix()))) diff --git a/token/jwt/date.go b/token/jwt/date.go index 187f5cbf..53eab0ac 100644 --- a/token/jwt/date.go +++ b/token/jwt/date.go @@ -1,6 +1,7 @@ package jwt import ( + "crypto/subtle" "encoding/json" "errors" "fmt" @@ -13,21 +14,28 @@ type NumericDate struct { time.Time } +func Now() *NumericDate { + return NewNumericDate(time.Now()) +} + func NewNumericDate(t time.Time) *NumericDate { - return &NumericDate{t.Truncate(TimePrecision)} + return &NumericDate{t.UTC().Truncate(TimePrecision)} } func newNumericDateFromSeconds(f float64) *NumericDate { round, frac := math.Modf(f) + return NewNumericDate(time.Unix(int64(round), int64(frac*1e9))) } func (date NumericDate) MarshalJSON() (b []byte, err error) { var prec int + if TimePrecision < time.Second { prec = int(math.Log10(float64(time.Second) / float64(TimePrecision))) } - truncatedDate := date.Truncate(TimePrecision) + + truncatedDate := date.UTC().Truncate(TimePrecision) seconds := strconv.FormatInt(truncatedDate.Unix(), 10) nanosecondsOffset := strconv.FormatFloat(float64(truncatedDate.Nanosecond())/float64(time.Second), 'f', prec, 64) @@ -69,6 +77,55 @@ func (date *NumericDate) Int64() (val int64) { type ClaimStrings []string +func (s ClaimStrings) Valid(cmp string, required bool) (valid bool) { + if len(s) == 0 { + return !required + } + + for _, str := range s { + if subtle.ConstantTimeCompare([]byte(str), []byte(cmp)) == 1 { + return true + } + } + + return false +} + +func (s ClaimStrings) ValidAny(cmp ClaimStrings, required bool) (valid bool) { + if len(s) == 0 { + return !required + } + + for _, strCmp := range cmp { + for _, str := range s { + if subtle.ConstantTimeCompare([]byte(str), []byte(strCmp)) == 1 { + return true + } + } + } + + return false +} + +func (s ClaimStrings) ValidAll(cmp ClaimStrings, required bool) (valid bool) { + if len(s) == 0 { + return !required + } + +outer: + for _, strCmp := range cmp { + for _, str := range s { + if subtle.ConstantTimeCompare([]byte(str), []byte(strCmp)) == 1 { + continue outer + } + } + + return false + } + + return true +} + func (s *ClaimStrings) UnmarshalJSON(data []byte) (err error) { var value interface{} diff --git a/token/jwt/token.go b/token/jwt/token.go index b85e72fe..2dd97764 100644 --- a/token/jwt/token.go +++ b/token/jwt/token.go @@ -127,12 +127,7 @@ func ParseCustomWithClaims(tokenString string, claims MapClaims, keyFunc Keyfunc return token, nil } -// Token represets a JWT Token -// This token provide an adaptation to -// transit from [jwt-go](https://github.com/dgrijalva/jwt-go) -// to [go-jose](https://github.com/square/go-jose) -// It provides method signatures compatible with jwt-go but implemented -// using go-json +// Token represets a JWT Token. type Token struct { KeyID string SignatureAlgorithm jose.SignatureAlgorithm // alg (JWS) diff --git a/token/jwt/token_test.go b/token/jwt/token_test.go index 170598dd..9169170a 100644 --- a/token/jwt/token_test.go +++ b/token/jwt/token_test.go @@ -101,7 +101,6 @@ var ( nilKeyFunc Keyfunc = nil ) -// Many test cases where taken from https://github.com/dgrijalva/jwt-go/blob/master/parser_test.go // Test cases related to json.Number where excluded because that is not supported by go-jose, // it is not used here and therefore not supported. // diff --git a/token/jwt/validate.go b/token/jwt/validate.go new file mode 100644 index 00000000..1797d1cd --- /dev/null +++ b/token/jwt/validate.go @@ -0,0 +1,164 @@ +package jwt + +import ( + "crypto/subtle" + "time" +) + +type ClaimValidationOption func(opts *ClaimValidationOptions) + +type ClaimValidationOptions struct { + timef func() time.Time + iss string + aud []string + audAll []string + sub string + expRequired bool + iatRequired bool + nbfRequired bool +} + +func ValidateTimeFunc(timef func() time.Time) ClaimValidationOption { + return func(opts *ClaimValidationOptions) { + opts.timef = timef + } +} + +func ValidateIssuer(iss string) ClaimValidationOption { + return func(opts *ClaimValidationOptions) { + opts.iss = iss + } +} + +func ValidateAudienceAny(aud ...string) ClaimValidationOption { + return func(opts *ClaimValidationOptions) { + opts.aud = aud + } +} + +func ValidateAudienceAll(aud ...string) ClaimValidationOption { + return func(opts *ClaimValidationOptions) { + opts.audAll = aud + } +} + +func ValidateSubject(sub string) ClaimValidationOption { + return func(opts *ClaimValidationOptions) { + opts.sub = sub + } +} + +func ValidateRequireExpiresAt() ClaimValidationOption { + return func(opts *ClaimValidationOptions) { + opts.expRequired = true + } +} + +func ValidateRequireIssuedAt() ClaimValidationOption { + return func(opts *ClaimValidationOptions) { + opts.iatRequired = true + } +} + +func ValidateRequireNotBefore() ClaimValidationOption { + return func(opts *ClaimValidationOptions) { + opts.nbfRequired = true + } +} + +func verifyAud(aud []string, cmp string, required bool) bool { + if len(aud) == 0 { + return !required + } + + for _, a := range aud { + if subtle.ConstantTimeCompare([]byte(a), []byte(cmp)) == 1 { + return true + } + } + + return false +} + +func verifyAudAny(aud []string, cmp []string, required bool) bool { + if len(aud) == 0 { + return !required + } + + for _, c := range cmp { + for _, a := range aud { + if subtle.ConstantTimeCompare([]byte(a), []byte(c)) == 1 { + return true + } + } + } + + return false +} + +func verifyAudAll(aud []string, cmp []string, required bool) bool { + if len(aud) == 0 { + return !required + } + +outer: + for _, c := range cmp { + for _, a := range aud { + if subtle.ConstantTimeCompare([]byte(a), []byte(c)) == 1 { + continue outer + } + } + + return false + } + + return true +} + +// validInt64Future ensures the given value is in the future. +func validInt64Future(value, now int64, required bool) bool { + if value == 0 { + return !required + } + + return now <= value +} + +// validInt64Past ensures the given value is in the past or the current value. +func validInt64Past(value, now int64, required bool) bool { + if value == 0 { + return !required + } + + return now >= value +} + +func validString(value, cmp string, required bool) bool { + if value == "" { + return !required + } + + return subtle.ConstantTimeCompare([]byte(value), []byte(cmp)) == 1 +} + +type validDateFunc func(value, now int64, required bool) bool + +func validDate(valid validDateFunc, now int64, required bool, date *NumericDate, err error) bool { + if err != nil || valid == nil { + return false + } + + if date == nil { + if required { + return false + } + + return true + } + + if valid(date.Int64(), now, required) { + return true + } + + return false +} From 2a7c7fa5f7e3281ac6d55df2abfc21d74018a44d Mon Sep 17 00:00:00 2001 From: James Elliott Date: Thu, 3 Oct 2024 22:54:03 +1000 Subject: [PATCH 31/33] feat: claims interface --- token/jwt/claims_id_token.go | 159 +++++++++++++++++------------- token/jwt/claims_id_token_test.go | 2 + token/jwt/util.go | 21 +++- 3 files changed, 113 insertions(+), 69 deletions(-) diff --git a/token/jwt/claims_id_token.go b/token/jwt/claims_id_token.go index 4dec51a2..e63c325b 100644 --- a/token/jwt/claims_id_token.go +++ b/token/jwt/claims_id_token.go @@ -5,6 +5,7 @@ package jwt import ( "bytes" + "encoding/json" "errors" "fmt" "time" @@ -21,18 +22,18 @@ type IDTokenClaims struct { JTI string `json:"jti"` Issuer string `json:"iss"` Subject string `json:"sub"` - Audience []string `json:"aud"` - Nonce string `json:"nonce"` - ExpirationTime *NumericDate `json:"exp"` - IssuedAt *NumericDate `json:"iat"` - RequestedAt *NumericDate `json:"rat"` - AuthTime *NumericDate `json:"auth_time"` - AccessTokenHash string `json:"at_hash"` - AuthenticationContextClassReference string `json:"acr"` - AuthenticationMethodsReferences []string `json:"amr"` - CodeHash string `json:"c_hash"` - StateHash string `json:"s_hash"` - Extra map[string]any `json:"ext"` + Audience []string `json:"aud,omitempty"` + Nonce string `json:"nonce,omitempty"` + ExpirationTime *NumericDate `json:"exp,omitempty"` + IssuedAt *NumericDate `json:"iat,omitempty"` + RequestedAt *NumericDate `json:"rat,omitempty"` + AuthTime *NumericDate `json:"auth_time,omitempty"` + AuthenticationContextClassReference string `json:"acr,omitempty"` + AuthenticationMethodsReferences []string `json:"amr,omitempty"` + AccessTokenHash string `json:"at_hash,omitempty"` + CodeHash string `json:"c_hash,omitempty"` + StateHash string `json:"s_hash,omitempty"` + Extra map[string]any `json:"ext,omitempty"` } func (c *IDTokenClaims) GetExpirationTime() (exp *NumericDate, err error) { @@ -170,6 +171,12 @@ func (c *IDTokenClaims) GetRequestedAtSafe() time.Time { return c.RequestedAt.UTC() } +func (c *IDTokenClaims) MarshalJSON() (data []byte, err error) { + claims := c.ToMapClaims() + + return json.Marshal(claims) +} + func (c *IDTokenClaims) UnmarshalJSON(data []byte) error { claims := MapClaims{} @@ -196,32 +203,7 @@ func (c *IDTokenClaims) UnmarshalJSON(data []byte) error { case ClaimSubject: c.Subject, ok = value.(string) case ClaimAudience: - switch aud := value.(type) { - case nil: - ok = true - case string: - ok = true - - c.Audience = []string{aud} - case []string: - ok = true - - c.Audience = aud - case []any: - ok = true - - loop: - for _, av := range aud { - switch a := av.(type) { - case string: - c.Audience = append(c.Audience, a) - default: - ok = false - - break loop - } - } - } + c.Audience, ok = toStringSlice(value) case ClaimNonce: c.Nonce, ok = value.(string) case ClaimExpirationTime: @@ -240,12 +222,16 @@ func (c *IDTokenClaims) UnmarshalJSON(data []byte) error { if c.AuthTime, err = toNumericDate(value); err == nil { ok = true } + case ClaimAuthenticationContextClassReference: + c.AuthenticationContextClassReference, ok = value.(string) + case ClaimAuthenticationMethodsReference: + c.AuthenticationMethodsReferences, ok = toStringSlice(value) + case ClaimAccessTokenHash: + c.AccessTokenHash, ok = value.(string) case ClaimCodeHash: c.CodeHash, ok = value.(string) case ClaimStateHash: c.StateHash, ok = value.(string) - case ClaimAuthenticationContextClassReference: - c.AuthenticationContextClassReference, ok = value.(string) default: if c.Extra == nil { c.Extra = make(map[string]any) @@ -268,10 +254,10 @@ func (c *IDTokenClaims) UnmarshalJSON(data []byte) error { func (c *IDTokenClaims) ToMap() map[string]any { var ret = Copy(c.Extra) - if c.Subject != "" { - ret[ClaimSubject] = c.Subject + if c.JTI != "" { + ret[consts.ClaimJWTID] = c.JTI } else { - delete(ret, ClaimSubject) + ret[consts.ClaimJWTID] = uuid.New().String() } if c.Issuer != "" { @@ -280,28 +266,16 @@ func (c *IDTokenClaims) ToMap() map[string]any { delete(ret, consts.ClaimIssuer) } - if c.JTI != "" { - ret[consts.ClaimJWTID] = c.JTI + if c.Subject != "" { + ret[ClaimSubject] = c.Subject } else { - ret[consts.ClaimJWTID] = uuid.New().String() + delete(ret, ClaimSubject) } if len(c.Audience) > 0 { ret[consts.ClaimAudience] = c.Audience } else { - ret[consts.ClaimAudience] = []string{} - } - - if c.IssuedAt != nil { - ret[consts.ClaimIssuedAt] = c.IssuedAt.Unix() - } else { - delete(ret, consts.ClaimIssuedAt) - } - - if c.ExpirationTime != nil { - ret[consts.ClaimExpirationTime] = c.ExpirationTime.Unix() - } else { - delete(ret, consts.ClaimExpirationTime) + delete(ret, ClaimAudience) } if len(c.Nonce) > 0 { @@ -310,22 +284,22 @@ func (c *IDTokenClaims) ToMap() map[string]any { delete(ret, consts.ClaimNonce) } - if len(c.AccessTokenHash) > 0 { - ret[consts.ClaimAccessTokenHash] = c.AccessTokenHash + if c.ExpirationTime != nil { + ret[consts.ClaimExpirationTime] = c.ExpirationTime.Unix() } else { - delete(ret, consts.ClaimAccessTokenHash) + delete(ret, consts.ClaimExpirationTime) } - if len(c.CodeHash) > 0 { - ret[consts.ClaimCodeHash] = c.CodeHash + if c.IssuedAt != nil { + ret[consts.ClaimIssuedAt] = c.IssuedAt.Unix() } else { - delete(ret, consts.ClaimCodeHash) + delete(ret, consts.ClaimIssuedAt) } - if len(c.StateHash) > 0 { - ret[consts.ClaimStateHash] = c.StateHash + if c.RequestedAt != nil { + ret[consts.ClaimRequestedAt] = c.RequestedAt.Unix() } else { - delete(ret, consts.ClaimStateHash) + delete(ret, consts.ClaimRequestedAt) } if c.AuthTime != nil { @@ -346,6 +320,24 @@ func (c *IDTokenClaims) ToMap() map[string]any { delete(ret, consts.ClaimAuthenticationMethodsReference) } + if len(c.AccessTokenHash) > 0 { + ret[consts.ClaimAccessTokenHash] = c.AccessTokenHash + } else { + delete(ret, consts.ClaimAccessTokenHash) + } + + if len(c.CodeHash) > 0 { + ret[consts.ClaimCodeHash] = c.CodeHash + } else { + delete(ret, consts.ClaimCodeHash) + } + + if len(c.StateHash) > 0 { + ret[consts.ClaimStateHash] = c.StateHash + } else { + delete(ret, consts.ClaimStateHash) + } + return ret } @@ -381,6 +373,37 @@ func (c IDTokenClaims) toNumericDate(key string) (date *NumericDate, err error) return toNumericDate(v) } +func toStringSlice(value any) (values []string, ok bool) { + switch t := value.(type) { + case nil: + ok = true + case string: + ok = true + + values = []string{t} + case []string: + ok = true + + values = t + case []any: + ok = true + + loop: + for _, tv := range t { + switch vv := tv.(type) { + case string: + values = append(values, vv) + default: + ok = false + + break loop + } + } + } + + return values, ok +} + var ( _ Claims = (*IDTokenClaims)(nil) ) diff --git a/token/jwt/claims_id_token_test.go b/token/jwt/claims_id_token_test.go index bc35493f..8c22119c 100644 --- a/token/jwt/claims_id_token_test.go +++ b/token/jwt/claims_id_token_test.go @@ -49,6 +49,7 @@ func TestIDTokenClaimsToMap(t *testing.T) { ClaimIssuer: idTokenClaims.Issuer, ClaimAudience: idTokenClaims.Audience, ClaimExpirationTime: idTokenClaims.ExpirationTime.Unix(), + ClaimRequestedAt: idTokenClaims.RequestedAt.Unix(), "foo": idTokenClaims.Extra["foo"], "baz": idTokenClaims.Extra["baz"], ClaimAccessTokenHash: idTokenClaims.AccessTokenHash, @@ -67,6 +68,7 @@ func TestIDTokenClaimsToMap(t *testing.T) { consts.ClaimIssuer: idTokenClaims.Issuer, consts.ClaimAudience: idTokenClaims.Audience, consts.ClaimExpirationTime: idTokenClaims.ExpirationTime.Unix(), + consts.ClaimRequestedAt: idTokenClaims.RequestedAt.Unix(), "foo": idTokenClaims.Extra["foo"], "baz": idTokenClaims.Extra["baz"], consts.ClaimAccessTokenHash: idTokenClaims.AccessTokenHash, diff --git a/token/jwt/util.go b/token/jwt/util.go index f97ec97c..2f299772 100644 --- a/token/jwt/util.go +++ b/token/jwt/util.go @@ -276,8 +276,27 @@ func NewClientSecretJWK(ctx context.Context, secret []byte, kid, alg, enc, use s switch use { case JSONWebTokenUseSignature: + var ( + hasher hash.Hash + ) + + switch jose.SignatureAlgorithm(alg) { + case jose.HS256: + hasher = sha256.New() + case jose.HS384: + hasher = sha512.New384() + case jose.HS512: + hasher = sha512.New() + default: + return nil, &JWKLookupError{Description: fmt.Sprintf("Unsupported algorithm '%s'", alg)} + } + + if _, err = hasher.Write(secret); err != nil { + return nil, &JWKLookupError{Description: fmt.Sprintf("Unable to derive key from hashing the client secret. %s", err.Error())} + } + return &jose.JSONWebKey{ - Key: secret, + Key: hasher.Sum(nil), KeyID: kid, Algorithm: alg, Use: use, From 7637fe0da962c80798ec8a813fcd3cd563c16c1e Mon Sep 17 00:00:00 2001 From: James Elliott Date: Fri, 4 Oct 2024 21:47:16 +1000 Subject: [PATCH 32/33] feat: claims interface --- token/jwt/claims_id_token.go | 94 +++++++++++++++++++----------------- 1 file changed, 49 insertions(+), 45 deletions(-) diff --git a/token/jwt/claims_id_token.go b/token/jwt/claims_id_token.go index e63c325b..7bddfe33 100644 --- a/token/jwt/claims_id_token.go +++ b/token/jwt/claims_id_token.go @@ -5,7 +5,6 @@ package jwt import ( "bytes" - "encoding/json" "errors" "fmt" "time" @@ -22,14 +21,15 @@ type IDTokenClaims struct { JTI string `json:"jti"` Issuer string `json:"iss"` Subject string `json:"sub"` - Audience []string `json:"aud,omitempty"` - Nonce string `json:"nonce,omitempty"` - ExpirationTime *NumericDate `json:"exp,omitempty"` - IssuedAt *NumericDate `json:"iat,omitempty"` - RequestedAt *NumericDate `json:"rat,omitempty"` + Audience []string `json:"aud"` + ExpirationTime *NumericDate `json:"exp"` + IssuedAt *NumericDate `json:"iat"` AuthTime *NumericDate `json:"auth_time,omitempty"` + RequestedAt *NumericDate `json:"rat,omitempty"` + Nonce string `json:"nonce,omitempty"` AuthenticationContextClassReference string `json:"acr,omitempty"` AuthenticationMethodsReferences []string `json:"amr,omitempty"` + AuthorizedParty string `json:"azp,omitempty"` AccessTokenHash string `json:"at_hash,omitempty"` CodeHash string `json:"c_hash,omitempty"` StateHash string `json:"s_hash,omitempty"` @@ -171,12 +171,6 @@ func (c *IDTokenClaims) GetRequestedAtSafe() time.Time { return c.RequestedAt.UTC() } -func (c *IDTokenClaims) MarshalJSON() (data []byte, err error) { - claims := c.ToMapClaims() - - return json.Marshal(claims) -} - func (c *IDTokenClaims) UnmarshalJSON(data []byte) error { claims := MapClaims{} @@ -204,8 +198,6 @@ func (c *IDTokenClaims) UnmarshalJSON(data []byte) error { c.Subject, ok = value.(string) case ClaimAudience: c.Audience, ok = toStringSlice(value) - case ClaimNonce: - c.Nonce, ok = value.(string) case ClaimExpirationTime: if c.ExpirationTime, err = toNumericDate(value); err == nil { ok = true @@ -214,24 +206,30 @@ func (c *IDTokenClaims) UnmarshalJSON(data []byte) error { if c.IssuedAt, err = toNumericDate(value); err == nil { ok = true } - case ClaimRequestedAt: - if c.RequestedAt, err = toNumericDate(value); err == nil { - ok = true - } case ClaimAuthenticationTime: if c.AuthTime, err = toNumericDate(value); err == nil { ok = true } + case ClaimRequestedAt: + if c.RequestedAt, err = toNumericDate(value); err == nil { + ok = true + } + case ClaimNonce: + c.Nonce, ok = value.(string) case ClaimAuthenticationContextClassReference: c.AuthenticationContextClassReference, ok = value.(string) case ClaimAuthenticationMethodsReference: c.AuthenticationMethodsReferences, ok = toStringSlice(value) + case ClaimAuthorizedParty: + c.AuthorizedParty, ok = value.(string) case ClaimAccessTokenHash: c.AccessTokenHash, ok = value.(string) case ClaimCodeHash: c.CodeHash, ok = value.(string) case ClaimStateHash: c.StateHash, ok = value.(string) + case ClaimExtra: + c.Extra, ok = value.(map[string]any) default: if c.Extra == nil { c.Extra = make(map[string]any) @@ -255,9 +253,9 @@ func (c *IDTokenClaims) ToMap() map[string]any { var ret = Copy(c.Extra) if c.JTI != "" { - ret[consts.ClaimJWTID] = c.JTI + ret[ClaimJWTID] = c.JTI } else { - ret[consts.ClaimJWTID] = uuid.New().String() + ret[ClaimJWTID] = uuid.New().String() } if c.Issuer != "" { @@ -273,69 +271,75 @@ func (c *IDTokenClaims) ToMap() map[string]any { } if len(c.Audience) > 0 { - ret[consts.ClaimAudience] = c.Audience + ret[ClaimAudience] = c.Audience } else { delete(ret, ClaimAudience) } - if len(c.Nonce) > 0 { - ret[consts.ClaimNonce] = c.Nonce + if c.ExpirationTime != nil { + ret[ClaimExpirationTime] = c.ExpirationTime.Unix() } else { - delete(ret, consts.ClaimNonce) + delete(ret, ClaimExpirationTime) } - if c.ExpirationTime != nil { - ret[consts.ClaimExpirationTime] = c.ExpirationTime.Unix() + if c.IssuedAt != nil { + ret[ClaimIssuedAt] = c.IssuedAt.Unix() } else { - delete(ret, consts.ClaimExpirationTime) + delete(ret, ClaimIssuedAt) } - if c.IssuedAt != nil { - ret[consts.ClaimIssuedAt] = c.IssuedAt.Unix() + if c.AuthTime != nil { + ret[ClaimAuthenticationTime] = c.AuthTime.Unix() } else { - delete(ret, consts.ClaimIssuedAt) + delete(ret, ClaimAuthenticationTime) } if c.RequestedAt != nil { - ret[consts.ClaimRequestedAt] = c.RequestedAt.Unix() + ret[ClaimRequestedAt] = c.RequestedAt.Unix() } else { - delete(ret, consts.ClaimRequestedAt) + delete(ret, ClaimRequestedAt) } - if c.AuthTime != nil { - ret[consts.ClaimAuthenticationTime] = c.AuthTime.Unix() + if len(c.Nonce) > 0 { + ret[ClaimNonce] = c.Nonce } else { - delete(ret, consts.ClaimAuthenticationTime) + delete(ret, ClaimNonce) } if len(c.AuthenticationContextClassReference) > 0 { - ret[consts.ClaimAuthenticationContextClassReference] = c.AuthenticationContextClassReference + ret[ClaimAuthenticationContextClassReference] = c.AuthenticationContextClassReference } else { - delete(ret, consts.ClaimAuthenticationContextClassReference) + delete(ret, ClaimAuthenticationContextClassReference) } if len(c.AuthenticationMethodsReferences) > 0 { - ret[consts.ClaimAuthenticationMethodsReference] = c.AuthenticationMethodsReferences + ret[ClaimAuthenticationMethodsReference] = c.AuthenticationMethodsReferences + } else { + delete(ret, ClaimAuthenticationMethodsReference) + } + + if len(c.AuthorizedParty) > 0 { + ret[ClaimAuthorizedParty] = c.AuthorizedParty } else { - delete(ret, consts.ClaimAuthenticationMethodsReference) + delete(ret, ClaimAuthorizedParty) } if len(c.AccessTokenHash) > 0 { - ret[consts.ClaimAccessTokenHash] = c.AccessTokenHash + ret[ClaimAccessTokenHash] = c.AccessTokenHash } else { - delete(ret, consts.ClaimAccessTokenHash) + delete(ret, ClaimAccessTokenHash) } if len(c.CodeHash) > 0 { - ret[consts.ClaimCodeHash] = c.CodeHash + ret[ClaimCodeHash] = c.CodeHash } else { - delete(ret, consts.ClaimCodeHash) + delete(ret, ClaimCodeHash) } if len(c.StateHash) > 0 { - ret[consts.ClaimStateHash] = c.StateHash + ret[ClaimStateHash] = c.StateHash } else { - delete(ret, consts.ClaimStateHash) + delete(ret, ClaimStateHash) } return ret From f1570691d81ad67de279687a6389c6387cf282d9 Mon Sep 17 00:00:00 2001 From: James Elliott Date: Mon, 21 Oct 2024 14:51:11 +1100 Subject: [PATCH 33/33] refactor: map claims converter --- token/jwt/claims_map.go | 5 +++ token/jwt/util.go | 76 +++++++++++++++++++++++++++++++++++++++++ 2 files changed, 81 insertions(+) diff --git a/token/jwt/claims_map.go b/token/jwt/claims_map.go index 904e0e09..113f804f 100644 --- a/token/jwt/claims_map.go +++ b/token/jwt/claims_map.go @@ -13,6 +13,11 @@ import ( "authelia.com/provider/oauth2/x/errorsx" ) +// NewMapClaims returns a set of MapClaims from an object that has the appropriate JSON tags. +func NewMapClaims(obj any) (claims MapClaims) { + return toMap(obj) +} + // MapClaims is a simple map based claims structure. type MapClaims map[string]any diff --git a/token/jwt/util.go b/token/jwt/util.go index 2f299772..20e3ccdd 100644 --- a/token/jwt/util.go +++ b/token/jwt/util.go @@ -8,6 +8,7 @@ import ( "crypto/sha512" "fmt" "hash" + "reflect" "regexp" "strings" @@ -467,3 +468,78 @@ func newError(message string, err error, more ...error) error { err = fmt.Errorf(format, args...) return err } + +func toMap(obj any) (result map[string]any) { + result = map[string]any{} + + if obj == nil { + return result + } + + v := reflect.TypeOf(obj) + + reflectValue := reflect.ValueOf(obj) + reflectValue = reflect.Indirect(reflectValue) + + if v.Kind() == reflect.Ptr { + v = v.Elem() + } + + for i := 0; i < v.NumField(); i++ { + tag, opts := parseTag(v.Field(i).Tag.Get("json")) + field := reflectValue.Field(i).Interface() + if tag != "" && tag != "-" { + if opts.Contains("omitempty") && isEmptyValue(reflect.ValueOf(field)) { + continue + } + + if v.Field(i).Type.Kind() == reflect.Struct { + result[tag] = toMap(field) + } else { + result[tag] = field + } + } + } + + return result +} + +type tagOptionsJSON string + +func parseTag(tag string) (string, tagOptionsJSON) { + tag, opt, _ := strings.Cut(tag, ",") + return tag, tagOptionsJSON(opt) +} + +func (o tagOptionsJSON) Contains(optionName string) bool { + if len(o) == 0 { + return false + } + + s := string(o) + + for s != "" { + var name string + name, s, _ = strings.Cut(s, ",") + if name == optionName { + return true + } + } + + return false +} + +func isEmptyValue(v reflect.Value) bool { + switch v.Kind() { + case reflect.Array, reflect.Map, reflect.Slice, reflect.String: + return v.Len() == 0 + case reflect.Bool, + reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, + reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr, + reflect.Float32, reflect.Float64, + reflect.Interface, reflect.Pointer: + return v.IsZero() + default: + return false + } +}