diff --git a/README.md b/README.md index 15ce90d2..a623cdd5 100644 --- a/README.md +++ b/README.md @@ -22,9 +22,9 @@ following list of differences: - [x] Minimum dependency is go version 1.21 - [x] Replace string values with constants where applicable [commit](https://github.com/authelia/oauth2-provider/commit/de536dc0c9cd5f080c387621799e644319587bd0) -- [ ] Simplify the internal JWT logic to leverage `github.com/golang-jwt/jwt/v5` +- [x] Simplify the internal JWT logic to leverage `github.com/golang-jwt/jwt/v5` or other such libraries -- [ ] Implement internal JWKS logic +- [x] Implement internal JWKS logic - [x] Higher Debug error information visibility (Debug Field includes the complete RFC6749 error with debug information if available) - Fixes: @@ -103,10 +103,10 @@ following list of differences: - [x] General Refactor - [x] Prevent Multiple Client Authentication Methods - [x] Client Secret Validation Interface - - [ ] JWE support for Client Authentication and Issuance + - [x] JWE support for Client Authentication and Issuance - [x] Testing Package (mocks, etc) - [ ] Clock Drift Support - - [ ] Key Management + - [x] Key Management - [ ] Injectable Clock Configurator - [x] Support `s_hash` [commit](https://github.com/authelia/oauth2-provider/commit/edbbbe9467c70a2578db4b9af4d6cd319f74886e) @@ -125,6 +125,7 @@ following list of differences: - [x] `github.com/gobuffalo/packr` - [x] `github.com/form3tech-oss/jwt-go` - [x] `github.com/dgrijalva/jwt-go` + - [x] `github.com/golang-jwt/jwt` - Migration of the following dependencies: - [x] `github.com/go-jose/go-jose/v3` => `github.com/go-jose/go-jose/v4` - [x] `github.com/golang/mock` => `github.com/uber-go/mock` diff --git a/access_error_test.go b/access_error_test.go index d43f9f85..1485e131 100644 --- a/access_error_test.go +++ b/access_error_test.go @@ -30,7 +30,7 @@ func TestWriteAccessError(t *testing.T) { rw.EXPECT().WriteHeader(http.StatusBadRequest) rw.EXPECT().Write(gomock.Any()) - provider.WriteAccessError(context.Background(), rw, nil, ErrInvalidRequest) + provider.WriteAccessError(context.TODO(), rw, nil, ErrInvalidRequest) } func TestWriteAccessError_RFC6749(t *testing.T) { @@ -62,7 +62,7 @@ func TestWriteAccessError_RFC6749(t *testing.T) { config.UseLegacyErrorFormat = c.includeExtraFields rw := httptest.NewRecorder() - provider.WriteAccessError(context.Background(), rw, nil, c.err) + provider.WriteAccessError(context.TODO(), rw, nil, c.err) var params struct { Error string `json:"error"` // specified by RFC, required diff --git a/access_write_test.go b/access_write_test.go index 9b7124d2..6de92459 100644 --- a/access_write_test.go +++ b/access_write_test.go @@ -30,7 +30,7 @@ func TestWriteAccessResponse(t *testing.T) { rw.EXPECT().Write(gomock.Any()) resp.EXPECT().ToMap().Return(map[string]any{}) - provider.WriteAccessResponse(context.Background(), rw, ar, resp) + provider.WriteAccessResponse(context.TODO(), rw, ar, resp) assert.Equal(t, consts.ContentTypeApplicationJSON, header.Get(consts.HeaderContentType)) assert.Equal(t, consts.CacheControlNoStore, header.Get(consts.HeaderCacheControl)) assert.Equal(t, consts.PragmaNoCache, header.Get(consts.HeaderPragma)) diff --git a/arguments_test.go b/arguments_test.go index a60c8035..94b2e51f 100644 --- a/arguments_test.go +++ b/arguments_test.go @@ -54,7 +54,6 @@ func TestArgumentsExactOne(t *testing.T) { for k, c := range testCases { assert.Equal(t, c.expect, c.args.ExactOne(c.exact), "%d", k) - t.Logf("Passed test case %d", k) } } @@ -106,7 +105,6 @@ func TestArgumentsHas(t *testing.T) { }, } { assert.Equal(t, c.expect, c.args.Has(c.has...), "%d", k) - t.Logf("Passed test case %d", k) } } @@ -192,7 +190,6 @@ func TestArgumentsMatchesExact(t *testing.T) { }...) for k, c := range testCases { assert.Equal(t, c.expect, c.args.MatchesExact(c.is...), "%d", k) - t.Logf("Passed test case %d", k) } } @@ -224,7 +221,6 @@ func TestArgumentsMatches(t *testing.T) { }...) for k, c := range testCases { assert.Equal(t, c.expect, c.args.Matches(c.is...), "%d", k) - t.Logf("Passed test case %d", k) } } @@ -251,6 +247,5 @@ func TestArgumentsOneOf(t *testing.T) { }, } { assert.Equal(t, c.expect, c.args.HasOneOf(c.oneOf...), "%d", k) - t.Logf("Passed test case %d", k) } } diff --git a/authorize_helper_test.go b/authorize_helper_test.go index f7bf6c98..1e6a1fdc 100644 --- a/authorize_helper_test.go +++ b/authorize_helper_test.go @@ -221,7 +221,6 @@ func TestDoesClientWhiteListRedirect(t *testing.T) { require.NotNil(t, redir, "%d", k) assert.Equal(t, c.expected, redir.String(), "%d", k) } - t.Logf("Passed test case %d", k) } } diff --git a/authorize_request_handler.go b/authorize_request_handler.go index 6bd8b451..89fa4f5a 100644 --- a/authorize_request_handler.go +++ b/authorize_request_handler.go @@ -9,8 +9,8 @@ import ( "io" "net/http" "strings" + "time" - "github.com/go-jose/go-jose/v4" "github.com/pkg/errors" "authelia.com/provider/oauth2/i18n" @@ -20,14 +20,6 @@ import ( "authelia.com/provider/oauth2/x/errorsx" ) -func wrapSigningKeyFailure(outer *RFC6749Error, inner error) *RFC6749Error { - outer = outer.WithWrap(inner).WithDebugError(inner) - if e := new(RFC6749Error); errors.As(inner, &e) { - return outer.WithHintf("%s %s", outer.Reason(), e.Reason()) - } - return outer -} - // TODO: Refactor time permitting. // //nolint:gocyclo @@ -74,12 +66,13 @@ func (f *Fosite) authorizeRequestParametersFromOpenIDConnectRequestObject(ctx co } var ( - algAny, algNone bool + alg string + algAny bool ) - switch alg := client.GetRequestObjectSigningAlg(); alg { + switch alg = client.GetRequestObjectSigningAlg(); alg { case consts.JSONWebTokenAlgNone: - algNone = true + break case "": algAny = true default: @@ -123,65 +116,32 @@ func (f *Fosite) authorizeRequestParametersFromOpenIDConnectRequestObject(ctx co assertion = request.Form.Get(consts.FormParameterRequest) } - token, err := jwt.ParseWithClaims(assertion, jwt.MapClaims{}, func(t *jwt.Token) (key any, err error) { - // request_object_signing_alg - OPTIONAL. - // JWS [JWS] alg algorithm [JWA] that MUST be used for signing Request Objects sent to the OP. All Request Objects from this Client MUST be rejected, - // if not signed with this algorithm. Request Objects are described in Section 6.1 of OpenID Connect Core 1.0 [OpenID.Core]. This algorithm MUST - // be used both when the Request Object is passed by value (using the request parameter) and when it is passed by reference (using the request_uri parameter). - // Servers SHOULD support RS256. The value none MAY be used. The default, if omitted, is that any algorithm supported by the OP and the RP MAY be used. - if !algAny && client.GetRequestObjectSigningAlg() != fmt.Sprintf("%s", t.Header[consts.JSONWebTokenHeaderAlgorithm]) { - return nil, errorsx.WithStack(ErrInvalidRequestObject.WithHintf(hintRequestObjectValidate, hintRequestObjectPrefix(openid)).WithDebugf("The OAuth 2.0 client with id '%s' expects request objects to be signed with the '%s' algorithm but the request object was signed with the '%s' algorithm.", request.GetClient().GetID(), client.GetRequestObjectSigningAlg(), t.Header[consts.JSONWebTokenHeaderAlgorithm])) - } - - if t.Method == jwt.SigningMethodNone { - algNone = true - - return jwt.UnsafeAllowNoneSignatureType, nil - } else if algNone { - return nil, errorsx.WithStack(ErrInvalidRequestObject.WithHintf(hintRequestObjectValidate, hintRequestObjectPrefix(openid)).WithDebugf("The OAuth 2.0 client with id '%s' expects request objects to be signed with the '%s' algorithm but the request object was signed with the '%s' algorithm.", request.GetClient().GetID(), client.GetRequestObjectSigningAlg(), t.Header[consts.JSONWebTokenHeaderAlgorithm])) - } - - switch t.Method { - case jose.RS256, jose.RS384, jose.RS512: - if key, err = f.findClientPublicJWK(ctx, client, t, true); err != nil { - return nil, wrapSigningKeyFailure( - ErrInvalidRequestObject.WithHint("Unable to retrieve RSA signing key from OAuth 2.0 Client."), err) - } + issuer := f.Config.GetIDTokenIssuer(ctx) - return key, nil - case jose.ES256, jose.ES384, jose.ES512: - if key, err = f.findClientPublicJWK(ctx, client, t, false); err != nil { - return nil, wrapSigningKeyFailure( - ErrInvalidRequestObject.WithHint("Unable to retrieve ECDSA signing key from OAuth 2.0 Client."), err) - } - - return key, nil - case jose.PS256, jose.PS384, jose.PS512: - if key, err = f.findClientPublicJWK(ctx, client, t, true); err != nil { - return nil, wrapSigningKeyFailure( - ErrInvalidRequestObject.WithHint("Unable to retrieve RSA signing key from OAuth 2.0 Client."), err) - } - - return key, nil - default: - return nil, errorsx.WithStack(ErrInvalidRequestObject.WithHintf(hintRequestObjectValidate, hintRequestObjectPrefix(openid)).WithDebugf("The OAuth 2.0 client with id '%s' provided a request object that uses the unsupported signing algorithm '%s'.", request.GetClient().GetID(), t.Header[consts.JSONWebTokenHeaderAlgorithm])) - } - }) + strategy := f.Config.GetJWTStrategy(ctx) + token, err := strategy.Decode(ctx, assertion, jwt.WithSigAlgorithm(jwt.SignatureAlgorithmsNone...), jwt.WithJARClient(client)) if err != nil { - // Do not re-process already enhanced errors - var e *jwt.ValidationError - if errors.As(err, &e) { - if e.Inner != nil { - return e.Inner - } + return errorsx.WithStack(fmtRequestObjectDecodeError(token, client, issuer, openid, err)) + } - return errorsx.WithStack(ErrInvalidRequestObject.WithHintf(hintRequestObjectValidate, hintRequestObjectPrefix(openid)).WithDebugf("The OAuth 2.0 client with id '%s' provided a request object which failed to validate with error: %+v.", request.GetClient().GetID(), err).WithWrap(err)) - } + optsValidHeader := []jwt.HeaderValidationOption{ + jwt.ValidateKeyID(client.GetRequestObjectSigningKeyID()), + jwt.ValidateAlgorithm(client.GetRequestObjectSigningAlg()), + jwt.ValidateEncryptionKeyID(client.GetRequestObjectEncryptionKeyID()), + jwt.ValidateKeyAlgorithm(client.GetRequestObjectEncryptionAlg()), + jwt.ValidateContentEncryption(client.GetRequestObjectEncryptionEnc()), + } - return err - } else if err = token.Claims.Valid(); err != nil { - return errorsx.WithStack(ErrInvalidRequestObject.WithHintf(hintRequestObjectValidate, hintRequestObjectPrefix(openid)).WithDebugf("The OAuth 2.0 client with id '%s' provided a request object which could not be validated because its claims could not be validated with error: %+v.", request.GetClient().GetID(), err).WithWrap(err)) + if err = token.Valid(optsValidHeader...); err != nil { + return errorsx.WithStack(fmtRequestObjectDecodeError(token, client, issuer, openid, err)) + } + + if algAny && token.SignatureAlgorithm == consts.JSONWebTokenAlgNone { + return errorsx.WithStack( + ErrInvalidRequestObject. + WithHintf("%s client provided a request object that has an invalid 'kid' or 'alg' header value.", hintRequestObjectPrefix(openid)). + WithDebugf("%s client with id '%s' was not explicitly registered with a 'request_object_signing_alg' value of 'none' but the request object had the 'alg' value 'none' in the header.", hintRequestObjectPrefix(openid), client.GetID())) } claims := token.Claims @@ -191,7 +151,7 @@ func (f *Fosite) authorizeRequestParametersFromOpenIDConnectRequestObject(ctx co v any ) - for k, v = range claims { + for k, v = range claims.ToMapClaims() { switch k { case consts.FormParameterRequest, consts.FormParameterRequestURI: // The request and request_uri parameters MUST NOT be included in Request Objects. @@ -230,57 +190,20 @@ func (f *Fosite) authorizeRequestParametersFromOpenIDConnectRequestObject(ctx co } } - if !algNone { - issuer := f.Config.GetIDTokenIssuer(ctx) - - if len(issuer) == 0 { - return errorsx.WithStack(ErrServerError.WithHintf("%s request could not be processed due to an authorization server configuration issue.", hintRequestObjectPrefix(openid)).WithDebugf("The OAuth 2.0 client with id '%s' provided a request object that was signed but the issuer for this authorization server is not known.", request.GetClient().GetID())) - } - - if v, ok = claims[consts.ClaimIssuer]; !ok { - return errorsx.WithStack(ErrInvalidRequestObject.WithHintf(hintRequestObjectInvalidAuthorizationClaim, hintRequestObjectPrefix(openid)).WithDebugf(debugRequestObjectSignedAbsentClaim, request.GetClient().GetID(), consts.ClaimIssuer)) - } - - clientID := request.GetClient().GetID() - - if value, ok = v.(string); !ok { - return errorsx.WithStack(ErrInvalidRequestObject.WithHintf(hintRequestObjectInvalidAuthorizationClaim, hintRequestObjectPrefix(openid)).WithDebugf(debugRequestObjectValueTypeNotString, request.GetClient().GetID(), consts.ClaimIssuer, v, clientID, v)) - } - - if value != clientID { - return errorsx.WithStack(ErrInvalidRequestObject.WithHintf(hintRequestObjectInvalidAuthorizationClaim, hintRequestObjectPrefix(openid)).WithDebugf(debugRequestObjectValueMismatch, clientID, consts.ClaimIssuer, value, clientID)) - } - - if v, ok = claims[consts.ClaimAudience]; !ok { - return errorsx.WithStack(ErrInvalidRequestObject.WithHintf(hintRequestObjectInvalidAuthorizationClaim, hintRequestObjectPrefix(openid)).WithDebugf(debugRequestObjectSignedAbsentClaim, request.GetClient().GetID(), consts.ClaimAudience)) - } - - var valid bool - - switch t := v.(type) { - case string: - valid = t == issuer - case []string: - for _, value = range t { - if value == issuer { - valid = true - - break - } - } - case []any: - for _, x := range t { - if value, ok = x.(string); ok && value == issuer { - valid = true + if len(issuer) == 0 { + return errorsx.WithStack(ErrServerError.WithHintf("%s request could not be processed due to an authorization server configuration issue.", hintRequestObjectPrefix(openid)).WithDebugf("The OAuth 2.0 client with id '%s' provided a request object that was signed but the issuer for this authorization server is not known.", request.GetClient().GetID())) + } - break - } - } - } + optsValidClaims := []jwt.ClaimValidationOption{ + jwt.ValidateTimeFunc(func() time.Time { + return time.Now().UTC() + }), + jwt.ValidateIssuer(client.GetID()), + jwt.ValidateAudienceAny(issuer), + } - if !valid { - return errorsx.WithStack(ErrInvalidRequestObject.WithHintf(hintRequestObjectInvalidAuthorizationClaim, hintRequestObjectPrefix(openid)).WithDebugf("The OAuth 2.0 client with id '%s' included a request object with a 'aud' claim with the values '%s' which is required match the issuer '%s'.", request.GetClient().GetID(), value, issuer)) - } + if err = claims.Valid(optsValidClaims...); err != nil { + return errorsx.WithStack(fmtRequestObjectDecodeError(token, client, issuer, openid, err)) } claimScope := RemoveEmpty(strings.Split(request.Form.Get(consts.FormParameterScope), " ")) @@ -595,3 +518,81 @@ func (f *Fosite) newAuthorizeRequest(ctx context.Context, r *http.Request, isPAR return request, nil } + +func fmtRequestObjectDecodeError(token *jwt.Token, client JARClient, issuer string, openid bool, inner error) (outer *RFC6749Error) { + outer = ErrInvalidRequestObject.WithWrap(inner).WithHintf("%s request object could not be decoded or validated.", hintRequestObjectPrefix(openid)) + + if errJWTValidation := new(jwt.ValidationError); errors.As(inner, &errJWTValidation) { + switch { + case errJWTValidation.Has(jwt.ValidationErrorHeaderKeyIDInvalid): + return outer.WithDebugf("%s client with id '%s' expects request objects to be signed with the 'kid' header value '%s' due to the client registration 'request_object_signing_key_id' value but the request object was signed with the 'kid' header value '%s'.", hintRequestObjectPrefix(openid), client.GetID(), client.GetRequestObjectSigningKeyID(), token.KeyID) + case errJWTValidation.Has(jwt.ValidationErrorHeaderAlgorithmInvalid): + return outer.WithDebugf("%s client with id '%s' expects request objects to be signed with the 'alg' header value '%s' due to the client registration 'request_object_signing_alg' value but the request object was signed with the 'alg' header value '%s'.", hintRequestObjectPrefix(openid), client.GetID(), client.GetRequestObjectSigningAlg(), token.SignatureAlgorithm) + case errJWTValidation.Has(jwt.ValidationErrorHeaderTypeInvalid): + return outer.WithDebugf("%s client with id '%s' expects request objects to be signed with the 'typ' header value '%s' but the request object was signed with the 'typ' header value '%s'.", hintRequestObjectPrefix(openid), client.GetID(), consts.JSONWebTokenTypeJWT, token.Header[consts.JSONWebTokenHeaderType]) + case errJWTValidation.Has(jwt.ValidationErrorHeaderEncryptionTypeInvalid): + return outer.WithDebugf("%s client with id '%s' expects request objects to be encrypted with the 'typ' header value '%s' but the request object was encrypted with the 'typ' header value '%s'.", hintRequestObjectPrefix(openid), client.GetID(), consts.JSONWebTokenTypeJWT, token.HeaderJWE[consts.JSONWebTokenHeaderType]) + case errJWTValidation.Has(jwt.ValidationErrorHeaderContentTypeInvalidMismatch): + return outer.WithDebugf("%s client with id '%s' expects request objects to be encrypted with a 'cty' header value and signed with a 'typ' value that match but the request object was encrypted with the 'cty' header value '%s' and signed with the 'typ' header value '%s'.", hintRequestObjectPrefix(openid), client.GetID(), token.HeaderJWE[consts.JSONWebTokenHeaderContentType], token.HeaderJWE[consts.JSONWebTokenHeaderType]) + case errJWTValidation.Has(jwt.ValidationErrorHeaderContentTypeInvalid): + return outer.WithDebugf("%s client with id '%s' expects request objects to be encrypted with the 'cty' header value '%s' but the request object was encrypted with the 'cty' header value '%s'.", hintRequestObjectPrefix(openid), client.GetID(), consts.JSONWebTokenTypeJWT, token.HeaderJWE[consts.JSONWebTokenHeaderContentType]) + case errJWTValidation.Has(jwt.ValidationErrorHeaderEncryptionKeyIDInvalid): + return outer.WithDebugf("%s client with id '%s' expects request objects to be encrypted with the 'kid' header value '%s' due to the client registration 'request_object_encryption_key_id' value but the request object was encrypted with the 'kid' header value '%s'.", hintRequestObjectPrefix(openid), client.GetID(), client.GetRequestObjectEncryptionKeyID(), token.EncryptionKeyID) + case errJWTValidation.Has(jwt.ValidationErrorHeaderKeyAlgorithmInvalid): + return outer.WithDebugf("%s client with id '%s' expects request objects to be encrypted with the 'alg' header value '%s' due to the client registration 'request_object_encryption_alg' value but the request object was encrypted with the 'alg' header value '%s'.", hintRequestObjectPrefix(openid), client.GetID(), client.GetRequestObjectEncryptionAlg(), token.KeyAlgorithm) + case errJWTValidation.Has(jwt.ValidationErrorHeaderContentEncryptionInvalid): + return outer.WithDebugf("%s client with id '%s' expects request objects to be encrypted with the 'enc' header value '%s' due to the client registration 'request_object_encryption_enc' value but the request object was encrypted with the 'enc' header value '%s'.", hintRequestObjectPrefix(openid), client.GetID(), client.GetRequestObjectEncryptionEnc(), token.ContentEncryption) + case errJWTValidation.Has(jwt.ValidationErrorMalformedNotCompactSerialized): + return outer.WithDebugf("%s client with id '%s' provided a request object that was malformed. The request object does not appear to be a JWE or JWS compact serialized JWT.", hintRequestObjectPrefix(openid), client.GetID()) + case errJWTValidation.Has(jwt.ValidationErrorMalformed): + return outer.WithDebugf("%s client with id '%s' provided a request object that was malformed. %s.", hintRequestObjectPrefix(openid), client.GetID(), strings.TrimPrefix(errJWTValidation.Error(), "go-jose/go-jose: ")) + case errJWTValidation.Has(jwt.ValidationErrorUnverifiable): + return outer.WithDebugf("%s client with id '%s' provided a request object that was not able to be verified. %s.", hintRequestObjectPrefix(openid), client.GetID(), strings.TrimPrefix(errJWTValidation.Error(), "go-jose/go-jose: ")) + case errJWTValidation.Has(jwt.ValidationErrorSignatureInvalid): + return outer.WithDebugf("%s client with id '%s' provided a request object that has an invalid signature.", hintRequestObjectPrefix(openid), client.GetID()) + case errJWTValidation.Has(jwt.ValidationErrorExpired): + exp, err := token.Claims.GetExpirationTime() + if err == nil { + return outer.WithDebugf("%s client with id '%s' provided a request object that was expired. The request object expired at %d.", hintRequestObjectPrefix(openid), client.GetID(), exp.Int64()) + } else { + return outer.WithDebugf("%s client with id '%s' provided a request object that was expired. The request object does not have an 'exp' claim or it has an invalid type.", hintRequestObjectPrefix(openid), client.GetID()) + } + case errJWTValidation.Has(jwt.ValidationErrorIssuedAt): + iat, err := token.Claims.GetIssuedAt() + if err == nil { + return outer.WithDebugf("%s client with id '%s' provided a request object that was issued in the future. The request object was issued at %d.", hintRequestObjectPrefix(openid), client.GetID(), iat.Int64()) + } else { + return outer.WithDebugf("%s client with id '%s' provided a request object that was issued in the future. The request object does not have an 'iat' claim or it has an invalid type.", hintRequestObjectPrefix(openid), client.GetID()) + } + case errJWTValidation.Has(jwt.ValidationErrorNotValidYet): + nbf, err := token.Claims.GetNotBefore() + if err == nil { + return outer.WithDebugf("%s client with id '%s' provided a request object that was issued in the future. The request object is not valid before %d.", hintRequestObjectPrefix(openid), client.GetID(), nbf.Int64()) + } else { + return outer.WithDebugf("%s client with id '%s' provided a request object that was issued in the future. The request object does not have an 'nbf' claim or it has an invalid type.", hintRequestObjectPrefix(openid), client.GetID()) + } + case errJWTValidation.Has(jwt.ValidationErrorIssuer): + iss, err := token.Claims.GetIssuer() + if err == nil { + return outer.WithDebugf("%s client with id '%s' provided a request object that has an invalid issuer. The request object was expected to have an 'iss' claim which matches the value '%s' but the 'iss' claim had the value '%s'.", hintRequestObjectPrefix(openid), client.GetID(), client.GetID(), iss) + } else { + return outer.WithDebugf("%s client with id '%s' provided a request object that has an invalid issuer. The request object does not have an 'iss' claim or it has an invalid type.", hintRequestObjectPrefix(openid), client.GetID()) + } + case errJWTValidation.Has(jwt.ValidationErrorAudience): + aud, err := token.Claims.GetAudience() + if err == nil { + return outer.WithDebugf("%s client with id '%s' provided a request object that has an invalid audience. The request object was expected to have an 'aud' claim which matches the issuer value of '%s' but the 'aud' claim had the values '%s'.", hintRequestObjectPrefix(openid), client.GetID(), issuer, strings.Join(aud, "', '")) + } else { + return outer.WithDebugf("%s client with id '%s' provided a request object that has an invalid audience. The request object does not have an 'aud' claim or it has an invalid type.", hintRequestObjectPrefix(openid), client.GetID()) + } + case errJWTValidation.Has(jwt.ValidationErrorClaimsInvalid): + return outer.WithDebugf("%s client with id '%s' provided a request object that had one or more invalid claims. Error occurred trying to validate the request objects claims: %s", hintRequestObjectPrefix(openid), client.GetID(), strings.TrimPrefix(errJWTValidation.Error(), "go-jose/go-jose: ")) + default: + return outer.WithDebugf("%s client with id '%s' provided a request object that could not be validated. Error occurred trying to validate the request object: %s", hintRequestObjectPrefix(openid), client.GetID(), strings.TrimPrefix(errJWTValidation.Error(), "go-jose/go-jose: ")) + } + } else if errJWKLookup := new(jwt.JWKLookupError); errors.As(inner, &errJWKLookup) { + return outer.WithDebugf("%s client with id '%s' provided a request object that could not be validated due to a key lookup error. %s.", hintRequestObjectPrefix(openid), client.GetID(), errJWKLookup.Description) + } else { + return outer.WithDebugf("%s client with id '%s' provided a request object that could not be validated. %s.", hintRequestObjectPrefix(openid), client.GetID(), ErrorToDebugRFC6749Error(inner).Error()) + } +} diff --git a/authorize_request_handler_oidc_request_test.go b/authorize_request_handler_oidc_request_test.go index 7d957d1b..32920a4a 100644 --- a/authorize_request_handler_oidc_request_test.go +++ b/authorize_request_handler_oidc_request_test.go @@ -5,15 +5,20 @@ package oauth2 import ( "context" + "crypto/ecdsa" + "crypto/elliptic" "crypto/rand" "crypto/rsa" + "encoding/base64" "encoding/json" "fmt" "net/http" "net/http/httptest" "net/url" "regexp" + "strings" "testing" + "time" "github.com/go-jose/go-jose/v4" "github.com/stretchr/testify/assert" @@ -24,33 +29,129 @@ import ( ) func TestAuthorizeRequestParametersFromOpenIDConnectRequestObject(t *testing.T) { - key, err := rsa.GenerateKey(rand.Reader, 1024) //nolint:gosec + keyRSA, err := rsa.GenerateKey(rand.Reader, 2048) require.NoError(t, err) - jwks := &jose.JSONWebKeySet{ + keyECDSA, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) + require.NoError(t, err) + + jwkNone := &jose.JSONWebKey{ + Key: jwt.UnsafeAllowNoneSignatureType, + } + + rawClientSecret := "aaaaaaaaaaaaaaabbbbbbbbbbbbbbbbbbbbbbbcccccccccccccccccccccddddddddddddddddddddddd" + + clientSecretHS256 := NewPlainTextClientSecret(rawClientSecret) + + jwkEncAES256, err := jwt.NewClientSecretJWK(context.TODO(), []byte(rawClientSecret), "", string(jose.A256GCMKW), "", consts.JSONWebTokenUseEncryption) + require.NoError(t, err) + + jwkSigHS := &jose.JSONWebKey{ + Key: []byte(rawClientSecret), + KeyID: "hs256-sig", + Algorithm: string(jose.HS256), + Use: consts.JSONWebTokenUseSignature, + } + + jwkPublicSigRSA := &jose.JSONWebKey{ + Key: keyRSA.Public(), + KeyID: "rs256-sig", + Algorithm: string(jose.RS256), + Use: consts.JSONWebTokenUseSignature, + } + + jwkPrivateSigRSA := &jose.JSONWebKey{ + Key: keyRSA, + KeyID: "rs256-sig", + Algorithm: string(jose.RS256), + Use: consts.JSONWebTokenUseSignature, + } + + jwkPublicSigRSA384 := &jose.JSONWebKey{ + Key: keyRSA.Public(), + KeyID: "rs384-sig", + Algorithm: string(jose.RS384), + Use: consts.JSONWebTokenUseSignature, + } + + jwkPrivateSigRSA384 := &jose.JSONWebKey{ + Key: keyRSA, + KeyID: "rs384-sig", + Algorithm: string(jose.RS384), + Use: consts.JSONWebTokenUseSignature, + } + + jwkPublicSigECDSA := &jose.JSONWebKey{ + Key: keyECDSA.Public(), + KeyID: "es256-sig", + Algorithm: string(jose.ES256), + Use: consts.JSONWebTokenUseSignature, + } + + jwkPrivateSigECDSA := &jose.JSONWebKey{ + Key: keyECDSA, + KeyID: "es256-sig", + Algorithm: string(jose.ES256), + Use: consts.JSONWebTokenUseSignature, + } + + jwkPublicEncECDSA := &jose.JSONWebKey{ + Key: keyECDSA.Public(), + KeyID: "es256-enc", + Algorithm: string(jose.ECDH_ES_A128KW), + Use: consts.JSONWebTokenUseEncryption, + } + + jwkPrivateEncECDSA := &jose.JSONWebKey{ + Key: keyECDSA, + KeyID: "es256-enc", + Algorithm: string(jose.ECDH_ES_A128KW), + Use: consts.JSONWebTokenUseEncryption, + } + + jwksPrivate := &jose.JSONWebKeySet{ Keys: []jose.JSONWebKey{ - { - KeyID: "kid-foo", - Use: "sig", - Key: &key.PublicKey, - }, + *jwkPrivateSigRSA, + *jwkPrivateSigECDSA, + *jwkPrivateEncECDSA, }, } - assertionRequestObjectValid := mustGenerateAssertion(t, jwt.MapClaims{consts.ClaimIssuer: "foo", consts.ClaimClientIdentifier: "foo", consts.ClaimAudience: []string{"https://auth.example.com"}, consts.FormParameterScope: "foo", "foo": "bar", "baz": "baz", consts.FormParameterResponseType: consts.ResponseTypeImplicitFlowToken, consts.FormParameterResponseMode: consts.ResponseModeFormPost}, key, "kid-foo") - assertionRequestObjectInvalidRequestInRequest := mustGenerateAssertion(t, jwt.MapClaims{consts.ClaimIssuer: "foo", consts.FormParameterRequest: "abc", consts.ClaimClientIdentifier: "foo", consts.ClaimAudience: []string{"https://auth.example.com"}, consts.FormParameterScope: "foo", "foo": "bar", "baz": "baz", consts.FormParameterResponseType: consts.ResponseTypeImplicitFlowToken, consts.FormParameterResponseMode: consts.ResponseModeFormPost}, key, "kid-foo") - assertionRequestObjectInvalidRequestURIInRequest := mustGenerateAssertion(t, jwt.MapClaims{consts.ClaimIssuer: "foo", consts.FormParameterRequestURI: "https://auth.example.com", consts.ClaimClientIdentifier: "foo", consts.ClaimAudience: []string{"https://auth.example.com"}, consts.FormParameterScope: "foo", "foo": "bar", "baz": "baz", consts.FormParameterResponseType: consts.ResponseTypeImplicitFlowToken, consts.FormParameterResponseMode: consts.ResponseModeFormPost}, key, "kid-foo") - assertionRequestObjectInvalidClientIDValue := mustGenerateAssertion(t, jwt.MapClaims{consts.ClaimIssuer: "foo", consts.ClaimClientIdentifier: 100, consts.ClaimAudience: []string{"https://auth.example.com"}, consts.FormParameterScope: "foo", "foo": "bar", "baz": "baz", consts.FormParameterResponseType: consts.ResponseTypeImplicitFlowToken, consts.FormParameterResponseMode: consts.ResponseModeFormPost}, key, "kid-foo") - assertionRequestObjectInvalidResponseTypeValue := mustGenerateAssertion(t, jwt.MapClaims{consts.ClaimIssuer: "foo", consts.ClaimAudience: []string{"https://auth.example.com"}, consts.FormParameterScope: "foo", "foo": "bar", "baz": "baz", consts.FormParameterResponseType: 100, consts.FormParameterResponseMode: consts.ResponseModeFormPost}, key, "kid-foo") - assertionRequestObjectInvalidAudience := mustGenerateAssertion(t, jwt.MapClaims{consts.ClaimIssuer: "foo", consts.ClaimAudience: []string{"https://auth.not-example.com"}, consts.FormParameterScope: "foo", "foo": "bar", "baz": "baz", consts.FormParameterResponseType: consts.ResponseTypeAuthorizationCodeFlow, consts.FormParameterResponseMode: consts.ResponseModeFormPost}, key, "kid-foo") - assertionRequestObjectInvalidIssuer := mustGenerateAssertion(t, jwt.MapClaims{consts.ClaimIssuer: "not-foo", consts.ClaimAudience: []string{"https://auth.example.com"}, consts.FormParameterScope: "foo", "foo": "bar", "baz": "baz", consts.FormParameterResponseType: consts.ResponseTypeAuthorizationCodeFlow, consts.FormParameterResponseMode: consts.ResponseModeFormPost}, key, "kid-foo") - assertionRequestObjectValidWithoutKID := mustGenerateAssertion(t, jwt.MapClaims{consts.ClaimIssuer: "foo", consts.ClaimAudience: []string{"https://auth.example.com"}, consts.FormParameterScope: "foo", "foo": "bar", "baz": "baz"}, key, "") - assertionRequestObjectValidNone := mustGenerateNoneAssertion(t, jwt.MapClaims{consts.FormParameterScope: "foo", "foo": "bar", "baz": "baz", consts.FormParameterState: "some-state"}) + jwksPublic := &jose.JSONWebKeySet{ + Keys: []jose.JSONWebKey{ + *jwkPublicSigRSA, + *jwkPublicSigRSA384, + *jwkPublicSigECDSA, + *jwkPublicEncECDSA, + }, + } + + assertionRequestObjectValid := mustGenerateRequestObjectJWS(t, jwt.MapClaims{consts.ClaimIssuer: "foo", consts.ClaimClientIdentifier: "foo", consts.ClaimAudience: []string{"https://auth.example.com"}, consts.FormParameterScope: "foo", "foo": "bar", "baz": "baz", consts.FormParameterResponseType: consts.ResponseTypeImplicitFlowToken, consts.FormParameterResponseMode: consts.ResponseModeFormPost}, nil, jwkPrivateSigRSA) + assertionRequestObjectInvalidExpired := mustGenerateRequestObjectJWS(t, jwt.MapClaims{consts.ClaimIssuer: "foo", consts.ClaimExpirationTime: time.Now().Add(-time.Hour).UTC().Unix(), consts.ClaimClientIdentifier: "foo", consts.ClaimAudience: []string{"https://auth.example.com"}, consts.FormParameterScope: "foo", "foo": "bar", "baz": "baz", consts.FormParameterResponseType: consts.ResponseTypeImplicitFlowToken, consts.FormParameterResponseMode: consts.ResponseModeFormPost}, nil, jwkPrivateSigRSA) + assertionRequestObjectInvalidFuture := mustGenerateRequestObjectJWS(t, jwt.MapClaims{consts.ClaimIssuer: "foo", consts.ClaimIssuedAt: time.Now().Add(time.Hour).UTC().Unix(), consts.ClaimClientIdentifier: "foo", consts.ClaimAudience: []string{"https://auth.example.com"}, consts.FormParameterScope: "foo", "foo": "bar", "baz": "baz", consts.FormParameterResponseType: consts.ResponseTypeImplicitFlowToken, consts.FormParameterResponseMode: consts.ResponseModeFormPost}, nil, jwkPrivateSigRSA) + assertionRequestObjectInvalidNotValidYet := mustGenerateRequestObjectJWS(t, jwt.MapClaims{consts.ClaimIssuer: "foo", consts.ClaimNotBefore: time.Now().Add(time.Hour).UTC().Unix(), consts.ClaimClientIdentifier: "foo", consts.ClaimAudience: []string{"https://auth.example.com"}, consts.FormParameterScope: "foo", "foo": "bar", "baz": "baz", consts.FormParameterResponseType: consts.ResponseTypeImplicitFlowToken, consts.FormParameterResponseMode: consts.ResponseModeFormPost}, nil, jwkPrivateSigRSA) + assertionRequestObjectInvalidSignature := mangleSig(assertionRequestObjectValid) + assertionRequestObjectInvalidKID := mustGenerateRequestObjectJWS(t, jwt.MapClaims{consts.ClaimIssuer: "foo", consts.ClaimClientIdentifier: "foo", consts.ClaimAudience: []string{"https://auth.example.com"}, consts.FormParameterScope: "foo", "foo": "bar", "baz": "baz", consts.FormParameterResponseType: consts.ResponseTypeImplicitFlowToken, consts.FormParameterResponseMode: consts.ResponseModeFormPost}, nil, jwkPrivateSigRSA384) + assertionRequestObjectInvalidTyp := mustGenerateRequestObjectJWS(t, jwt.MapClaims{consts.ClaimIssuer: "foo", consts.ClaimClientIdentifier: "foo", consts.ClaimAudience: []string{"https://auth.example.com"}, consts.FormParameterScope: "foo", "foo": "bar", "baz": "baz", consts.FormParameterResponseType: consts.ResponseTypeImplicitFlowToken, consts.FormParameterResponseMode: consts.ResponseModeFormPost}, &jwt.Headers{Extra: map[string]any{consts.JSONWebTokenHeaderType: "abc"}}, jwkPrivateSigRSA) + assertionRequestObjectInvalidJWEContentType := mustGenerateRequestObjectJWE(t, jwt.MapClaims{consts.ClaimIssuer: "foo", consts.ClaimClientIdentifier: "foo", consts.ClaimAudience: []string{"https://auth.example.com"}, consts.FormParameterScope: "foo", "foo": "bar", "baz": "baz", consts.FormParameterResponseType: consts.ResponseTypeImplicitFlowToken, consts.FormParameterResponseMode: consts.ResponseModeFormPost}, nil, &jwt.Headers{Extra: map[string]any{consts.JSONWebTokenHeaderContentType: "at+jwt"}}, jwkPrivateSigRSA, jwkEncAES256, jose.A256GCM) + assertionRequestObjectInvalidJWEType := mustGenerateRequestObjectJWE(t, jwt.MapClaims{consts.ClaimIssuer: "foo", consts.ClaimClientIdentifier: "foo", consts.ClaimAudience: []string{"https://auth.example.com"}, consts.FormParameterScope: "foo", "foo": "bar", "baz": "baz", consts.FormParameterResponseType: consts.ResponseTypeImplicitFlowToken, consts.FormParameterResponseMode: consts.ResponseModeFormPost}, nil, &jwt.Headers{Extra: map[string]any{consts.JSONWebTokenHeaderType: "at+jwt"}}, jwkPrivateSigRSA, jwkEncAES256, jose.A256GCM) + assertionRequestObjectValidJWE := mustGenerateRequestObjectJWE(t, jwt.MapClaims{consts.ClaimIssuer: "foo", consts.ClaimClientIdentifier: "foo", consts.ClaimAudience: []string{"https://auth.example.com"}, consts.FormParameterScope: "foo", "foo": "bar", "baz": "baz", consts.FormParameterResponseType: consts.ResponseTypeImplicitFlowToken, consts.FormParameterResponseMode: consts.ResponseModeFormPost}, nil, nil, jwkPrivateSigRSA, jwkEncAES256, jose.A256GCM) + assertionRequestObjectValidAssymetricJWE := mustGenerateRequestObjectJWE(t, jwt.MapClaims{consts.ClaimIssuer: "foo", consts.ClaimClientIdentifier: "foo", consts.ClaimAudience: []string{"https://auth.example.com"}, consts.FormParameterScope: "foo", "foo": "bar", "baz": "baz", consts.FormParameterResponseType: consts.ResponseTypeImplicitFlowToken, consts.FormParameterResponseMode: consts.ResponseModeFormPost}, nil, nil, jwkPrivateSigECDSA, jwkPublicEncECDSA, jose.A128GCM) + assertionRequestObjectEmptyHS256 := mustGenerateRequestObjectJWS(t, jwt.MapClaims{}, nil, jwkSigHS) + assertionRequestObjectInvalidRequestInRequest := mustGenerateRequestObjectJWS(t, jwt.MapClaims{consts.ClaimIssuer: "foo", consts.FormParameterRequest: "abc", consts.ClaimClientIdentifier: "foo", consts.ClaimAudience: []string{"https://auth.example.com"}, consts.FormParameterScope: "foo", "foo": "bar", "baz": "baz", consts.FormParameterResponseType: consts.ResponseTypeImplicitFlowToken, consts.FormParameterResponseMode: consts.ResponseModeFormPost}, nil, jwkPrivateSigRSA) + assertionRequestObjectInvalidRequestURIInRequest := mustGenerateRequestObjectJWS(t, jwt.MapClaims{consts.ClaimIssuer: "foo", consts.FormParameterRequestURI: "https://auth.example.com", consts.ClaimClientIdentifier: "foo", consts.ClaimAudience: []string{"https://auth.example.com"}, consts.FormParameterScope: "foo", "foo": "bar", "baz": "baz", consts.FormParameterResponseType: consts.ResponseTypeImplicitFlowToken, consts.FormParameterResponseMode: consts.ResponseModeFormPost}, nil, jwkPrivateSigRSA) + assertionRequestObjectInvalidClientIDValue := mustGenerateRequestObjectJWS(t, jwt.MapClaims{consts.ClaimIssuer: "foo", consts.ClaimClientIdentifier: 100, consts.ClaimAudience: []string{"https://auth.example.com"}, consts.FormParameterScope: "foo", "foo": "bar", "baz": "baz", consts.FormParameterResponseType: consts.ResponseTypeImplicitFlowToken, consts.FormParameterResponseMode: consts.ResponseModeFormPost}, nil, jwkPrivateSigRSA) + assertionRequestObjectInvalidResponseTypeValue := mustGenerateRequestObjectJWS(t, jwt.MapClaims{consts.ClaimIssuer: "foo", consts.ClaimAudience: []string{"https://auth.example.com"}, consts.FormParameterScope: "foo", "foo": "bar", "baz": "baz", consts.FormParameterResponseType: 100, consts.FormParameterResponseMode: consts.ResponseModeFormPost}, nil, jwkPrivateSigRSA) + assertionRequestObjectInvalidAudience := mustGenerateRequestObjectJWS(t, jwt.MapClaims{consts.ClaimIssuer: "foo", consts.ClaimAudience: []string{"https://auth.not-example.com"}, consts.FormParameterScope: "foo", "foo": "bar", "baz": "baz", consts.FormParameterResponseType: consts.ResponseTypeAuthorizationCodeFlow, consts.FormParameterResponseMode: consts.ResponseModeFormPost}, nil, jwkPrivateSigRSA) + assertionRequestObjectInvalidIssuer := mustGenerateRequestObjectJWS(t, jwt.MapClaims{consts.ClaimIssuer: "not-foo", consts.ClaimAudience: []string{"https://auth.example.com"}, consts.FormParameterScope: "foo", "foo": "bar", "baz": "baz", consts.FormParameterResponseType: consts.ResponseTypeAuthorizationCodeFlow, consts.FormParameterResponseMode: consts.ResponseModeFormPost}, nil, jwkPrivateSigRSA) + assertionRequestObjectValidWithoutKID := mustGenerateRequestObjectJWS(t, jwt.MapClaims{consts.ClaimIssuer: "foo", consts.ClaimAudience: []string{"https://auth.example.com"}, consts.FormParameterScope: "foo", "foo": "bar", "baz": "baz"}, nil, &jose.JSONWebKey{Key: keyRSA, Algorithm: string(jose.RS256), Use: consts.JSONWebTokenUseSignature}) + assertionRequestObjectValidNone := mustGenerateRequestObjectJWS(t, jwt.MapClaims{consts.FormParameterScope: "foo", "foo": "bar", "baz": "baz", consts.FormParameterState: "some-state", consts.ClaimIssuer: "foo", consts.ClaimAudience: []string{"https://auth.example.com"}}, nil, jwkNone) + assertionRequestObjectValidHS256 := mustGenerateRequestObjectJWS(t, jwt.MapClaims{consts.FormParameterScope: "foo", "foo": "bar", "baz": "baz", consts.FormParameterState: "some-state", consts.ClaimIssuer: "foo", consts.ClaimAudience: []string{"https://auth.example.com"}}, nil, jwkSigHS) mux := http.NewServeMux() var handlerJWKS http.HandlerFunc = func(rw http.ResponseWriter, r *http.Request) { - require.NoError(t, json.NewEncoder(rw).Encode(jwks)) + require.NoError(t, json.NewEncoder(rw).Encode(jwksPublic)) } handleString := func(in string) http.HandlerFunc { @@ -109,9 +210,21 @@ func TestAuthorizeRequestParametersFromOpenIDConnectRequestObject(t *testing.T) { name: "ShouldPassRequest", have: url.Values{consts.FormParameterScope: {"foo openid"}, consts.FormParameterClientID: {"foo"}, consts.FormParameterResponseType: {consts.ResponseTypeImplicitFlowToken}, consts.FormParameterRequest: {assertionRequestObjectValid}}, - client: &DefaultJARClient{JSONWebKeys: jwks, RequestObjectSigningAlg: "RS256", DefaultClient: &DefaultClient{ID: "foo"}}, + client: &DefaultJARClient{JSONWebKeys: jwksPublic, RequestObjectSigningAlg: "RS256", DefaultClient: &DefaultClient{ID: "foo"}}, expected: url.Values{consts.FormParameterScope: {"foo openid"}, consts.FormParameterClientID: {"foo"}, consts.FormParameterResponseType: {consts.ResponseTypeImplicitFlowToken}, consts.FormParameterResponseMode: {consts.ResponseModeFormPost}, consts.FormParameterRequest: {assertionRequestObjectValid}, "foo": {"bar"}, "baz": {"baz"}}, }, + { + name: "ShouldPassRequestJWE", + have: url.Values{consts.FormParameterScope: {"foo openid"}, consts.FormParameterClientID: {"foo"}, consts.FormParameterResponseType: {consts.ResponseTypeImplicitFlowToken}, consts.FormParameterRequest: {assertionRequestObjectValidJWE}}, + client: &DefaultJARClient{JSONWebKeys: jwksPublic, RequestObjectSigningAlg: "RS256", DefaultClient: &DefaultClient{ID: "foo", ClientSecret: clientSecretHS256}}, + expected: url.Values{consts.FormParameterScope: {"foo openid"}, consts.FormParameterClientID: {"foo"}, consts.FormParameterResponseType: {consts.ResponseTypeImplicitFlowToken}, consts.FormParameterResponseMode: {consts.ResponseModeFormPost}, consts.FormParameterRequest: {assertionRequestObjectValidJWE}, "foo": {"bar"}, "baz": {"baz"}}, + }, + { + name: "ShouldPassRequestJWESymmetric", + have: url.Values{consts.FormParameterScope: {"foo openid"}, consts.FormParameterClientID: {"foo"}, consts.FormParameterResponseType: {consts.ResponseTypeImplicitFlowToken}, consts.FormParameterRequest: {assertionRequestObjectValidAssymetricJWE}}, + client: &DefaultJARClient{JSONWebKeys: jwksPublic, RequestObjectSigningAlg: "ES256", RequestObjectSigningKeyID: "es256-sig", RequestObjectEncryptionKeyID: "es256-enc", DefaultClient: &DefaultClient{ID: "foo", ClientSecret: clientSecretHS256}}, + expected: url.Values{consts.FormParameterScope: {"foo openid"}, consts.FormParameterClientID: {"foo"}, consts.FormParameterResponseType: {consts.ResponseTypeImplicitFlowToken}, consts.FormParameterRequest: {assertionRequestObjectValidAssymetricJWE}, "baz": {"baz"}, "foo": {"bar"}, consts.FormParameterResponseMode: {consts.ResponseModeFormPost}}, + }, { name: "ShouldFailRequestNotOpenIDConnectClient", have: url.Values{consts.FormParameterScope: {consts.ScopeOpenID}, consts.FormParameterRequest: {"foo"}}, @@ -183,31 +296,119 @@ func TestAuthorizeRequestParametersFromOpenIDConnectRequestObject(t *testing.T) { name: "ShouldFailInvalidTokenMalformed", have: url.Values{consts.FormParameterScope: {consts.ScopeOpenID}, consts.FormParameterClientID: {"foo"}, consts.FormParameterResponseType: {consts.ResponseTypeAuthorizationCodeFlow}, consts.FormParameterRequest: {"bad-token"}}, - client: &DefaultJARClient{JSONWebKeys: jwks, RequestObjectSigningAlg: "RS256", DefaultClient: &DefaultClient{ID: "foo"}}, + client: &DefaultJARClient{JSONWebKeys: jwksPublic, RequestObjectSigningAlg: "RS256", DefaultClient: &DefaultClient{ID: "foo"}}, expected: url.Values{consts.FormParameterScope: {consts.ScopeOpenID}}, err: ErrInvalidRequestObject, - errString: "The request parameter contains an invalid Request Object. OpenID Connect 1.0 request failed with an error attempting to validate the request object. The OAuth 2.0 client with id 'foo' provided a request object which failed to validate with error: go-jose/go-jose: compact JWS format must have three parts.", + errString: "The request parameter contains an invalid Request Object. OpenID Connect 1.0 request object could not be decoded or validated. OpenID Connect 1.0 client with id 'foo' provided a request object that was malformed. The request object does not appear to be a JWE or JWS compact serialized JWT.", }, { name: "ShouldFailUnknownKID", - have: url.Values{consts.FormParameterScope: {consts.ScopeOpenID}, consts.FormParameterClientID: {"foo"}, consts.FormParameterResponseType: {consts.ResponseTypeAuthorizationCodeFlow}, consts.FormParameterRequest: {mustGenerateAssertion(t, jwt.MapClaims{}, key, "does-not-exists")}}, - client: &DefaultJARClient{JSONWebKeys: jwks, RequestObjectSigningAlg: "RS256", DefaultClient: &DefaultClient{ID: "test"}}, + have: url.Values{consts.FormParameterScope: {consts.ScopeOpenID}, consts.FormParameterClientID: {"foo"}, consts.FormParameterResponseType: {consts.ResponseTypeAuthorizationCodeFlow}, consts.FormParameterRequest: {mustGenerateAssertion(t, jwt.MapClaims{}, keyRSA, "does-not-exists")}}, + client: &DefaultJARClient{JSONWebKeys: jwksPublic, RequestObjectSigningAlg: "RS256", DefaultClient: &DefaultClient{ID: "test"}}, + expected: url.Values{consts.FormParameterScope: {consts.ScopeOpenID}}, + err: ErrInvalidRequestObject, + errString: "The request parameter contains an invalid Request Object. OpenID Connect 1.0 request object could not be decoded or validated. OpenID Connect 1.0 client with id 'test' provided a request object that was not able to be verified. Error occurred retrieving the JSON Web Key. The JSON Web Token uses signing key with kid 'does-not-exists' which was not found.", + }, + { + name: "ShouldFailBadKID", + have: url.Values{consts.FormParameterScope: {consts.ScopeOpenID}, consts.FormParameterClientID: {"foo"}, consts.FormParameterResponseType: {consts.ResponseTypeAuthorizationCodeFlow}, consts.FormParameterRequest: {assertionRequestObjectInvalidKID}}, + client: &DefaultJARClient{JSONWebKeys: jwksPublic, RequestObjectSigningAlg: "RS256", RequestObjectSigningKeyID: "rs256-sig", DefaultClient: &DefaultClient{ID: "test", ClientSecret: clientSecretHS256}}, + expected: url.Values{consts.FormParameterScope: {consts.ScopeOpenID}}, + err: ErrInvalidRequestObject, + errString: "The request parameter contains an invalid Request Object. OpenID Connect 1.0 request object could not be decoded or validated. OpenID Connect 1.0 client with id 'test' expects request objects to be signed with the 'kid' header value 'rs256-sig' due to the client registration 'request_object_signing_key_id' value but the request object was signed with the 'kid' header value 'rs384-sig'.", + }, + { + name: "ShouldFailBadType", + have: url.Values{consts.FormParameterScope: {consts.ScopeOpenID}, consts.FormParameterClientID: {"foo"}, consts.FormParameterResponseType: {consts.ResponseTypeAuthorizationCodeFlow}, consts.FormParameterRequest: {assertionRequestObjectInvalidTyp}}, + client: &DefaultJARClient{JSONWebKeys: jwksPublic, RequestObjectSigningAlg: "RS256", RequestObjectSigningKeyID: "rs256-sig", DefaultClient: &DefaultClient{ID: "test", ClientSecret: clientSecretHS256}}, + expected: url.Values{consts.FormParameterScope: {consts.ScopeOpenID}}, + err: ErrInvalidRequestObject, + errString: "The request parameter contains an invalid Request Object. OpenID Connect 1.0 request object could not be decoded or validated. OpenID Connect 1.0 client with id 'test' expects request objects to be signed with the 'typ' header value 'JWT' but the request object was signed with the 'typ' header value 'abc'.", + }, + { + name: "ShouldFailJWEBadContentType", + have: url.Values{consts.FormParameterScope: {consts.ScopeOpenID}, consts.FormParameterClientID: {"foo"}, consts.FormParameterResponseType: {consts.ResponseTypeAuthorizationCodeFlow}, consts.FormParameterRequest: {assertionRequestObjectInvalidJWEContentType}}, + client: &DefaultJARClient{JSONWebKeys: jwksPublic, RequestObjectSigningAlg: "RS256", RequestObjectSigningKeyID: "rs256-sig", DefaultClient: &DefaultClient{ID: "test", ClientSecret: clientSecretHS256}}, + expected: url.Values{consts.FormParameterScope: {consts.ScopeOpenID}}, + err: ErrInvalidRequestObject, + errString: "The request parameter contains an invalid Request Object. OpenID Connect 1.0 request object could not be decoded or validated. OpenID Connect 1.0 client with id 'test' expects request objects to be encrypted with a 'cty' header value and signed with a 'typ' value that match but the request object was encrypted with the 'cty' header value 'at+jwt' and signed with the 'typ' header value 'JWT'.", + }, + { + name: "ShouldFailJWEBadType", + have: url.Values{consts.FormParameterScope: {consts.ScopeOpenID}, consts.FormParameterClientID: {"foo"}, consts.FormParameterResponseType: {consts.ResponseTypeAuthorizationCodeFlow}, consts.FormParameterRequest: {assertionRequestObjectInvalidJWEType}}, + client: &DefaultJARClient{JSONWebKeys: jwksPublic, RequestObjectSigningAlg: "RS256", RequestObjectSigningKeyID: "rs256-sig", DefaultClient: &DefaultClient{ID: "test", ClientSecret: clientSecretHS256}}, + expected: url.Values{consts.FormParameterScope: {consts.ScopeOpenID}}, + err: ErrInvalidRequestObject, + errString: "The request parameter contains an invalid Request Object. OpenID Connect 1.0 request object could not be decoded or validated. OpenID Connect 1.0 client with id 'test' expects request objects to be encrypted with the 'typ' header value 'JWT' but the request object was encrypted with the 'typ' header value 'at+jwt'.", + }, + { + name: "ShouldFailJWEBadKeyID", + have: url.Values{consts.FormParameterScope: {consts.ScopeOpenID}, consts.FormParameterClientID: {"foo"}, consts.FormParameterResponseType: {consts.ResponseTypeImplicitFlowToken}, consts.FormParameterRequest: {assertionRequestObjectValidAssymetricJWE}}, + client: &DefaultJARClient{JSONWebKeys: jwksPublic, RequestObjectSigningAlg: "ES256", RequestObjectSigningKeyID: "es256-sig", RequestObjectEncryptionKeyID: "abc", DefaultClient: &DefaultClient{ID: "foo", ClientSecret: clientSecretHS256}}, + expected: url.Values{consts.FormParameterScope: {consts.ScopeOpenID}}, + err: ErrInvalidRequestObject, + errString: "The request parameter contains an invalid Request Object. OpenID Connect 1.0 request object could not be decoded or validated. OpenID Connect 1.0 client with id 'foo' expects request objects to be encrypted with the 'kid' header value 'abc' due to the client registration 'request_object_encryption_key_id' value but the request object was encrypted with the 'kid' header value 'es256-enc'.", + }, + { + name: "ShouldFailJWEBadAlg", + have: url.Values{consts.FormParameterScope: {consts.ScopeOpenID}, consts.FormParameterClientID: {"foo"}, consts.FormParameterResponseType: {consts.ResponseTypeImplicitFlowToken}, consts.FormParameterRequest: {assertionRequestObjectValidAssymetricJWE}}, + client: &DefaultJARClient{JSONWebKeys: jwksPublic, RequestObjectSigningAlg: "ES256", RequestObjectSigningKeyID: "es256-sig", RequestObjectEncryptionAlg: "abc", DefaultClient: &DefaultClient{ID: "foo", ClientSecret: clientSecretHS256}}, + expected: url.Values{consts.FormParameterScope: {consts.ScopeOpenID}}, + err: ErrInvalidRequestObject, + errString: "The request parameter contains an invalid Request Object. OpenID Connect 1.0 request object could not be decoded or validated. OpenID Connect 1.0 client with id 'foo' expects request objects to be encrypted with the 'alg' header value 'abc' due to the client registration 'request_object_encryption_alg' value but the request object was encrypted with the 'alg' header value 'ECDH-ES+A128KW'.", + }, + { + name: "ShouldFailJWEBadEnc", + have: url.Values{consts.FormParameterScope: {consts.ScopeOpenID}, consts.FormParameterClientID: {"foo"}, consts.FormParameterResponseType: {consts.ResponseTypeImplicitFlowToken}, consts.FormParameterRequest: {assertionRequestObjectValidAssymetricJWE}}, + client: &DefaultJARClient{JSONWebKeys: jwksPublic, RequestObjectSigningAlg: "ES256", RequestObjectSigningKeyID: "es256-sig", RequestObjectEncryptionEnc: "abc", DefaultClient: &DefaultClient{ID: "foo", ClientSecret: clientSecretHS256}}, + expected: url.Values{consts.FormParameterScope: {consts.ScopeOpenID}}, + err: ErrInvalidRequestObject, + errString: "The request parameter contains an invalid Request Object. OpenID Connect 1.0 request object could not be decoded or validated. OpenID Connect 1.0 client with id 'foo' expects request objects to be encrypted with the 'enc' header value 'abc' due to the client registration 'request_object_encryption_enc' value but the request object was encrypted with the 'enc' header value 'A128GCM'.", + }, + { + name: "ShouldFailExpired", + have: url.Values{consts.FormParameterScope: {consts.ScopeOpenID}, consts.FormParameterClientID: {"foo"}, consts.FormParameterResponseType: {consts.ResponseTypeImplicitFlowToken}, consts.FormParameterRequest: {assertionRequestObjectInvalidExpired}}, + client: &DefaultJARClient{JSONWebKeys: jwksPublic, DefaultClient: &DefaultClient{ID: "foo", ClientSecret: clientSecretHS256}}, + expected: url.Values{consts.FormParameterScope: {consts.ScopeOpenID}}, + err: ErrInvalidRequestObject, + errRegex: regexp.MustCompile(`^The request parameter contains an invalid Request Object\. OpenID Connect 1\.0 request object could not be decoded or validated\. OpenID Connect 1\.0 client with id 'foo' provided a request object that was expired\. The request object expired at \d+\.`), + }, + { + name: "ShouldFailFuture", + have: url.Values{consts.FormParameterScope: {consts.ScopeOpenID}, consts.FormParameterClientID: {"foo"}, consts.FormParameterResponseType: {consts.ResponseTypeImplicitFlowToken}, consts.FormParameterRequest: {assertionRequestObjectInvalidFuture}}, + client: &DefaultJARClient{JSONWebKeys: jwksPublic, DefaultClient: &DefaultClient{ID: "foo", ClientSecret: clientSecretHS256}}, + expected: url.Values{consts.FormParameterScope: {consts.ScopeOpenID}}, + err: ErrInvalidRequestObject, + errRegex: regexp.MustCompile(`^The request parameter contains an invalid Request Object. OpenID Connect 1\.0 request object could not be decoded or validated\. OpenID Connect 1\.0 client with id 'foo' provided a request object that was issued in the future\. The request object was issued at \d+\.$`), + }, + { + name: "ShouldFailNotBefore", + have: url.Values{consts.FormParameterScope: {consts.ScopeOpenID}, consts.FormParameterClientID: {"foo"}, consts.FormParameterResponseType: {consts.ResponseTypeImplicitFlowToken}, consts.FormParameterRequest: {assertionRequestObjectInvalidNotValidYet}}, + client: &DefaultJARClient{JSONWebKeys: jwksPublic, DefaultClient: &DefaultClient{ID: "foo", ClientSecret: clientSecretHS256}}, + expected: url.Values{consts.FormParameterScope: {consts.ScopeOpenID}}, + err: ErrInvalidRequestObject, + errRegex: regexp.MustCompile(`^The request parameter contains an invalid Request Object\. OpenID Connect 1\.0 request object could not be decoded or validated\. OpenID Connect 1\.0 client with id 'foo' provided a request object that was issued in the future\. The request object is not valid before \d+\.`), + }, + { + name: "ShouldFailBadSignature", + have: url.Values{consts.FormParameterScope: {consts.ScopeOpenID}, consts.FormParameterClientID: {"foo"}, consts.FormParameterResponseType: {consts.ResponseTypeAuthorizationCodeFlow}, consts.FormParameterRequest: {assertionRequestObjectInvalidSignature}}, + client: &DefaultJARClient{JSONWebKeys: jwksPublic, RequestObjectSigningAlg: "RS256", RequestObjectSigningKeyID: "rs256-sig", DefaultClient: &DefaultClient{ID: "test", ClientSecret: clientSecretHS256}}, expected: url.Values{consts.FormParameterScope: {consts.ScopeOpenID}}, err: ErrInvalidRequestObject, - errString: "The request parameter contains an invalid Request Object. Unable to retrieve RSA signing key from OAuth 2.0 Client. The JSON Web Token uses signing key with kid 'does-not-exists', which could not be found. The request is missing a required parameter, includes an invalid parameter value, includes a parameter more than once, or is otherwise malformed. The JSON Web Token uses signing key with kid 'does-not-exists', which could not be found.", + errString: "The request parameter contains an invalid Request Object. OpenID Connect 1.0 request object could not be decoded or validated. OpenID Connect 1.0 client with id 'test' provided a request object that has an invalid signature.", }, { name: "ShouldFailBadAlgRS256", - have: url.Values{consts.FormParameterScope: {consts.ScopeOpenID}, consts.FormParameterClientID: {"foo"}, consts.FormParameterResponseType: {consts.ResponseTypeAuthorizationCodeFlow}, consts.FormParameterRequest: {mustGenerateHSAssertion(t, jwt.MapClaims{})}}, - client: &DefaultJARClient{JSONWebKeys: jwks, RequestObjectSigningAlg: "RS256", DefaultClient: &DefaultClient{ID: "test"}}, + have: url.Values{consts.FormParameterScope: {consts.ScopeOpenID}, consts.FormParameterClientID: {"foo"}, consts.FormParameterResponseType: {consts.ResponseTypeAuthorizationCodeFlow}, consts.FormParameterRequest: {assertionRequestObjectEmptyHS256}}, + client: &DefaultJARClient{JSONWebKeys: jwksPublic, RequestObjectSigningAlg: "RS256", DefaultClient: &DefaultClient{ID: "test", ClientSecret: clientSecretHS256}}, expected: url.Values{consts.FormParameterScope: {consts.ScopeOpenID}}, err: ErrInvalidRequestObject, - errString: "The request parameter contains an invalid Request Object. OpenID Connect 1.0 request failed with an error attempting to validate the request object. The OAuth 2.0 client with id 'test' expects request objects to be signed with the 'RS256' algorithm but the request object was signed with the 'HS256' algorithm.", + errString: "The request parameter contains an invalid Request Object. OpenID Connect 1.0 request object could not be decoded or validated. OpenID Connect 1.0 client with id 'test' expects request objects to be signed with the 'alg' header value 'RS256' due to the client registration 'request_object_signing_alg' value but the request object was signed with the 'alg' header value 'HS256'.", }, { name: "ShouldFailMismatchedClientID", have: url.Values{consts.FormParameterScope: {consts.ScopeOpenID}, consts.FormParameterClientID: {"not-foo"}, consts.FormParameterResponseType: {consts.ResponseTypeImplicitFlowToken}, consts.FormParameterResponseMode: {consts.ResponseModeNone}, consts.FormParameterRequest: {assertionRequestObjectValid}}, - client: &DefaultJARClient{JSONWebKeys: jwks, RequestObjectSigningAlg: "RS256", DefaultClient: &DefaultClient{ID: "foo"}}, + client: &DefaultJARClient{JSONWebKeys: jwksPublic, RequestObjectSigningAlg: "RS256", DefaultClient: &DefaultClient{ID: "foo"}}, expected: url.Values{consts.FormParameterClientID: {"foo"}, consts.FormParameterResponseType: {consts.ResponseTypeImplicitFlowToken}, consts.FormParameterResponseMode: {consts.ResponseModeFormPost}, consts.FormParameterScope: {"foo openid"}, consts.FormParameterRequest: {assertionRequestObjectValid}, "foo": {"bar"}, "baz": {"baz"}}, err: ErrInvalidRequestObject, errString: "The request parameter contains an invalid Request Object. OpenID Connect 1.0 request included a request object which excluded claims that are required or included claims that did not match the OAuth 2.0 request syntax or are generally not permitted. The OAuth 2.0 client with id 'foo' included a request object with a 'client_id' claim with a value of 'foo' which is required to match the value 'not-foo' in the parameter with the same name from the OAuth 2.0 request syntax.", @@ -215,7 +416,7 @@ func TestAuthorizeRequestParametersFromOpenIDConnectRequestObject(t *testing.T) { name: "ShouldFailRequestClientIDAssert", have: url.Values{consts.FormParameterScope: {consts.ScopeOpenID}, consts.FormParameterClientID: {"not-foo"}, consts.FormParameterResponseType: {consts.ResponseTypeImplicitFlowToken}, consts.FormParameterResponseMode: {consts.ResponseModeNone}, consts.FormParameterRequest: {assertionRequestObjectInvalidClientIDValue}}, - client: &DefaultJARClient{JSONWebKeys: jwks, RequestObjectSigningAlg: "RS256", DefaultClient: &DefaultClient{ID: "foo"}}, + client: &DefaultJARClient{JSONWebKeys: jwksPublic, RequestObjectSigningAlg: "RS256", DefaultClient: &DefaultClient{ID: "foo"}}, expected: url.Values{consts.FormParameterClientID: {"foo"}, consts.FormParameterResponseType: {consts.ResponseTypeImplicitFlowToken}, consts.FormParameterResponseMode: {consts.ResponseModeFormPost}, consts.FormParameterScope: {"foo openid"}, consts.FormParameterRequest: {assertionRequestObjectInvalidClientIDValue}, "foo": {"bar"}, "baz": {"baz"}}, err: ErrInvalidRequestObject, errString: "The request parameter contains an invalid Request Object. OpenID Connect 1.0 request included a request object which excluded claims that are required or included claims that did not match the OAuth 2.0 request syntax or are generally not permitted. The OAuth 2.0 client with id 'foo' included a request object with a 'client_id' claim with a value of '100' which is required to match the value 'not-foo' in the parameter with the same name from the OAuth 2.0 request syntax but instead of a string it had the int64 type.", @@ -223,7 +424,7 @@ func TestAuthorizeRequestParametersFromOpenIDConnectRequestObject(t *testing.T) { name: "ShouldFailRequestWithRequest", have: url.Values{consts.FormParameterScope: {consts.ScopeOpenID}, consts.FormParameterClientID: {"foo"}, consts.FormParameterResponseType: {consts.ResponseTypeImplicitFlowToken}, consts.FormParameterResponseMode: {consts.ResponseModeNone}, consts.FormParameterRequest: {assertionRequestObjectInvalidRequestInRequest}}, - client: &DefaultJARClient{JSONWebKeys: jwks, RequestObjectSigningAlg: "RS256", DefaultClient: &DefaultClient{ID: "foo"}}, + client: &DefaultJARClient{JSONWebKeys: jwksPublic, RequestObjectSigningAlg: "RS256", DefaultClient: &DefaultClient{ID: "foo"}}, expected: url.Values{consts.FormParameterClientID: {"foo"}, consts.FormParameterResponseType: {consts.ResponseTypeImplicitFlowToken}, consts.FormParameterResponseMode: {consts.ResponseModeFormPost}, consts.FormParameterScope: {"foo openid"}, consts.FormParameterRequest: {assertionRequestObjectInvalidRequestInRequest}, "foo": {"bar"}, "baz": {"baz"}}, err: ErrInvalidRequestObject, errString: "The request parameter contains an invalid Request Object. OpenID Connect 1.0 request included a request object which excluded claims that are required or included claims that did not match the OAuth 2.0 request syntax or are generally not permitted. The OAuth 2.0 client with id 'foo' included a request object which contained the 'request' or 'request_uri' claims but this is not permitted.", @@ -231,7 +432,7 @@ func TestAuthorizeRequestParametersFromOpenIDConnectRequestObject(t *testing.T) { name: "ShouldFailRequestWithRequestURI", have: url.Values{consts.FormParameterScope: {consts.ScopeOpenID}, consts.FormParameterClientID: {"foo"}, consts.FormParameterResponseType: {consts.ResponseTypeImplicitFlowToken}, consts.FormParameterResponseMode: {consts.ResponseModeNone}, consts.FormParameterRequest: {assertionRequestObjectInvalidRequestURIInRequest}}, - client: &DefaultJARClient{JSONWebKeys: jwks, RequestObjectSigningAlg: "RS256", DefaultClient: &DefaultClient{ID: "foo"}}, + client: &DefaultJARClient{JSONWebKeys: jwksPublic, RequestObjectSigningAlg: "RS256", DefaultClient: &DefaultClient{ID: "foo"}}, expected: url.Values{consts.FormParameterClientID: {"foo"}, consts.FormParameterResponseType: {consts.ResponseTypeImplicitFlowToken}, consts.FormParameterResponseMode: {consts.ResponseModeFormPost}, consts.FormParameterScope: {"foo openid"}, consts.FormParameterRequest: {assertionRequestObjectInvalidRequestURIInRequest}, "foo": {"bar"}, "baz": {"baz"}}, err: ErrInvalidRequestObject, errString: "The request parameter contains an invalid Request Object. OpenID Connect 1.0 request included a request object which excluded claims that are required or included claims that did not match the OAuth 2.0 request syntax or are generally not permitted. The OAuth 2.0 client with id 'foo' included a request object which contained the 'request' or 'request_uri' claims but this is not permitted.", @@ -239,7 +440,7 @@ func TestAuthorizeRequestParametersFromOpenIDConnectRequestObject(t *testing.T) { name: "ShouldFailMismatchedResponseType", have: url.Values{consts.FormParameterScope: {consts.ScopeOpenID}, consts.FormParameterClientID: {"foo"}, consts.FormParameterResponseType: {consts.ResponseTypeAuthorizationCodeFlow}, consts.FormParameterResponseMode: {consts.ResponseModeNone}, consts.FormParameterRequest: {assertionRequestObjectValid}}, - client: &DefaultJARClient{JSONWebKeys: jwks, RequestObjectSigningAlg: "RS256", DefaultClient: &DefaultClient{ID: "foo"}}, + client: &DefaultJARClient{JSONWebKeys: jwksPublic, RequestObjectSigningAlg: "RS256", DefaultClient: &DefaultClient{ID: "foo"}}, expected: url.Values{consts.FormParameterClientID: {"foo"}, consts.FormParameterResponseType: {consts.ResponseTypeImplicitFlowToken}, consts.FormParameterResponseMode: {consts.ResponseModeFormPost}, consts.FormParameterScope: {"foo openid"}, consts.FormParameterRequest: {assertionRequestObjectValid}, "foo": {"bar"}, "baz": {"baz"}}, err: ErrInvalidRequestObject, errString: "The request parameter contains an invalid Request Object. OpenID Connect 1.0 request included a request object which excluded claims that are required or included claims that did not match the OAuth 2.0 request syntax or are generally not permitted. The OAuth 2.0 client with id 'foo' included a request object with a 'response_type' claim with a value of 'token' which is required to match the value 'code' in the parameter with the same name from the OAuth 2.0 request syntax.", @@ -247,7 +448,7 @@ func TestAuthorizeRequestParametersFromOpenIDConnectRequestObject(t *testing.T) { name: "ShouldFailMismatchedResponseTypeAsserted", have: url.Values{consts.FormParameterScope: {consts.ScopeOpenID}, consts.FormParameterClientID: {"foo"}, consts.FormParameterResponseType: {consts.ResponseTypeAuthorizationCodeFlow}, consts.FormParameterResponseMode: {consts.ResponseModeNone}, consts.FormParameterRequest: {assertionRequestObjectInvalidResponseTypeValue}}, - client: &DefaultJARClient{JSONWebKeys: jwks, RequestObjectSigningAlg: "RS256", DefaultClient: &DefaultClient{ID: "foo"}}, + client: &DefaultJARClient{JSONWebKeys: jwksPublic, RequestObjectSigningAlg: "RS256", DefaultClient: &DefaultClient{ID: "foo"}}, expected: url.Values{consts.FormParameterClientID: {"foo"}, consts.FormParameterResponseType: {consts.ResponseTypeImplicitFlowToken}, consts.FormParameterResponseMode: {consts.ResponseModeFormPost}, consts.FormParameterScope: {"foo openid"}, consts.FormParameterRequest: {assertionRequestObjectInvalidResponseTypeValue}, "foo": {"bar"}, "baz": {"baz"}}, err: ErrInvalidRequestObject, errString: "The request parameter contains an invalid Request Object. OpenID Connect 1.0 request included a request object which excluded claims that are required or included claims that did not match the OAuth 2.0 request syntax or are generally not permitted. The OAuth 2.0 client with id 'foo' included a request object with a 'response_type' claim with a value of '100' which is required to match the value 'code' in the parameter with the same name from the OAuth 2.0 request syntax but instead of a string it had the int64 type.", @@ -255,13 +456,13 @@ func TestAuthorizeRequestParametersFromOpenIDConnectRequestObject(t *testing.T) { name: "ShouldPassWithoutKID", have: url.Values{consts.FormParameterScope: {consts.ScopeOpenID}, consts.FormParameterClientID: {"foo"}, consts.FormParameterResponseType: {consts.ResponseTypeAuthorizationCodeFlow}, consts.FormParameterRequest: {assertionRequestObjectValidWithoutKID}}, - client: &DefaultJARClient{JSONWebKeys: jwks, RequestObjectSigningAlg: "RS256", DefaultClient: &DefaultClient{ID: "foo"}}, + client: &DefaultJARClient{JSONWebKeys: jwksPublic, RequestObjectSigningAlg: "RS256", DefaultClient: &DefaultClient{ID: "foo"}}, expected: url.Values{consts.FormParameterScope: {"foo openid"}, consts.FormParameterClientID: {"foo"}, consts.FormParameterResponseType: {consts.ResponseTypeAuthorizationCodeFlow}, consts.FormParameterRequest: {assertionRequestObjectValidWithoutKID}, "foo": {"bar"}, "baz": {"baz"}}, }, { name: "ShouldFailRequestURINotWhiteListed", have: url.Values{consts.FormParameterScope: {consts.ScopeOpenID}, consts.FormParameterClientID: {"foo"}, consts.FormParameterResponseType: {consts.ResponseTypeAuthorizationCodeFlow}, consts.FormParameterRequestURI: {root.JoinPath("request-object", "valid", "standard.jwk").String()}}, - client: &DefaultJARClient{JSONWebKeys: jwks, RequestObjectSigningAlg: "RS256", DefaultClient: &DefaultClient{ID: "foo"}}, + client: &DefaultJARClient{JSONWebKeys: jwksPublic, RequestObjectSigningAlg: "RS256", DefaultClient: &DefaultClient{ID: "foo"}}, expected: url.Values{consts.FormParameterScope: {"foo openid"}, consts.FormParameterClientID: {"foo"}, consts.FormParameterResponseType: {consts.ResponseTypeAuthorizationCodeFlow}, consts.FormParameterRequest: {assertionRequestObjectValidWithoutKID}, "foo": {"bar"}, "baz": {"baz"}}, err: ErrInvalidRequestURI, errRegex: regexp.MustCompile(`^The request_uri in the authorization request returns an error or contains invalid data\. OpenID Connect 1\.0 request failed to fetch request parameters from the provided 'request_uri'\. The OAuth 2\.0 client with id 'foo' provided the 'request_uri' parameter with value 'http://127.0.0.1:\d+/request-object/valid/standard\.jwk' which is not whitelisted.$`), @@ -278,7 +479,7 @@ func TestAuthorizeRequestParametersFromOpenIDConnectRequestObject(t *testing.T) client: &DefaultJARClient{JSONWebKeysURI: root.JoinPath("jwks.json").String(), RequestObjectSigningAlg: "RS256", DefaultClient: &DefaultClient{ID: "foo"}}, expected: url.Values{consts.FormParameterState: {"some-state"}, consts.FormParameterClientID: {"foo"}, consts.FormParameterResponseType: {consts.ResponseTypeAuthorizationCodeFlow}, consts.FormParameterScope: {"foo openid"}, consts.FormParameterRequest: {assertionRequestObjectValidNone}, "foo": {"bar"}, "baz": {"baz"}}, err: ErrInvalidRequestObject, - errString: "The request parameter contains an invalid Request Object. OpenID Connect 1.0 request failed with an error attempting to validate the request object. The OAuth 2.0 client with id 'foo' expects request objects to be signed with the 'RS256' algorithm but the request object was signed with the 'none' algorithm.", + errString: "The request parameter contains an invalid Request Object. OpenID Connect 1.0 request object could not be decoded or validated. OpenID Connect 1.0 client with id 'foo' expects request objects to be signed with the 'alg' header value 'RS256' due to the client registration 'request_object_signing_alg' value but the request object was signed with the 'alg' header value 'none'.", }, { name: "ShouldFailRequestURIAlgNone", @@ -286,7 +487,7 @@ func TestAuthorizeRequestParametersFromOpenIDConnectRequestObject(t *testing.T) client: &DefaultJARClient{JSONWebKeysURI: root.JoinPath("jwks.json").String(), RequestObjectSigningAlg: "RS256", RequestURIs: []string{root.JoinPath("request-object", "valid", "none.jwk").String()}, DefaultClient: &DefaultClient{ID: "foo"}}, expected: url.Values{consts.FormParameterResponseType: {"token"}, consts.FormParameterClientID: {"foo"}, consts.FormParameterState: {"some-state"}, consts.FormParameterScope: {"foo openid"}, consts.FormParameterRequestURI: {root.JoinPath("request-object", "valid", "none.jwk").String()}, "foo": {"bar"}, "baz": {"baz"}}, err: ErrInvalidRequestObject, - errString: "The request parameter contains an invalid Request Object. OpenID Connect 1.0 request failed with an error attempting to validate the request object. The OAuth 2.0 client with id 'foo' expects request objects to be signed with the 'RS256' algorithm but the request object was signed with the 'none' algorithm.", + errString: "The request parameter contains an invalid Request Object. OpenID Connect 1.0 request object could not be decoded or validated. OpenID Connect 1.0 client with id 'foo' expects request objects to be signed with the 'alg' header value 'RS256' due to the client registration 'request_object_signing_alg' value but the request object was signed with the 'alg' header value 'none'.", }, { name: "ShouldFailRequestRS256", @@ -294,7 +495,7 @@ func TestAuthorizeRequestParametersFromOpenIDConnectRequestObject(t *testing.T) client: &DefaultJARClient{JSONWebKeysURI: root.JoinPath("jwks.json").String(), RequestObjectSigningAlg: consts.JSONWebTokenAlgNone, DefaultClient: &DefaultClient{ID: "foo"}}, expected: url.Values{consts.FormParameterState: {"some-state"}, consts.FormParameterClientID: {"foo"}, consts.FormParameterResponseType: {consts.ResponseTypeAuthorizationCodeFlow}, consts.FormParameterScope: {"foo openid"}, consts.FormParameterRequest: {assertionRequestObjectValid}, "foo": {"bar"}, "baz": {"baz"}}, err: ErrInvalidRequestObject, - errString: "The request parameter contains an invalid Request Object. OpenID Connect 1.0 request failed with an error attempting to validate the request object. The OAuth 2.0 client with id 'foo' expects request objects to be signed with the 'none' algorithm but the request object was signed with the 'RS256' algorithm.", + errString: "The request parameter contains an invalid Request Object. OpenID Connect 1.0 request object could not be decoded or validated. OpenID Connect 1.0 client with id 'foo' expects request objects to be signed with the 'alg' header value 'none' due to the client registration 'request_object_signing_alg' value but the request object was signed with the 'alg' header value 'RS256'.", }, { name: "ShouldFailRequestURIRS256", @@ -302,12 +503,12 @@ func TestAuthorizeRequestParametersFromOpenIDConnectRequestObject(t *testing.T) client: &DefaultJARClient{JSONWebKeysURI: root.JoinPath("jwks.json").String(), RequestObjectSigningAlg: consts.JSONWebTokenAlgNone, RequestURIs: []string{root.JoinPath("request-object", "valid", "standard.jwk").String()}, DefaultClient: &DefaultClient{ID: "foo"}}, expected: url.Values{consts.FormParameterResponseType: {"token"}, consts.FormParameterClientID: {"foo"}, consts.FormParameterState: {"some-state"}, consts.FormParameterScope: {"foo openid"}, consts.FormParameterRequestURI: {root.JoinPath("request-object", "valid", "standard.jwk").String()}, "foo": {"bar"}, "baz": {"baz"}}, err: ErrInvalidRequestObject, - errString: "The request parameter contains an invalid Request Object. OpenID Connect 1.0 request failed with an error attempting to validate the request object. The OAuth 2.0 client with id 'foo' expects request objects to be signed with the 'none' algorithm but the request object was signed with the 'RS256' algorithm.", + errString: "The request parameter contains an invalid Request Object. OpenID Connect 1.0 request object could not be decoded or validated. OpenID Connect 1.0 client with id 'foo' expects request objects to be signed with the 'alg' header value 'none' due to the client registration 'request_object_signing_alg' value but the request object was signed with the 'alg' header value 'RS256'.", }, { name: "ShouldPassRequestAlgNone", have: url.Values{consts.FormParameterScope: {consts.ScopeOpenID}, consts.FormParameterClientID: {"foo"}, consts.FormParameterResponseType: {consts.ResponseTypeAuthorizationCodeFlow}, consts.FormParameterRequest: {assertionRequestObjectValidNone}}, - client: &DefaultJARClient{JSONWebKeysURI: root.JoinPath("jwks.json").String(), RequestObjectSigningAlg: consts.JSONWebTokenAlgNone}, + client: &DefaultJARClient{JSONWebKeysURI: root.JoinPath("jwks.json").String(), RequestObjectSigningAlg: consts.JSONWebTokenAlgNone, DefaultClient: &DefaultClient{ID: "foo"}}, expected: url.Values{consts.FormParameterState: {"some-state"}, consts.FormParameterClientID: {"foo"}, consts.FormParameterResponseType: {consts.ResponseTypeAuthorizationCodeFlow}, consts.FormParameterScope: {"foo openid"}, consts.FormParameterRequest: {assertionRequestObjectValidNone}, "foo": {"bar"}, "baz": {"baz"}}, }, { @@ -317,16 +518,26 @@ func TestAuthorizeRequestParametersFromOpenIDConnectRequestObject(t *testing.T) expected: url.Values{consts.FormParameterResponseType: {"token"}, consts.FormParameterClientID: {"foo"}, consts.FormParameterState: {"some-state"}, consts.FormParameterScope: {"foo openid"}, consts.FormParameterRequestURI: {root.JoinPath("request-object", "valid", "none.jwk").String()}, "foo": {"bar"}, "baz": {"baz"}}, }, { - name: "ShouldPassRequestAlgNoneAllowAny", - have: url.Values{consts.FormParameterScope: {consts.ScopeOpenID}, consts.FormParameterClientID: {"foo"}, consts.FormParameterResponseType: {consts.ResponseTypeAuthorizationCodeFlow}, consts.FormParameterRequest: {assertionRequestObjectValidNone}}, - client: &DefaultJARClient{JSONWebKeysURI: root.JoinPath("jwks.json").String()}, - expected: url.Values{consts.FormParameterState: {"some-state"}, consts.FormParameterClientID: {"foo"}, consts.FormParameterResponseType: {consts.ResponseTypeAuthorizationCodeFlow}, consts.FormParameterScope: {"foo openid"}, consts.FormParameterRequest: {assertionRequestObjectValidNone}, "foo": {"bar"}, "baz": {"baz"}}, + name: "ShouldPassRequestAlgHS256", + have: url.Values{consts.FormParameterScope: {consts.ScopeOpenID}, consts.FormParameterClientID: {"foo"}, consts.FormParameterResponseType: {consts.ResponseTypeAuthorizationCodeFlow}, consts.FormParameterRequest: {assertionRequestObjectValidHS256}}, + client: &DefaultJARClient{JSONWebKeysURI: root.JoinPath("jwks.json").String(), RequestObjectSigningAlg: string(jose.HS256), DefaultClient: &DefaultClient{ID: "foo", ClientSecret: clientSecretHS256}}, + expected: url.Values{consts.FormParameterState: {"some-state"}, consts.FormParameterClientID: {"foo"}, consts.FormParameterResponseType: {consts.ResponseTypeAuthorizationCodeFlow}, consts.FormParameterScope: {"foo openid"}, consts.FormParameterRequest: {assertionRequestObjectValidHS256}, "foo": {"bar"}, "baz": {"baz"}}, }, { - name: "ShouldPassRequestURIAlgNoneAllowAny", - have: url.Values{consts.FormParameterScope: {consts.ScopeOpenID}, consts.FormParameterClientID: {"foo"}, consts.FormParameterResponseType: {consts.ResponseTypeImplicitFlowToken}, consts.FormParameterRequestURI: {root.JoinPath("request-object", "valid", "none.jwk").String()}}, - client: &DefaultJARClient{JSONWebKeysURI: root.JoinPath("jwks.json").String(), RequestObjectSigningAlg: "", RequestURIs: []string{root.JoinPath("request-object", "valid", "none.jwk").String()}, DefaultClient: &DefaultClient{ID: "foo"}}, - expected: url.Values{consts.FormParameterResponseType: {"token"}, consts.FormParameterClientID: {"foo"}, consts.FormParameterState: {"some-state"}, consts.FormParameterScope: {"foo openid"}, consts.FormParameterRequestURI: {root.JoinPath("request-object", "valid", "none.jwk").String()}, "foo": {"bar"}, "baz": {"baz"}}, + name: "ShouldPassRequestAlgNoneAllowAny", + have: url.Values{consts.FormParameterScope: {consts.ScopeOpenID}, consts.FormParameterClientID: {"foo"}, consts.FormParameterResponseType: {consts.ResponseTypeAuthorizationCodeFlow}, consts.FormParameterRequest: {assertionRequestObjectValidNone}}, + client: &DefaultJARClient{JSONWebKeysURI: root.JoinPath("jwks.json").String(), DefaultClient: &DefaultClient{ID: "foo"}}, + expected: url.Values{consts.FormParameterState: {"some-state"}, consts.FormParameterClientID: {"foo"}, consts.FormParameterResponseType: {consts.ResponseTypeAuthorizationCodeFlow}, consts.FormParameterScope: {"foo openid"}, consts.FormParameterRequest: {assertionRequestObjectValidNone}, "foo": {"bar"}, "baz": {"baz"}}, + err: ErrInvalidRequestObject, + errString: "The request parameter contains an invalid Request Object. OpenID Connect 1.0 client provided a request object that has an invalid 'kid' or 'alg' header value. OpenID Connect 1.0 client with id 'foo' was not explicitly registered with a 'request_object_signing_alg' value of 'none' but the request object had the 'alg' value 'none' in the header.", + }, + { + name: "ShouldPassRequestURIAlgNoneAllowAny", + have: url.Values{consts.FormParameterScope: {consts.ScopeOpenID}, consts.FormParameterClientID: {"foo"}, consts.FormParameterResponseType: {consts.ResponseTypeImplicitFlowToken}, consts.FormParameterRequestURI: {root.JoinPath("request-object", "valid", "none.jwk").String()}}, + client: &DefaultJARClient{JSONWebKeysURI: root.JoinPath("jwks.json").String(), RequestObjectSigningAlg: "", RequestURIs: []string{root.JoinPath("request-object", "valid", "none.jwk").String()}, DefaultClient: &DefaultClient{ID: "foo"}}, + expected: url.Values{consts.FormParameterResponseType: {"token"}, consts.FormParameterClientID: {"foo"}, consts.FormParameterState: {"some-state"}, consts.FormParameterScope: {"foo openid"}, consts.FormParameterRequestURI: {root.JoinPath("request-object", "valid", "none.jwk").String()}, "foo": {"bar"}, "baz": {"baz"}}, + err: ErrInvalidRequestObject, + errString: "The request parameter contains an invalid Request Object. OpenID Connect 1.0 client provided a request object that has an invalid 'kid' or 'alg' header value. OpenID Connect 1.0 client with id 'foo' was not explicitly registered with a 'request_object_signing_alg' value of 'none' but the request object had the 'alg' value 'none' in the header.", }, { name: "ShouldFailRequestBadAudience", @@ -334,7 +545,7 @@ func TestAuthorizeRequestParametersFromOpenIDConnectRequestObject(t *testing.T) client: &DefaultJARClient{JSONWebKeysURI: root.JoinPath("jwks.json").String(), RequestObjectSigningAlg: "RS256", DefaultClient: &DefaultClient{ID: "foo"}}, expected: url.Values{consts.FormParameterState: {"some-state"}, consts.FormParameterClientID: {"foo"}, consts.FormParameterResponseType: {consts.ResponseTypeAuthorizationCodeFlow}, consts.FormParameterScope: {"foo openid"}, consts.FormParameterRequest: {assertionRequestObjectValidNone}, "foo": {"bar"}, "baz": {"baz"}}, err: ErrInvalidRequestObject, - errString: "The request parameter contains an invalid Request Object. OpenID Connect 1.0 request included a request object which excluded claims that are required or included claims that did not match the OAuth 2.0 request syntax or are generally not permitted. The OAuth 2.0 client with id 'foo' included a request object with a 'aud' claim with the values 'https://auth.not-example.com' which is required match the issuer 'https://auth.example.com'.", + errString: "The request parameter contains an invalid Request Object. OpenID Connect 1.0 request object could not be decoded or validated. OpenID Connect 1.0 client with id 'foo' provided a request object that has an invalid audience. The request object was expected to have an 'aud' claim which matches the issuer value of 'https://auth.example.com' but the 'aud' claim had the values 'https://auth.not-example.com'.", }, { name: "ShouldFailRequestURIBadAudience", @@ -350,7 +561,7 @@ func TestAuthorizeRequestParametersFromOpenIDConnectRequestObject(t *testing.T) client: &DefaultJARClient{JSONWebKeysURI: root.JoinPath("jwks.json").String(), RequestObjectSigningAlg: "RS256", DefaultClient: &DefaultClient{ID: "foo"}}, expected: url.Values{consts.FormParameterState: {"some-state"}, consts.FormParameterClientID: {"foo"}, consts.FormParameterResponseType: {consts.ResponseTypeAuthorizationCodeFlow}, consts.FormParameterScope: {"foo openid"}, consts.FormParameterRequest: {assertionRequestObjectValidNone}, "foo": {"bar"}, "baz": {"baz"}}, err: ErrInvalidRequestObject, - errString: "The request parameter contains an invalid Request Object. OpenID Connect 1.0 request included a request object which excluded claims that are required or included claims that did not match the OAuth 2.0 request syntax or are generally not permitted. The OAuth 2.0 client with id 'foo' included a request object with a 'iss' claim with a value of 'not-foo' which is required to match the value 'foo' in the parameter with the same name from the OAuth 2.0 request syntax.", + errString: "The request parameter contains an invalid Request Object. OpenID Connect 1.0 request object could not be decoded or validated. OpenID Connect 1.0 client with id 'foo' provided a request object that has an invalid issuer. The request object was expected to have an 'iss' claim which matches the value 'foo' but the 'iss' claim had the value 'not-foo'.", }, { name: "ShouldFailRequestURIBadIssuer", @@ -371,7 +582,14 @@ func TestAuthorizeRequestParametersFromOpenIDConnectRequestObject(t *testing.T) }, } - provider := &Fosite{Config: &Config{JWKSFetcherStrategy: NewDefaultJWKSFetcherStrategy(), IDTokenIssuer: "https://auth.example.com"}} + config := &Config{JWKSFetcherStrategy: NewDefaultJWKSFetcherStrategy(), IDTokenIssuer: "https://auth.example.com"} + + strategy := &jwt.DefaultStrategy{ + Config: config, + Issuer: jwt.NewDefaultIssuerUnverifiedFromJWKS(jwksPrivate), + } + + provider := &Fosite{Config: &Config{JWKSFetcherStrategy: NewDefaultJWKSFetcherStrategy(), IDTokenIssuer: "https://auth.example.com", JWTStrategy: strategy}} err = provider.authorizeRequestParametersFromOpenIDConnectRequestObject(context.Background(), r, tc.par) if tc.err != nil { @@ -400,21 +618,42 @@ func mustGenerateAssertion(t *testing.T, claims jwt.MapClaims, key *rsa.PrivateK if kid != "" { token.Header[consts.JSONWebTokenHeaderKeyIdentifier] = kid } - tokenString, err := token.SignedString(key) + tokenString, err := token.CompactSignedString(key) require.NoError(t, err) return tokenString } func mustGenerateHSAssertion(t *testing.T, claims jwt.MapClaims) string { token := jwt.NewWithClaims(jose.HS256, claims) - tokenString, err := token.SignedString([]byte("aaaaaaaaaaaaaaabbbbbbbbbbbbbbbbbbbbbbbcccccccccccccccccccccddddddddddddddddddddddd")) + tokenString, err := token.CompactSignedString([]byte("aaaaaaaaaaaaaaabbbbbbbbbbbbbbbbbbbbbbbcccccccccccccccccccccddddddddddddddddddddddd")) require.NoError(t, err) return tokenString } -func mustGenerateNoneAssertion(t *testing.T, claims jwt.MapClaims) string { - token := jwt.NewWithClaims(jwt.SigningMethodNone, claims) - tokenString, err := token.SignedString(jwt.UnsafeAllowNoneSignatureType) +func mangleSig(tokenString string) string { + parts := strings.Split(tokenString, ".") + raw, err := base64.RawURLEncoding.DecodeString(parts[2]) + if err != nil { + panic(err) + } + + raw = append(raw, []byte("abc")...) + + parts[2] = base64.RawURLEncoding.EncodeToString(raw) + + return strings.Join(parts, ".") +} + +func mustGenerateRequestObjectJWS(t *testing.T, claims jwt.MapClaims, headers jwt.Mapper, key *jose.JSONWebKey) string { + token, _, err := jwt.EncodeCompactSigned(context.TODO(), claims, headers, key) require.NoError(t, err) - return tokenString + + return token +} + +func mustGenerateRequestObjectJWE(t *testing.T, claims jwt.MapClaims, headers, headersJWE jwt.Mapper, key *jose.JSONWebKey, keyEnc *jose.JSONWebKey, enc jose.ContentEncryption) string { + token, _, err := jwt.EncodeNestedCompactEncrypted(context.TODO(), claims, headers, headersJWE, key, keyEnc, enc) + require.NoError(t, err) + + return token } diff --git a/authorize_response_writer_test.go b/authorize_response_writer_test.go index 7ba3d441..d6d33973 100644 --- a/authorize_response_writer_test.go +++ b/authorize_response_writer_test.go @@ -91,6 +91,5 @@ func TestNewAuthorizeResponse(t *testing.T) { } else { assert.NotNil(t, responder, "%d", k) } - t.Logf("Passed test case %d", k) } } diff --git a/client.go b/client.go index 2dc370f2..3a753867 100644 --- a/client.go +++ b/client.go @@ -17,8 +17,21 @@ type Client interface { // GetID returns the client ID. GetID() (id string) + // GetClientSecret returns the ClientSecret. GetClientSecret() (secret ClientSecret) + // GetClientSecretPlainText returns the ClientSecret as plaintext if available. The semantics of this function + // return values are important. + // If the client is not configured with a secret the return should be: + // - secret with value nil, ok with value false, and err with value of nil + // If the client is configured with a secret but is hashed or otherwise not a plaintext value: + // - secret with value nil, ok with value true, and err with value of nil + // If an error occurs retrieving the secret other than this: + // - secret with value nil, ok with value true, and err with value of the error + // If the plaintext secret is successful: + // - secret with value of the bytes of the plaintext secret, ok with value true, and err with value of nil + GetClientSecretPlainText() (secret []byte, ok bool, err error) + // GetRedirectURIs returns the client's allowed redirect URIs. GetRedirectURIs() []string @@ -230,6 +243,19 @@ type AuthenticationMethodClient interface { // methods. GetRevocationEndpointAuthSigningAlg() (alg string) + // GetPushedAuthorizationRequestEndpointAuthMethod is equivalent to the + // 'pushed_authorize_request_endpoint_auth_method' client metadata value which determines the requested Client + // Authentication method for the Pushed Authorization Request Endpoint. The options are client_secret_post, + // client_secret_basic, client_secret_jwt, and private_key_jwt. + GetPushedAuthorizationRequestEndpointAuthMethod() (method string) + + // GetPushedAuthorizationRequestEndpointAuthSigningAlg is equivalent to the + // 'pushed_authorization_request_endpoint_auth_signing_alg' client metadata value which determines the JWS [JWS] alg + // algorithm [JWA] that MUST be used for signing the JWT [JWT] used to authenticate the + // Client at the Pushed Authorization Request Endpoint for the private_key_jwt and client_secret_jwt authentication + // methods. + GetPushedAuthorizationRequestEndpointAuthSigningAlg() (alg string) + JSONWebKeysClient } @@ -368,7 +394,7 @@ type RequestedAudienceImplicitClient interface { type IntrospectionJWTResponseClient interface { // GetIntrospectionSignedResponseKeyID returns the specific key identifier used to satisfy JWS requirements for // OAuth 2.0 JWT introspection response specifications. If unspecified the other available parameters will be - // // utilized to select an appropriate key. + // utilized to select an appropriate key. GetIntrospectionSignedResponseKeyID() (kid string) // GetIntrospectionSignedResponseAlg is equivalent to the 'introspection_signed_response_alg' client metadata @@ -379,7 +405,7 @@ type IntrospectionJWTResponseClient interface { // GetIntrospectionEncryptedResponseKeyID returns the specific key identifier used to satisfy JWE requirements for // OAuth 2.0 JWT introspection response specifications. If unspecified the other available parameters will be - // // utilized to select an appropriate key. + // utilized to select an appropriate key. GetIntrospectionEncryptedResponseKeyID() (kid string) // GetIntrospectionEncryptedResponseAlg is equivalent to the 'introspection_encrypted_response_alg' client metadata @@ -414,20 +440,22 @@ type DefaultClient struct { type DefaultJARClient struct { *DefaultClient - JSONWebKeysURI string `json:"jwks_uri"` - JSONWebKeys *jose.JSONWebKeySet `json:"jwks"` - TokenEndpointAuthMethod string `json:"token_endpoint_auth_method"` - IntrospectionEndpointAuthMethod string `json:"introspection_endpoint_auth_method"` - RevocationEndpointAuthMethod string `json:"revocation_endpoint_auth_method"` - RequestURIs []string `json:"request_uris"` - RequestObjectSigningKeyID string `json:"request_object_signing_kid"` - RequestObjectSigningAlg string `json:"request_object_signing_alg"` - RequestObjectEncryptionKeyID string `json:"request_object_encryption_kid"` - RequestObjectEncryptionAlg string `json:"request_object_encryption_alg"` - RequestObjectEncryptionEnc string `json:"request_object_encryption_enc"` - TokenEndpointAuthSigningAlg string `json:"token_endpoint_auth_signing_alg"` - IntrospectionEndpointAuthSigningAlg string `json:"introspection_endpoint_auth_signing_alg"` - RevocationEndpointAuthSigningAlg string `json:"revocation_endpoint_auth_signing_alg"` + JSONWebKeysURI string `json:"jwks_uri"` + JSONWebKeys *jose.JSONWebKeySet `json:"jwks"` + TokenEndpointAuthMethod string `json:"token_endpoint_auth_method"` + IntrospectionEndpointAuthMethod string `json:"introspection_endpoint_auth_method"` + RevocationEndpointAuthMethod string `json:"revocation_endpoint_auth_method"` + PushedAuthorizationRequestEndpointAuthMethod string `json:"pushed_authorization_request_endpoint_auth_method"` + RequestURIs []string `json:"request_uris"` + RequestObjectSigningKeyID string `json:"request_object_signing_kid"` + RequestObjectSigningAlg string `json:"request_object_signing_alg"` + RequestObjectEncryptionKeyID string `json:"request_object_encryption_kid"` + RequestObjectEncryptionAlg string `json:"request_object_encryption_alg"` + RequestObjectEncryptionEnc string `json:"request_object_encryption_enc"` + TokenEndpointAuthSigningAlg string `json:"token_endpoint_auth_signing_alg"` + IntrospectionEndpointAuthSigningAlg string `json:"introspection_endpoint_auth_signing_alg"` + RevocationEndpointAuthSigningAlg string `json:"revocation_endpoint_auth_signing_alg"` + PushedAuthorizationRequestEndpointAuthSigningAlg string `json:"pushed_authorization_request_endpoint_auth_signing_alg"` } type DefaultResponseModeClient struct { @@ -455,6 +483,22 @@ func (c *DefaultClient) GetClientSecret() (secret ClientSecret) { return c.ClientSecret } +func (c *DefaultClient) GetClientSecretPlainText() (secret []byte, ok bool, err error) { + if c.ClientSecret == nil || !c.ClientSecret.Valid() { + return nil, false, nil + } + + if !c.ClientSecret.IsPlainText() { + return nil, true, nil + } + + if secret, err = c.ClientSecret.GetPlainTextValue(); err != nil { + return nil, true, err + } + + return secret, true, nil +} + func (c *DefaultClient) GetRotatedClientSecrets() (secrets []ClientSecret) { return c.RotatedClientSecrets } @@ -513,6 +557,10 @@ func (c *DefaultJARClient) GetRevocationEndpointAuthSigningAlg() string { return c.RevocationEndpointAuthSigningAlg } +func (c *DefaultJARClient) GetPushedAuthorizationRequestEndpointAuthSigningAlg() (alg string) { + return c.PushedAuthorizationRequestEndpointAuthSigningAlg +} + func (c *DefaultJARClient) GetRequestObjectSigningKeyID() string { return c.RequestObjectSigningKeyID } @@ -545,6 +593,10 @@ func (c *DefaultJARClient) GetRevocationEndpointAuthMethod() string { return c.RevocationEndpointAuthMethod } +func (c *DefaultJARClient) GetPushedAuthorizationRequestEndpointAuthMethod() string { + return c.PushedAuthorizationRequestEndpointAuthMethod +} + func (c *DefaultJARClient) GetRequestURIs() []string { return c.RequestURIs } diff --git a/client_authentication.go b/client_authentication.go index 6bb6cfd0..d40aebef 100644 --- a/client_authentication.go +++ b/client_authentication.go @@ -6,19 +6,15 @@ package oauth2 import ( "context" "crypto" - "crypto/ecdsa" - "crypto/rsa" "encoding/base64" "errors" "net/http" "net/url" - "sort" "strings" "github.com/go-jose/go-jose/v4" "authelia.com/provider/oauth2/internal/consts" - "authelia.com/provider/oauth2/token/jwt" "authelia.com/provider/oauth2/x/errorsx" ) @@ -47,34 +43,6 @@ func (f *Fosite) AuthenticateClientWithAuthHandler(ctx context.Context, r *http. return strategy.AuthenticateClient(ctx, r, form, handler) } -func (f *Fosite) findClientPublicJWK(ctx context.Context, client JARClient, t *jwt.Token, expectsRSAKey bool) (key any, err error) { - var ( - keys *jose.JSONWebKeySet - ) - - if keys = client.GetJSONWebKeys(); keys != nil { - return findPublicKey(t, keys, expectsRSAKey) - } - - if location := client.GetJSONWebKeysURI(); len(location) > 0 { - if keys, err = f.Config.GetJWKSFetcherStrategy(ctx).Resolve(ctx, location, false); err != nil { - return nil, err - } - - if key, err = findPublicKey(t, keys, expectsRSAKey); err == nil { - return key, nil - } - - if keys, err = f.Config.GetJWKSFetcherStrategy(ctx).Resolve(ctx, location, true); err != nil { - return nil, err - } - - return findPublicKey(t, keys, expectsRSAKey) - } - - return nil, errorsx.WithStack(ErrInvalidClient.WithHint("The OAuth 2.0 Client has no JSON Web Keys set registered, but they are needed to complete the request.")) -} - // CompareClientSecret compares a raw secret input from a client to the registered client secret. If the secret is valid // it returns nil, otherwise it returns an error. The ErrClientSecretNotRegistered error indicates the ClientSecret // is nil, all other errors are returned directly from the ClientSecret.Compare function. @@ -111,35 +79,6 @@ func CompareClientSecret(ctx context.Context, client Client, rawSecret []byte) ( return err } -// FindClientPublicJWK takes a JARClient and a kid, alg, and use to resolve a Public JWK for the client. -func FindClientPublicJWK(ctx context.Context, provider JWKSFetcherStrategyProvider, client JSONWebKeysClient, kid, alg, use string) (key any, err error) { - if set := client.GetJSONWebKeys(); set != nil { - return findPublicKeyByKID(kid, alg, use, set) - } - - strategy := provider.GetJWKSFetcherStrategy(ctx) - - var keys *jose.JSONWebKeySet - - if location := client.GetJSONWebKeysURI(); len(location) > 0 { - if keys, err = strategy.Resolve(ctx, location, false); err != nil { - return nil, err - } - - if key, err = findPublicKeyByKID(kid, alg, use, keys); err == nil { - return key, nil - } - - if keys, err = strategy.Resolve(ctx, location, true); err != nil { - return nil, err - } - - return findPublicKeyByKID(kid, alg, use, keys) - } - - return nil, errorsx.WithStack(ErrInvalidClient.WithHint("The OAuth 2.0 Client has no JSON Web Keys set registered, but they are needed to complete the request.")) -} - func getClientCredentialsSecretBasic(r *http.Request) (id, secret string, ok bool, err error) { auth := r.Header.Get(consts.HeaderAuthorization) @@ -188,119 +127,6 @@ func getClientCredentialsSecretBasic(r *http.Request) (id, secret string, ok boo return id, secret, secret != "", nil } -func getJWTHeaderKIDAlg(header map[string]any) (kid, alg string) { - kid, _ = header[consts.JSONWebTokenHeaderKeyIdentifier].(string) - alg, _ = header[consts.JSONWebTokenHeaderAlgorithm].(string) - - return kid, alg -} - -type partial struct { - points int - jwk jose.JSONWebKey -} - -func findPublicKeyByKID(kid, alg, use string, set *jose.JSONWebKeySet) (key any, err error) { - if len(set.Keys) == 0 { - return nil, errorsx.WithStack(ErrInvalidRequest.WithHintf("The retrieved JSON Web Key Set does not contain any JSON Web Keys.")) - } - - partials := []partial{} - - for _, jwk := range set.Keys { - if jwk.Use == use && jwk.Algorithm == alg && jwk.KeyID == kid { - switch k := jwk.Key.(type) { - case PrivateKey: - return k.Public(), nil - default: - return k, nil - } - } - - p := partial{} - - if jwk.KeyID != kid { - if jwk.KeyID == "" { - p.points -= 3 - } else { - continue - } - } - - if jwk.Use != use { - if jwk.Use == "" { - p.points -= 2 - } else { - continue - } - } - - if jwk.Algorithm != alg && jwk.Algorithm != "" { - if jwk.Algorithm == "" { - p.points -= 1 - } else { - continue - } - } - - p.jwk = jwk - - partials = append(partials, p) - } - - if len(partials) != 0 { - sort.Slice(partials, func(i, j int) bool { - return partials[i].points > partials[j].points - }) - - switch k := partials[0].jwk.Key.(type) { - case PrivateKey: - return k.Public(), nil - default: - return k, nil - } - } - - return nil, errorsx.WithStack(ErrInvalidRequest.WithHintf("Unable to find JWK with kid value '%s', alg value '%s', and use value '%s' in the JSON Web Key Set.", kid, alg, use)) -} - -func findPublicKey(t *jwt.Token, set *jose.JSONWebKeySet, expectsRSAKey bool) (any, error) { - keys := set.Keys - if len(keys) == 0 { - return nil, errorsx.WithStack(ErrInvalidRequest.WithHintf("The retrieved JSON Web Key Set does not contain any key.")) - } - - kid, ok := t.Header[consts.JSONWebTokenHeaderKeyIdentifier].(string) - if ok { - keys = set.Key(kid) - } - - if len(keys) == 0 { - return nil, errorsx.WithStack(ErrInvalidRequest.WithHintf("The JSON Web Token uses signing key with kid '%s', which could not be found.", kid)) - } - - for _, key := range keys { - if key.Use != consts.JSONWebTokenUseSignature { - continue - } - if expectsRSAKey { - if k, ok := key.Key.(*rsa.PublicKey); ok { - return k, nil - } - } else { - if k, ok := key.Key.(*ecdsa.PublicKey); ok { - return k, nil - } - } - } - - if expectsRSAKey { - return nil, errorsx.WithStack(ErrInvalidRequest.WithHintf("Unable to find RSA public key with use='sig' for kid '%s' in JSON Web Key Set.", kid)) - } else { - return nil, errorsx.WithStack(ErrInvalidRequest.WithHintf("Unable to find ECDSA public key with use='sig' for kid '%s' in JSON Web Key Set.", kid)) - } -} - func getClientCredentialsClientAssertion(form url.Values) (assertion, assertionType string, hasAssertion bool) { assertionType, assertion = form.Get(consts.FormParameterClientAssertionType), form.Get(consts.FormParameterClientAssertion) @@ -335,9 +161,21 @@ type EndpointClientAuthHandler interface { // GetAuthMethod returns the appropriate auth method for this client. GetAuthMethod(client AuthenticationMethodClient) string + // GetAuthSigningKeyID returns the appropriate auth signature key id for this client. + GetAuthSigningKeyID(client AuthenticationMethodClient) string + // GetAuthSigningAlg returns the appropriate auth signature algorithm for this client. GetAuthSigningAlg(client AuthenticationMethodClient) string + // GetAuthEncryptionKeyID returns the appropriate auth encryption key id for this client. + GetAuthEncryptionKeyID(client AuthenticationMethodClient) string + + // GetAuthEncryptionAlg returns the appropriate auth encryption key algorithm for this client. + GetAuthEncryptionAlg(client AuthenticationMethodClient) string + + // GetAuthEncryptionEnc returns the appropriate auth encryption content encryption for this client. + GetAuthEncryptionEnc(client AuthenticationMethodClient) string + // Name returns the appropriate name for this endpoint for logging purposes. Name() string @@ -345,16 +183,77 @@ type EndpointClientAuthHandler interface { AllowAuthMethodAny() bool } +type EndpointClientAuthJWTClient struct { + client AuthenticationMethodClient + handler EndpointClientAuthHandler +} + +func (c *EndpointClientAuthJWTClient) GetID() string { + return c.client.GetID() +} + +func (c *EndpointClientAuthJWTClient) GetClientSecretPlainText() (secret []byte, ok bool, err error) { + return c.client.GetClientSecretPlainText() +} + +func (c *EndpointClientAuthJWTClient) GetJSONWebKeys() (jwks *jose.JSONWebKeySet) { + return c.client.GetJSONWebKeys() +} + +func (c *EndpointClientAuthJWTClient) GetJSONWebKeysURI() (uri string) { + return c.client.GetJSONWebKeysURI() +} + +func (c *EndpointClientAuthJWTClient) GetSigningKeyID() (kid string) { + return "" +} + +func (c *EndpointClientAuthJWTClient) GetSigningAlg() (alg string) { + return c.handler.GetAuthSigningAlg(c.client) +} + +func (c *EndpointClientAuthJWTClient) GetEncryptionKeyID() (kid string) { + return "" +} + +func (c *EndpointClientAuthJWTClient) GetEncryptionAlg() (alg string) { + return "" +} + +func (c *EndpointClientAuthJWTClient) GetEncryptionEnc() (enc string) { + return "" +} + +func (c *EndpointClientAuthJWTClient) IsClientSigned() (is bool) { + return true +} + type TokenEndpointClientAuthHandler struct{} func (h *TokenEndpointClientAuthHandler) GetAuthMethod(client AuthenticationMethodClient) string { return client.GetTokenEndpointAuthMethod() } +func (h *TokenEndpointClientAuthHandler) GetAuthSigningKeyID(client AuthenticationMethodClient) string { + return "" +} + func (h *TokenEndpointClientAuthHandler) GetAuthSigningAlg(client AuthenticationMethodClient) string { return client.GetTokenEndpointAuthSigningAlg() } +func (h *TokenEndpointClientAuthHandler) GetAuthEncryptionKeyID(client AuthenticationMethodClient) string { + return "" +} + +func (h *TokenEndpointClientAuthHandler) GetAuthEncryptionAlg(client AuthenticationMethodClient) string { + return "" +} + +func (h *TokenEndpointClientAuthHandler) GetAuthEncryptionEnc(client AuthenticationMethodClient) string { + return "" +} + func (h *TokenEndpointClientAuthHandler) Name() string { return "token" } @@ -369,10 +268,26 @@ func (h *IntrospectionEndpointClientAuthHandler) GetAuthMethod(client Authentica return client.GetIntrospectionEndpointAuthMethod() } +func (h *IntrospectionEndpointClientAuthHandler) GetAuthSigningKeyID(client AuthenticationMethodClient) string { + return "" +} + func (h *IntrospectionEndpointClientAuthHandler) GetAuthSigningAlg(client AuthenticationMethodClient) string { return client.GetIntrospectionEndpointAuthSigningAlg() } +func (h *IntrospectionEndpointClientAuthHandler) GetAuthEncryptionKeyID(client AuthenticationMethodClient) string { + return "" +} + +func (h *IntrospectionEndpointClientAuthHandler) GetAuthEncryptionAlg(client AuthenticationMethodClient) string { + return "" +} + +func (h *IntrospectionEndpointClientAuthHandler) GetAuthEncryptionEnc(client AuthenticationMethodClient) string { + return "" +} + func (h *IntrospectionEndpointClientAuthHandler) Name() string { return "introspection" } @@ -387,10 +302,26 @@ func (h *RevocationEndpointClientAuthHandler) GetAuthMethod(client Authenticatio return client.GetRevocationEndpointAuthMethod() } +func (h *RevocationEndpointClientAuthHandler) GetAuthSigningKeyID(client AuthenticationMethodClient) string { + return "" +} + func (h *RevocationEndpointClientAuthHandler) GetAuthSigningAlg(client AuthenticationMethodClient) string { return client.GetRevocationEndpointAuthSigningAlg() } +func (h *RevocationEndpointClientAuthHandler) GetAuthEncryptionKeyID(client AuthenticationMethodClient) string { + return "" +} + +func (h *RevocationEndpointClientAuthHandler) GetAuthEncryptionAlg(client AuthenticationMethodClient) string { + return "" +} + +func (h *RevocationEndpointClientAuthHandler) GetAuthEncryptionEnc(client AuthenticationMethodClient) string { + return "" +} + func (h *RevocationEndpointClientAuthHandler) Name() string { return "revocation" } diff --git a/client_authentication_jwks_strategy.go b/client_authentication_jwks_strategy.go index 2d12fedc..fd882a0e 100644 --- a/client_authentication_jwks_strategy.go +++ b/client_authentication_jwks_strategy.go @@ -13,21 +13,13 @@ import ( "github.com/go-jose/go-jose/v4" "github.com/hashicorp/go-retryablehttp" + "authelia.com/provider/oauth2/token/jwt" "authelia.com/provider/oauth2/x/errorsx" ) const defaultJWKSFetcherStrategyCachePrefix = "authelia.com/provider/oauth2.DefaultJWKSFetcherStrategy:" -// JWKSFetcherStrategy is a strategy which pulls (optionally caches) JSON Web Key Sets from a location, -// typically a client's jwks_uri. -type JWKSFetcherStrategy interface { - // Resolve returns the JSON Web Key Set, or an error if something went wrong. The forceRefresh, if true, forces - // the strategy to fetch the key from the remote. If forceRefresh is false, the strategy may use a caching strategy - // to fetch the key. - Resolve(ctx context.Context, location string, ignoreCache bool) (*jose.JSONWebKeySet, error) -} - -// DefaultJWKSFetcherStrategy is a default implementation of the JWKSFetcherStrategy interface. +// DefaultJWKSFetcherStrategy is a default implementation of the jwt.JWKSFetcherStrategy interface. type DefaultJWKSFetcherStrategy struct { client *retryablehttp.Client cache *ristretto.Cache @@ -36,7 +28,7 @@ type DefaultJWKSFetcherStrategy struct { } // NewDefaultJWKSFetcherStrategy returns a new instance of the DefaultJWKSFetcherStrategy. -func NewDefaultJWKSFetcherStrategy(opts ...func(*DefaultJWKSFetcherStrategy)) JWKSFetcherStrategy { +func NewDefaultJWKSFetcherStrategy(opts ...func(*DefaultJWKSFetcherStrategy)) jwt.JWKSFetcherStrategy { dc, err := ristretto.NewCache(&ristretto.Config{ NumCounters: 10000 * 10, MaxCost: 10000, @@ -100,7 +92,7 @@ func (s *DefaultJWKSFetcherStrategy) Resolve(ctx context.Context, location strin if !ok || ignoreCache { req, err := retryablehttp.NewRequest(http.MethodGet, location, nil) if err != nil { - return nil, errorsx.WithStack(ErrServerError.WithHintf("Unable to create HTTP 'GET' request to fetch JSON Web Keys from location '%s'.", location).WithWrap(err).WithDebugError(err)) + return nil, errorsx.WithStack(ErrServerError.WithHintf("Unable to create HTTP 'GET' request to fetch JSON Web Keys from location '%s'.", location).WithWrap(err).WithDebugError(err)) } hc := s.client diff --git a/client_authentication_secret.go b/client_authentication_secret.go index bcb33845..e84e4ceb 100644 --- a/client_authentication_secret.go +++ b/client_authentication_secret.go @@ -19,7 +19,8 @@ type ClientSecret interface { IsPlainText() (is bool) // GetPlainTextValue is a utility function to return the secret in the plaintext format making it usable for the - // client_secret_jwt authentication method. + // client_secret_jwt authentication method. If the client secret doesn't have a value that is plaintext it should + // return oauth2.ErrClientSecretNotPlainText for the sake of deterministic error values. GetPlainTextValue() (secret []byte, err error) // Valid should return false if the secret is nil or otherwise invalid. diff --git a/client_authentication_secret_plaintext.go b/client_authentication_secret_plaintext.go new file mode 100644 index 00000000..4d7ffc7e --- /dev/null +++ b/client_authentication_secret_plaintext.go @@ -0,0 +1,41 @@ +// Copyright © 2023 Ory Corp +// SPDX-License-Identifier: Apache-2.0 + +package oauth2 + +import ( + "context" + "crypto/subtle" + "fmt" + + "authelia.com/provider/oauth2/x/errorsx" +) + +// NewPlainTextClientSecret returns a new PlainTextClientSecret given a value. +func NewPlainTextClientSecret(value string) *PlainTextClientSecret { + return &PlainTextClientSecret{value: []byte(value)} +} + +type PlainTextClientSecret struct { + value []byte +} + +func (s *PlainTextClientSecret) IsPlainText() (is bool) { + return true +} + +func (s *PlainTextClientSecret) GetPlainTextValue() (secret []byte, err error) { + return s.value, nil +} + +func (s *PlainTextClientSecret) Compare(ctx context.Context, secret []byte) (err error) { + if subtle.ConstantTimeCompare(s.value, secret) == 0 { + return errorsx.WithStack(fmt.Errorf("secrets don't match")) + } + + return nil +} + +func (s *PlainTextClientSecret) Valid() (valid bool) { + return s != nil && len(s.value) != 0 +} diff --git a/client_authentication_strategy.go b/client_authentication_strategy.go index 9f63f7d3..0b39c58f 100644 --- a/client_authentication_strategy.go +++ b/client_authentication_strategy.go @@ -10,9 +10,8 @@ import ( "strings" "time" - xjwt "github.com/golang-jwt/jwt/v5" - "authelia.com/provider/oauth2/internal/consts" + "authelia.com/provider/oauth2/token/jwt" "authelia.com/provider/oauth2/x/errorsx" ) @@ -21,12 +20,13 @@ type DefaultClientAuthenticationStrategy struct { ClientManager } Config interface { + JWTStrategyProvider JWKSFetcherStrategyProvider AllowedJWTAssertionAudiencesProvider } } -func (s *DefaultClientAuthenticationStrategy) AuthenticateClient(ctx context.Context, r *http.Request, form url.Values, resolver EndpointClientAuthHandler) (client Client, method string, err error) { +func (s *DefaultClientAuthenticationStrategy) AuthenticateClient(ctx context.Context, r *http.Request, form url.Values, handler EndpointClientAuthHandler) (client Client, method string, err error) { var ( id, secret string @@ -48,7 +48,7 @@ func (s *DefaultClientAuthenticationStrategy) AuthenticateClient(ctx context.Con var assertion *ClientAssertion if hasAssertion { - if assertion, err = NewClientAssertion(ctx, s.Store, assertionValue, assertionType, resolver); err != nil { + if assertion, err = NewClientAssertion(ctx, s.Config.GetJWTStrategy(ctx), s.Store, assertionValue, assertionType, handler); err != nil { return nil, "", err } } @@ -64,10 +64,10 @@ func (s *DefaultClientAuthenticationStrategy) AuthenticateClient(ctx context.Con hasNone := !hasPost && !hasBasic && assertion == nil && len(id) != 0 - return s.authenticate(ctx, id, secret, assertion, hasBasic, hasPost, hasNone, resolver) + return s.authenticate(ctx, id, secret, assertion, hasBasic, hasPost, hasNone, handler) } -func (s *DefaultClientAuthenticationStrategy) authenticate(ctx context.Context, id, secret string, assertion *ClientAssertion, hasBasic, hasPost, hasNone bool, resolver EndpointClientAuthHandler) (client Client, method string, err error) { +func (s *DefaultClientAuthenticationStrategy) authenticate(ctx context.Context, id, secret string, assertion *ClientAssertion, hasBasic, hasPost, hasNone bool, handler EndpointClientAuthHandler) (client Client, method string, err error) { var methods []string if hasBasic { @@ -115,16 +115,16 @@ func (s *DefaultClientAuthenticationStrategy) authenticate(ctx context.Context, return nil, "", errorsx.WithStack(ErrInvalidRequest. WithHintf("Client Authentication failed with more than one known authentication method included in the request which is not permitted."). - WithDebugf("The registered client with id '%s' and the authorization server policy does not permit this malformed request. The `%s_endpoint_auth_method` methods determined to be used were '%s'.", client.GetID(), resolver.Name(), strings.Join(methods, "', '"))) + WithDebugf("The registered client with id '%s' and the authorization server policy does not permit this malformed request. The `%s_endpoint_auth_method` methods determined to be used were '%s'.", client.GetID(), handler.Name(), strings.Join(methods, "', '"))) } switch { case assertion != nil: - method, err = s.doAuthenticateAssertionJWTBearer(ctx, client, assertion, resolver) + method, err = s.doAuthenticateAssertionJWTBearer(ctx, client, assertion, handler) case hasBasic, hasPost: - method, err = s.doAuthenticateClientSecret(ctx, client, secret, hasBasic, hasPost, resolver) + method, err = s.doAuthenticateClientSecret(ctx, client, secret, hasBasic, hasPost, handler) default: - method, err = s.doAuthenticateNone(ctx, client, resolver) + method, err = s.doAuthenticateNone(ctx, client, handler) } if err != nil { @@ -134,54 +134,58 @@ func (s *DefaultClientAuthenticationStrategy) authenticate(ctx context.Context, return client, method, nil } -func NewClientAssertion(ctx context.Context, store ClientManager, raw, assertionType string, resolver EndpointClientAuthHandler) (assertion *ClientAssertion, err error) { +// NewClientAssertion converts a raw assertion string into a *ClientAssertion. +func NewClientAssertion(ctx context.Context, strategy jwt.Strategy, store ClientManager, assertion, assertionType string, handler EndpointClientAuthHandler) (a *ClientAssertion, err error) { var ( - token *xjwt.Token + token *jwt.Token - id, alg, method string - client Client + id, method string + client Client ) switch assertionType { case consts.ClientAssertionTypeJWTBearer: - if len(raw) == 0 { - return &ClientAssertion{Raw: raw, Type: assertionType}, errorsx.WithStack(ErrInvalidRequest.WithHintf("The request parameter 'client_assertion' must be set when using 'client_assertion_type' of '%s'.", consts.ClientAssertionTypeJWTBearer)) + if len(assertion) == 0 { + return &ClientAssertion{Assertion: assertion, Type: assertionType}, errorsx.WithStack(ErrInvalidRequest.WithHintf("The request parameter 'client_assertion' must be set when using 'client_assertion_type' of '%s'.", consts.ClientAssertionTypeJWTBearer)) } default: - return &ClientAssertion{Raw: raw, Type: assertionType}, errorsx.WithStack(ErrInvalidRequest.WithHintf("Unknown client_assertion_type '%s'.", assertionType)) + return &ClientAssertion{Assertion: assertion, Type: assertionType}, errorsx.WithStack(ErrInvalidRequest.WithHintf("Unknown client_assertion_type '%s'.", assertionType)) } - if token, _, err = xjwt.NewParser(xjwt.WithoutClaimsValidation()).ParseUnverified(raw, &xjwt.MapClaims{}); err != nil { - return &ClientAssertion{Raw: raw, Type: assertionType}, resolveJWTErrorToRFCError(err) + if token, err = strategy.Decode(ctx, assertion, jwt.WithAllowUnverified(), jwt.WithSigAlgorithm(jwt.SignatureAlgorithmsNone...)); err != nil { + return &ClientAssertion{Assertion: assertion, Type: assertionType}, resolveJWTErrorToRFCError(err) } - if id, err = token.Claims.GetSubject(); err != nil { - if id, err = token.Claims.GetIssuer(); err != nil { - return &ClientAssertion{Raw: raw, Type: assertionType}, nil + if id, err = token.Claims.GetSubject(); err != nil || len(id) == 0 { + if id, err = token.Claims.GetIssuer(); err != nil || len(id) == 0 { + return &ClientAssertion{Assertion: assertion, Type: assertionType}, nil } } if client, err = store.GetClient(ctx, id); err != nil { - return &ClientAssertion{Raw: raw, Type: assertionType, ID: id}, nil + return &ClientAssertion{Assertion: assertion, Type: assertionType, ID: id}, nil } - if c, ok := client.(AuthenticationMethodClient); ok { - alg, method = resolver.GetAuthSigningAlg(c), resolver.GetAuthMethod(c) + method = consts.ClientAuthMethodPrivateKeyJWT + + if jwt.IsSignedJWTClientSecretAlg(token.SignatureAlgorithm) { + method = consts.ClientAuthMethodClientSecretJWT } return &ClientAssertion{ - Raw: raw, + Assertion: assertion, Type: assertionType, Parsed: true, ID: id, Method: method, - Algorithm: alg, + Algorithm: string(token.SignatureAlgorithm), Client: client, }, nil } +// ClientAssertion represents a client assertion. type ClientAssertion struct { - Raw, Type string + Assertion, Type string Parsed bool ID, Method, Algorithm string Client Client @@ -222,7 +226,7 @@ func (s *DefaultClientAuthenticationStrategy) doAuthenticateClientSecret(ctx con return "", errorsx.WithStack( ErrInvalidClient. WithHintf("The request was determined to be using '%s_endpoint_auth_method' method '%s', however the OAuth 2.0 client registration does not allow this method.", handler.Name(), method). - WithDebugf("The registered client with id '%s' is configured to only support '%s_endpoint_auth_method' method '%s'. Either the Authorization Server client registration will need to have the '%s_endpoint_auth_method' updated to '%s' or the Relying Party will need to be configured to use '%s'.", client.GetID(), handler.Name(), cmethod, handler.Name(), method, cmethod)) + WithDebugf("The registered client with id '%s' is configured to only support '%s_endpoint_auth_method' method '%s'. Either the Authorization Server client registration will need to have the '%s_endpoint_auth_method' updated to '%s'x or the Relying Party will need to be configured to use '%s'.", client.GetID(), handler.Name(), cmethod, handler.Name(), method, cmethod)) } } @@ -240,35 +244,62 @@ func (s *DefaultClientAuthenticationStrategy) doAuthenticateClientSecret(ctx con } } -func (s *DefaultClientAuthenticationStrategy) doAuthenticateAssertionJWTBearer(ctx context.Context, client Client, assertion *ClientAssertion, resolver EndpointClientAuthHandler) (method string, err error) { +func (s *DefaultClientAuthenticationStrategy) doAuthenticateAssertionJWTBearer(ctx context.Context, client Client, assertion *ClientAssertion, handler EndpointClientAuthHandler) (method string, err error) { var ( - token *xjwt.Token - claims *xjwt.RegisteredClaims + token *jwt.Token + c AuthenticationMethodClient + ok bool ) - if method, _, _, token, claims, err = s.doAuthenticateAssertionParseAssertionJWTBearer(ctx, client, assertion, resolver); err != nil { + if c, ok = client.(AuthenticationMethodClient); !ok { + return "", errorsx.WithStack(ErrInvalidRequest.WithHint("The registered client does not support OAuth 2.0 JWT Profile Client Authentication RFC7523 or OpenID Connect 1.0 specific authentication methods.")) + } + + if !assertion.Parsed { + + } + + if method, _, _, token, err = s.doAuthenticateAssertionParseAssertionJWTBearer(ctx, c, assertion, handler); err != nil { return "", err } if token == nil { - return "", err + return "", errorsx.WithStack(ErrInvalidClient.WithDebug("The client assertion did not result in a parsed token.")) } clientID := []byte(client.GetID()) + claims := &jwt.JWTClaims{} + + claims.FromMapClaims(token.Claims.ToMapClaims()) + switch { case subtle.ConstantTimeCompare([]byte(claims.Issuer), clientID) == 0: - return "", errorsx.WithStack(ErrInvalidClient.WithHint("Claim 'iss' from 'client_assertion' must match the 'client_id' of the OAuth 2.0 Client.")) + return "", errorsx.WithStack(ErrInvalidClient.WithHint("The client assertion had invalid claims.").WithDebug("Claim 'iss' from 'client_assertion' must match the 'client_id' of the OAuth 2.0 Client.")) case subtle.ConstantTimeCompare([]byte(claims.Subject), clientID) == 0: - return "", errorsx.WithStack(ErrInvalidClient.WithHint("Claim 'sub' from 'client_assertion' must match the 'client_id' of the OAuth 2.0 Client.")) - case claims.ID == "": - return "", errorsx.WithStack(ErrInvalidClient.WithHint("Claim 'jti' from 'client_assertion' must be set but is not.")) + return "", errorsx.WithStack(ErrInvalidClient.WithHint("The client assertion had invalid claims.").WithDebug("Claim 'sub' from 'client_assertion' must match the 'client_id' of the OAuth 2.0 Client.")) + case claims.JTI == "": + return "", errorsx.WithStack(ErrInvalidClient.WithHint("The client assertion had invalid claims.").WithDebug("Claim 'jti' from 'client_assertion' must be set but is not.")) default: - if err = s.Store.ClientAssertionJWTValid(ctx, claims.ID); err != nil { + switch cmethod := handler.GetAuthMethod(c); { + case cmethod == "" && handler.AllowAuthMethodAny(): + break + case cmethod != method: + return "", errorsx.WithStack( + ErrInvalidClient. + WithHintf("The request was determined to be using '%s_endpoint_auth_method' method '%s', however the OAuth 2.0 client registration does not allow this method.", handler.Name(), method). + WithDebugf("The registered client with id '%s' is configured to only support '%s_endpoint_auth_method' method '%s'. Either the Authorization Server client registration will need to have the '%s_endpoint_auth_method' updated to '%s' or the Relying Party will need to be configured to use '%s'.", client.GetID(), handler.Name(), cmethod, handler.Name(), method, cmethod)) + } + + if !assertion.Parsed { + return "", errorsx.WithStack(ErrInvalidClient.WithDebug("The client assertion was not able to be parsed.")) + } + + if err = s.Store.ClientAssertionJWTValid(ctx, claims.JTI); err != nil { return "", errorsx.WithStack(ErrJTIKnown.WithHint("Claim 'jti' from 'client_assertion' MUST only be used once.").WithDebugError(err)) } - if err = s.Store.SetClientAssertionJWT(ctx, claims.ID, time.Unix(claims.ExpiresAt.Unix(), 0)); err != nil { + if err = s.Store.SetClientAssertionJWT(ctx, claims.JTI, time.Unix(claims.ExpiresAt.Unix(), 0)); err != nil { return "", err } @@ -276,168 +307,157 @@ func (s *DefaultClientAuthenticationStrategy) doAuthenticateAssertionJWTBearer(c } } -func (s *DefaultClientAuthenticationStrategy) doAuthenticateAssertionParseAssertionJWTBearer(ctx context.Context, client Client, assertion *ClientAssertion, resolver EndpointClientAuthHandler) (method, kid, alg string, token *xjwt.Token, claims *xjwt.RegisteredClaims, err error) { +func (s *DefaultClientAuthenticationStrategy) doAuthenticateAssertionParseAssertionJWTBearer(ctx context.Context, client AuthenticationMethodClient, assertion *ClientAssertion, handler EndpointClientAuthHandler) (method, kid, alg string, token *jwt.Token, err error) { audience := s.Config.GetAllowedJWTAssertionAudiences(ctx) if len(audience) == 0 { - return "", "", "", nil, nil, errorsx.WithStack(ErrMisconfiguration.WithHint("The authorization server does not support OAuth 2.0 JWT Profile Client Authentication RFC7523 or OpenID Connect 1.0 specific authentication methods.").WithDebug("The authorization server could not determine any safe value for it's audience but it's required to validate the RFC7523 client assertions.")) - } - - opts := []xjwt.ParserOption{ - xjwt.WithStrictDecoding(), - //xjwt.WithAudience(tokenURI), // Satisfies RFC7523 Section 3 Point 3. - xjwt.WithExpirationRequired(), // Satisfies RFC7523 Section 3 Point 4. - xjwt.WithIssuedAt(), // Satisfies RFC7523 Section 3 Point 6. + return "", "", "", nil, errorsx.WithStack(ErrMisconfiguration.WithHint("The authorization server does not support OAuth 2.0 JWT Profile Client Authentication RFC7523 or OpenID Connect 1.0 specific authentication methods.").WithDebug("The authorization server could not determine any safe value for it's audience but it's required to validate the RFC7523 client assertions.")) } - // Automatically satisfies RFC7523 Section 3 Point 5, 8, 9, and 10. - parser := xjwt.NewParser(opts...) - - claims = &xjwt.RegisteredClaims{} - - if token, err = parser.ParseWithClaims(assertion.Raw, claims, func(token *xjwt.Token) (key any, err error) { - if subtle.ConstantTimeCompare([]byte(client.GetID()), []byte(claims.Subject)) == 0 { - return nil, errorsx.WithStack(ErrInvalidClient.WithHint("The supplied 'client_id' did not match the 'sub' claim of the 'client_assertion'.")) - } - - // The following check satisfies RFC7523 Section 3 Point 2. - // See: https://datatracker.ietf.org/doc/html/rfc7523#section-3. - if claims.Subject == "" { - return nil, errorsx.WithStack(ErrInvalidClient.WithHint("The claim 'sub' from the 'client_assertion' isn't defined.")) - } - - var ( - c AuthenticationMethodClient - ok bool - ) - - if c, ok = client.(AuthenticationMethodClient); !ok { - return nil, errorsx.WithStack(ErrInvalidRequest.WithHint("The registered client does not support OAuth 2.0 JWT Profile Client Authentication RFC7523 or OpenID Connect 1.0 specific authentication methods.")) - } - - return s.doAuthenticateAssertionParseAssertionJWTBearerFindKey(ctx, token.Header, c, resolver) - }); err != nil { - return "", "", "", nil, nil, resolveJWTErrorToRFCError(err) + if token, err = s.Config.GetJWTStrategy(ctx).Decode(ctx, assertion.Assertion, jwt.WithClient(&EndpointClientAuthJWTClient{client: client, handler: handler}), jwt.WithSigAlgorithm(jwt.SignatureAlgorithmsNone...)); err != nil { + return "", "", "", nil, errorsx.WithStack(fmtClientAssertionDecodeError(token, client, handler, audience, err)) } - // Satisfies RFC7523 Section 3 Point 3. - if err = s.doAuthenticateAssertionJWTBearerClaimAudience(ctx, audience, claims); err != nil { - return "", "", "", nil, nil, err + optsClaims := []jwt.ClaimValidationOption{ + jwt.ValidateAudienceAny(audience...), // Satisfies RFC7523 Section 3 Point 3. + jwt.ValidateRequireExpiresAt(), // Satisfies RFC7523 Section 3 Point 4. + jwt.ValidateTimeFunc(time.Now), } - return method, kid, alg, token, claims, nil -} - -func (s *DefaultClientAuthenticationStrategy) doAuthenticateAssertionJWTBearerClaimAudience(ctx context.Context, audience []string, claims *xjwt.RegisteredClaims) (err error) { - if len(claims.Audience) == 0 { - return errorsx.WithStack( - ErrInvalidClient. - WithHint("Unable to verify the integrity of the 'client_assertion' value. It may have been used before it was issued, may have been used before it's allowed to be used, may have been used after it's expired, or otherwise doesn't meet a particular validation constraint."). - WithDebug("Unable to validate the 'aud' claim of the 'client_assertion' as it was empty."), - ) + if err = token.Claims.Valid(optsClaims...); err != nil { + return "", "", "", nil, errorsx.WithStack(fmtClientAssertionDecodeError(token, client, handler, audience, err)) } - validAudience := false - - var aud, unverified string - -verification: - for _, unverified = range claims.Audience { - for _, aud = range audience { - if subtle.ConstantTimeCompare([]byte(aud), []byte(unverified)) == 1 { - validAudience = true - break verification - } - } + optsHeader := []jwt.HeaderValidationOption{ + jwt.ValidateKeyID(handler.GetAuthSigningKeyID(client)), + jwt.ValidateAlgorithm(handler.GetAuthSigningAlg(client)), + jwt.ValidateEncryptionKeyID(handler.GetAuthEncryptionKeyID(client)), + jwt.ValidateKeyAlgorithm(handler.GetAuthEncryptionAlg(client)), + jwt.ValidateContentEncryption(handler.GetAuthEncryptionEnc(client)), } - if !validAudience { - return errorsx.WithStack( - ErrInvalidClient. - WithHint("Unable to verify the integrity of the 'client_assertion' value. It may have been used before it was issued, may have been used before it's allowed to be used, may have been used after it's expired, or otherwise doesn't meet a particular validation constraint."). - WithDebugf("Unable to validate the 'aud' claim of the 'client_assertion' value '%s' as it doesn't match any of the expected values '%s'.", strings.Join(claims.Audience, "', '"), strings.Join(audience, "', '")), - ) + if err = token.Valid(optsHeader...); err != nil { + return "", "", "", nil, errorsx.WithStack(fmtClientAssertionDecodeError(token, client, handler, audience, err)) } - return nil + return assertion.Method, kid, alg, token, nil } -func (s *DefaultClientAuthenticationStrategy) doAuthenticateAssertionParseAssertionJWTBearerFindKey(ctx context.Context, header map[string]any, client AuthenticationMethodClient, handler EndpointClientAuthHandler) (key any, err error) { - var kid, alg, method string - - kid, alg = getJWTHeaderKIDAlg(header) - - if calg := handler.GetAuthSigningAlg(client); calg != alg && calg != "" { - return nil, errorsx.WithStack(ErrInvalidClient.WithHintf("The requested OAuth 2.0 client does not support the '%s_endpoint_auth_signing_alg' value '%s'.", handler.Name(), alg).WithDebugf("The registered OAuth 2.0 client with id '%s' only supports the '%s' algorithm.", client.GetID(), calg)) - } +func (s *DefaultClientAuthenticationStrategy) getClientCredentialsSecretPost(form url.Values) (id, secret string, ok bool) { + id, secret = form.Get(consts.FormParameterClientID), form.Get(consts.FormParameterClientSecret) - switch method = handler.GetAuthMethod(client); method { - case consts.ClientAuthMethodClientSecretJWT: - return s.doAuthenticateAssertionParseAssertionJWTBearerFindKeyClientSecretJWT(ctx, kid, alg, client, handler) - case consts.ClientAuthMethodPrivateKeyJWT: - return s.doAuthenticateAssertionParseAssertionJWTBearerFindKeyPrivateKeyJWT(ctx, kid, alg, client, handler) - case consts.ClientAuthMethodNone: - return nil, errorsx.WithStack(ErrInvalidClient.WithHint("This requested OAuth 2.0 client does not support client authentication, however 'client_assertion' was provided in the request.")) - case consts.ClientAuthMethodClientSecretBasic, consts.ClientAuthMethodClientSecretPost: - return nil, errorsx.WithStack(ErrInvalidClient.WithHintf("This requested OAuth 2.0 client only supports client authentication method '%s', however 'client_assertion' was provided in the request.", method)) - default: - return nil, errorsx.WithStack(ErrInvalidClient.WithHintf("This requested OAuth 2.0 client only supports client authentication method '%s', however that method is not supported by this server.", method)) - } + return id, secret, len(id) != 0 && len(secret) != 0 } -func (s *DefaultClientAuthenticationStrategy) doAuthenticateAssertionParseAssertionJWTBearerFindKeyClientSecretJWT(_ context.Context, _, alg string, client AuthenticationMethodClient, handler EndpointClientAuthHandler) (key any, err error) { - switch alg { - case xjwt.SigningMethodHS256.Alg(), xjwt.SigningMethodHS384.Alg(), xjwt.SigningMethodRS512.Alg(): - secret := client.GetClientSecret() - - if secret == nil || !secret.IsPlainText() { - return nil, errorsx.WithStack(ErrInvalidClient.WithHint("The requested OAuth 2.0 client does not support the client authentication method 'client_secret_jwt' ")) - } - - if key, err = secret.GetPlainTextValue(); err != nil { - return nil, errorsx.WithStack(ErrInvalidClient.WithHint("The requested OAuth 2.0 client does not support the client authentication method 'client_secret_jwt' ")) - } +func resolveJWTErrorToRFCError(err error) (rfc error) { + var e *RFC6749Error - return key, nil - default: - return nil, errorsx.WithStack(ErrInvalidClient.WithHintf("The requested OAuth 2.0 client does not support the '%s_endpoint_auth_signing_alg' value '%s'.", handler.Name(), alg)) + if errors.As(err, &e) { + return errorsx.WithStack(e) } -} -func (s *DefaultClientAuthenticationStrategy) doAuthenticateAssertionParseAssertionJWTBearerFindKeyPrivateKeyJWT(ctx context.Context, kid, alg string, client AuthenticationMethodClient, handler EndpointClientAuthHandler) (key any, err error) { - switch alg { - case xjwt.SigningMethodRS256.Alg(), xjwt.SigningMethodRS384.Alg(), xjwt.SigningMethodRS512.Alg(), - xjwt.SigningMethodPS256.Alg(), xjwt.SigningMethodPS384.Alg(), xjwt.SigningMethodPS512.Alg(), - xjwt.SigningMethodES256.Alg(), xjwt.SigningMethodES384.Alg(), xjwt.SigningMethodES512.Alg(): - if key, err = FindClientPublicJWK(ctx, s.Config, client, kid, alg, "sig"); err != nil { - return nil, err + if errJWTValidation := new(jwt.ValidationError); errors.As(err, &errJWTValidation) { + switch { + case errJWTValidation.Has(jwt.ValidationErrorMalformed): + e = ErrInvalidClient. + WithHint("OAuth 2.0 client provided a client assertion which could not be decoded or validated."). + WithWrap(err). + WithDebugf("OAuth 2.0 client provided a client assertion that was malformed. %s.", strings.TrimPrefix(errJWTValidation.Error(), "go-jose/go-jose: ")) + case errJWTValidation.Has(jwt.ValidationErrorMalformedNotCompactSerialized): + e = ErrInvalidClient. + WithHint("OAuth 2.0 client provided a client assertion which could not be decoded or validated."). + WithWrap(err). + WithDebugf("OAuth 2.0 client provided a client assertion that was malformed. The client assertion does not appear to be a JWE or JWS compact serialized JWT.") + case errJWTValidation.Has(jwt.ValidationErrorUnverifiable): + e = ErrInvalidClient. + WithHint("OAuth 2.0 client provided a client assertion which could not be decoded or validated."). + WithWrap(err). + WithDebugf("OAuth 2.0 client provided a client assertion that was not able to be verified. %s.", strings.TrimPrefix(errJWTValidation.Error(), "go-jose/go-jose: ")) + default: + e = ErrInvalidClient. + WithHint("OAuth 2.0 client provided a client assertion which could not be decoded or validated."). + WithWrap(err). + WithDebugf("Unknown error occurred handling the client assertion.") } - - return key, nil - default: - return nil, errorsx.WithStack(ErrInvalidClient.WithHintf("The requested OAuth 2.0 client does not support the '%s_endpoint_auth_signing_alg' value '%s'.", handler.Name(), alg)) } -} - -func (s *DefaultClientAuthenticationStrategy) getClientCredentialsSecretPost(form url.Values) (id, secret string, ok bool) { - id, secret = form.Get(consts.FormParameterClientID), form.Get(consts.FormParameterClientSecret) - return id, secret, len(id) != 0 && len(secret) != 0 + return errorsx.WithStack(e) } -func resolveJWTErrorToRFCError(err error) (rfc error) { - var e *RFC6749Error - - switch { - case errors.As(err, &e): - return errorsx.WithStack(e) - case errors.Is(err, xjwt.ErrTokenMalformed): - return errorsx.WithStack(ErrInvalidClient.WithHint("Unable to decode the 'client_assertion' value as it is malformed or incomplete.").WithWrap(err).WithDebugError(err)) - case errors.Is(err, xjwt.ErrTokenUnverifiable): - return errorsx.WithStack(ErrInvalidClient.WithHint("Unable to decode the 'client_assertion' value as it is missing the information required to validate it.").WithWrap(err).WithDebugError(err)) - case errors.Is(err, xjwt.ErrTokenNotValidYet), errors.Is(err, xjwt.ErrTokenExpired), errors.Is(err, xjwt.ErrTokenUsedBeforeIssued): - return errorsx.WithStack(ErrInvalidClient.WithHint("Unable to verify the integrity of the 'client_assertion' value. It may have been used before it was issued, may have been used before it's allowed to be used, may have been used after it's expired, or otherwise doesn't meet a particular validation constraint.").WithWrap(err).WithDebugError(err)) - default: - return errorsx.WithStack(ErrInvalidClient.WithHint("Unable to decode 'client_assertion' value for an unknown reason.").WithWrap(err).WithDebugError(err)) +func fmtClientAssertionDecodeError(token *jwt.Token, client AuthenticationMethodClient, handler EndpointClientAuthHandler, audience []string, inner error) (outer *RFC6749Error) { + outer = ErrInvalidClient.WithWrap(inner).WithHintf("OAuth 2.0 client with id '%s' provided a client assertion which could not be decoded or validated.", client.GetID()) + + if errJWTValidation := new(jwt.ValidationError); errors.As(inner, &errJWTValidation) { + switch { + case errJWTValidation.Has(jwt.ValidationErrorHeaderKeyIDInvalid): + return outer.WithDebugf("OAuth 2.0 client with id '%s' expects client assertions to be signed with the 'kid' header value '%s' due to the client registration 'request_object_signing_key_id' value but the client assertion was signed with the 'kid' header value '%s'.", client.GetID(), handler.GetAuthSigningKeyID(client), token.KeyID) + case errJWTValidation.Has(jwt.ValidationErrorHeaderAlgorithmInvalid): + return outer.WithDebugf("OAuth 2.0 client with id '%s' expects client assertions to be signed with the 'alg' header value '%s' due to the client registration 'request_object_signing_alg' value but the client assertion was signed with the 'alg' header value '%s'.", client.GetID(), handler.GetAuthSigningAlg(client), token.SignatureAlgorithm) + case errJWTValidation.Has(jwt.ValidationErrorHeaderTypeInvalid): + return outer.WithDebugf("OAuth 2.0 client with id '%s' expects client assertions to be signed with the 'typ' header value '%s' but the client assertion was signed with the 'typ' header value '%s'.", client.GetID(), consts.JSONWebTokenTypeJWT, token.Header[consts.JSONWebTokenHeaderType]) + case errJWTValidation.Has(jwt.ValidationErrorHeaderEncryptionTypeInvalid): + return outer.WithDebugf("OAuth 2.0 client with id '%s' expects client assertions to be encrypted with the 'typ' header value '%s' but the client assertion was encrypted with the 'typ' header value '%s'.", client.GetID(), consts.JSONWebTokenTypeJWT, token.HeaderJWE[consts.JSONWebTokenHeaderType]) + case errJWTValidation.Has(jwt.ValidationErrorHeaderContentTypeInvalidMismatch): + return outer.WithDebugf("OAuth 2.0 client with id '%s' expects client assertions to be encrypted with a 'cty' header value and signed with a 'typ' value that match but the client assertions was encrypted with the 'cty' header value '%s' and signed with the 'typ' header value '%s'.", client.GetID(), token.HeaderJWE[consts.JSONWebTokenHeaderContentType], token.HeaderJWE[consts.JSONWebTokenHeaderType]) + case errJWTValidation.Has(jwt.ValidationErrorHeaderContentTypeInvalid): + return outer.WithDebugf("OAuth 2.0 client with id '%s' expects client assertions to be encrypted with the 'cty' header value '%s' but the client assertion was encrypted with the 'cty' header value '%s'.", client.GetID(), consts.JSONWebTokenTypeJWT, token.HeaderJWE[consts.JSONWebTokenHeaderContentType]) + case errJWTValidation.Has(jwt.ValidationErrorHeaderEncryptionKeyIDInvalid): + return outer.WithDebugf("OAuth 2.0 client with id '%s' expects client assertions to be encrypted with the 'kid' header value '%s' due to the client registration 'request_object_encryption_key_id' value but the client assertion was encrypted with the 'kid' header value '%s'.", client.GetID(), handler.GetAuthEncryptionKeyID(client), token.EncryptionKeyID) + case errJWTValidation.Has(jwt.ValidationErrorHeaderKeyAlgorithmInvalid): + return outer.WithDebugf("OAuth 2.0 client with id '%s' expects client assertions to be encrypted with the 'alg' header value '%s' due to the client registration 'request_object_encryption_alg' value but the client assertion was encrypted with the 'alg' header value '%s'.", client.GetID(), handler.GetAuthEncryptionAlg(client), token.KeyAlgorithm) + case errJWTValidation.Has(jwt.ValidationErrorHeaderContentEncryptionInvalid): + return outer.WithDebugf("OAuth 2.0 client with id '%s' expects client assertions to be encrypted with the 'enc' header value '%s' due to the client registration 'request_object_encryption_enc' value but the client assertion was encrypted with the 'enc' header value '%s'.", client.GetID(), handler.GetAuthEncryptionEnc(client), token.ContentEncryption) + case errJWTValidation.Has(jwt.ValidationErrorMalformedNotCompactSerialized): + return outer.WithDebugf("OAuth 2.0 client with id '%s' provided a client assertion that was malformed. The client assertion does not appear to be a JWE or JWS compact serialized JWT.", client.GetID()) + case errJWTValidation.Has(jwt.ValidationErrorMalformed): + return outer.WithDebugf("OAuth 2.0 client with id '%s' provided a client assertion that was malformed. %s.", client.GetID(), strings.TrimPrefix(errJWTValidation.Error(), "go-jose/go-jose: ")) + case errJWTValidation.Has(jwt.ValidationErrorUnverifiable): + return outer.WithDebugf("OAuth 2.0 client with id '%s' provided a client assertion that was not able to be verified. %s.", client.GetID(), strings.TrimPrefix(errJWTValidation.Error(), "go-jose/go-jose: ")) + case errJWTValidation.Has(jwt.ValidationErrorSignatureInvalid): + return outer.WithDebugf("OAuth 2.0 client with id '%s' provided a client assertion that has an invalid signature. %s.", client.GetID(), strings.TrimPrefix(errJWTValidation.Error(), "go-jose/go-jose: ")) + case errJWTValidation.Has(jwt.ValidationErrorExpired): + exp, err := token.Claims.GetExpirationTime() + if err == nil { + return outer.WithDebugf("OAuth 2.0 client with id '%s' provided a client assertion that was expired. The client assertion expired at %d.", client.GetID(), exp.Int64()) + } else { + return outer.WithDebugf("OAuth 2.0 client with id '%s' provided a client assertion that was expired. The client assertion does not have an 'exp' claim or it has an invalid type.", client.GetID()) + } + case errJWTValidation.Has(jwt.ValidationErrorIssuedAt): + iat, err := token.Claims.GetIssuedAt() + if err == nil { + return outer.WithDebugf("OAuth 2.0 client with id '%s' provided a client assertion that was issued in the future. The client assertion was issued at %d.", client.GetID(), iat.Int64()) + } else { + return outer.WithDebugf("OAuth 2.0 client with id '%s' provided a client assertion that was issued in the future. The client assertion does not have an 'iat' claim or it has an invalid type.", client.GetID()) + } + case errJWTValidation.Has(jwt.ValidationErrorNotValidYet): + nbf, err := token.Claims.GetNotBefore() + if err == nil { + return outer.WithDebugf("OAuth 2.0 client with id '%s' provided a client assertion that was issued in the future. The client assertion is not valid before %d.", client.GetID(), nbf.Int64()) + } else { + return outer.WithDebugf("OAuth 2.0 client with id '%s' provided a client assertion that was issued in the future. The client assertion does not have an 'nbf' claim or it has an invalid type.", client.GetID()) + } + case errJWTValidation.Has(jwt.ValidationErrorIssuer): + iss, err := token.Claims.GetIssuer() + if err == nil { + return outer.WithDebugf("OAuth 2.0 client with id '%s' provided a client assertion that has an invalid issuer. The client assertion was expected to have an 'iss' claim which matches the value '%s' but the 'iss' claim had the value '%s'.", client.GetID(), client.GetID(), iss) + } else { + return outer.WithDebugf("OAuth 2.0 client with id '%s' provided a client assertion that has an invalid issuer. The client assertion does not have an 'iss' claim or it has an invalid type.", client.GetID()) + } + case errJWTValidation.Has(jwt.ValidationErrorAudience): + aud, err := token.Claims.GetAudience() + if err == nil { + return outer.WithDebugf("OAuth 2.0 client with id '%s' provided a client assertion that has an invalid audience. The client assertion was expected to have an 'aud' claim which matches one of the values '%s' but the 'aud' claim had the values '%s'.", client.GetID(), strings.Join(audience, "', '"), strings.Join(aud, "', '")) + } else { + return outer.WithDebugf("OAuth 2.0 client with id '%s' provided a client assertion that has an invalid audience. The client assertion does not have an 'aud' claim or it has an invalid type.", client.GetID()) + } + case errJWTValidation.Has(jwt.ValidationErrorClaimsInvalid): + return outer.WithDebugf("OAuth 2.0 client with id '%s' provided a client assertion that had one or more invalid claims. Error occurred trying to validate the client assertions claims: %s", client.GetID(), strings.TrimPrefix(errJWTValidation.Error(), "go-jose/go-jose: ")) + default: + return outer.WithDebugf("OAuth 2.0 client with id '%s' provided a client assertion that could not be validated. Error occurred trying to validate the client assertion: %s", client.GetID(), strings.TrimPrefix(errJWTValidation.Error(), "go-jose/go-jose: ")) + } + } else if errJWKLookup := new(jwt.JWKLookupError); errors.As(inner, &errJWKLookup) { + return outer.WithDebugf("OAuth 2.0 client with id '%s' provided a client assertion that could not be validated due to a key lookup error. %s.", client.GetID(), errJWKLookup.Description) + } else { + return outer.WithDebugf("OAuth 2.0 client with id '%s' provided a client assertion that could not be validated. %s.", client.GetID(), ErrorToDebugRFC6749Error(inner).Error()) } } diff --git a/client_authentication_test.go b/client_authentication_test.go index 6ce06cfd..142910bf 100644 --- a/client_authentication_test.go +++ b/client_authentication_test.go @@ -12,6 +12,7 @@ import ( "net/http" "net/http/httptest" "net/url" + "regexp" "testing" "time" @@ -29,12 +30,14 @@ import ( func TestAuthenticateClient(t *testing.T) { keyRSA := gen.MustRSAKey() + jwksRSA := &jose.JSONWebKeySet{ Keys: []jose.JSONWebKey{ { - KeyID: "kid-foo", - Use: "sig", - Key: &keyRSA.PublicKey, + KeyID: "kid-foo", + Use: "sig", + Algorithm: "RS256", + Key: &keyRSA.PublicKey, }, }, } @@ -43,9 +46,27 @@ func TestAuthenticateClient(t *testing.T) { jwksECDSA := &jose.JSONWebKeySet{ Keys: []jose.JSONWebKey{ { - KeyID: "kid-foo", - Use: "sig", - Key: &keyECDSA.PublicKey, + KeyID: "kid-foo", + Use: "sig", + Algorithm: "ES256", + Key: &keyECDSA.PublicKey, + }, + }, + } + + jwks := &jose.JSONWebKeySet{ + Keys: []jose.JSONWebKey{ + { + KeyID: "kid-foo", + Use: "sig", + Algorithm: "RS256", + Key: &keyRSA.PublicKey, + }, + { + KeyID: "kid-foo", + Use: "sig", + Algorithm: "ES256", + Key: &keyECDSA.PublicKey, }, }, } @@ -60,6 +81,7 @@ func TestAuthenticateClient(t *testing.T) { r *http.Request form url.Values err string + errRegexp *regexp.Regexp expectErr error }{ { @@ -372,7 +394,7 @@ func TestAuthenticateClient(t *testing.T) { { name: "ShouldFailBecauseRSAAssertionIsUsedButECDSAAssertionIsRequired", client: func(ts *httptest.Server) Client { - return &DefaultJARClient{DefaultClient: &DefaultClient{ID: "bar", ClientSecret: testClientSecretBar}, JSONWebKeys: jwksECDSA, TokenEndpointAuthMethod: "private_key_jwt", TokenEndpointAuthSigningAlg: "ES256"} + return &DefaultJARClient{DefaultClient: &DefaultClient{ID: "bar", ClientSecret: testClientSecretBar}, JSONWebKeys: jwks, TokenEndpointAuthMethod: "private_key_jwt", TokenEndpointAuthSigningAlg: "ES256"} }, form: url.Values{"client_assertion": {mustGenerateRSAAssertion(t, jwt.MapClaims{ consts.ClaimSubject: "bar", consts.ClaimExpirationTime: time.Now().Add(time.Hour).Unix(), @@ -382,7 +404,7 @@ func TestAuthenticateClient(t *testing.T) { }, keyRSA, "kid-foo")}, "client_assertion_type": []string{consts.ClientAssertionTypeJWTBearer}}, r: new(http.Request), expectErr: ErrInvalidClient, - err: "Client authentication failed (e.g., unknown client, no client authentication included, or unsupported authentication method). The requested OAuth 2.0 client does not support the 'token_endpoint_auth_signing_alg' value 'RS256'. The registered OAuth 2.0 client with id 'bar' only supports the 'ES256' algorithm.", + err: "Client authentication failed (e.g., unknown client, no client authentication included, or unsupported authentication method). OAuth 2.0 client with id 'bar' provided a client assertion which could not be decoded or validated. OAuth 2.0 client with id 'bar' expects client assertions to be signed with the 'alg' header value 'ES256' due to the client registration 'request_object_signing_alg' value but the client assertion was signed with the 'alg' header value 'RS256'.", }, { name: "ShouldFailBecauseMalformedAssertionUsed", @@ -391,7 +413,7 @@ func TestAuthenticateClient(t *testing.T) { }, form: url.Values{"client_assertion": []string{"bad.assertion"}, "client_assertion_type": []string{consts.ClientAssertionTypeJWTBearer}}, r: new(http.Request), expectErr: ErrInvalidClient, - err: "Client authentication failed (e.g., unknown client, no client authentication included, or unsupported authentication method). Unable to decode the 'client_assertion' value as it is malformed or incomplete. token is malformed: token contains an invalid number of segments", + err: "Client authentication failed (e.g., unknown client, no client authentication included, or unsupported authentication method). OAuth 2.0 client provided a client assertion which could not be decoded or validated. OAuth 2.0 client provided a client assertion that was malformed. The client assertion does not appear to be a JWE or JWS compact serialized JWT.", }, { name: "ShouldFailBecauseExpired", @@ -406,7 +428,7 @@ func TestAuthenticateClient(t *testing.T) { }, keyECDSA, "kid-foo")}, "client_assertion_type": []string{consts.ClientAssertionTypeJWTBearer}}, r: new(http.Request), expectErr: ErrInvalidClient, - err: "Client authentication failed (e.g., unknown client, no client authentication included, or unsupported authentication method). Unable to verify the integrity of the 'client_assertion' value. It may have been used before it was issued, may have been used before it's allowed to be used, may have been used after it's expired, or otherwise doesn't meet a particular validation constraint. token has invalid claims: token is expired", + errRegexp: regexp.MustCompile(`^Client authentication failed \(e\.g\., unknown client, no client authentication included, or unsupported authentication method\)\. OAuth 2\.0 client with id 'bar' provided a client assertion which could not be decoded or validated\. OAuth 2\.0 client with id 'bar' provided a client assertion that was expired\. The client assertion expired at \d+\.$`), }, { name: "ShouldFailBecauseNotBefore", @@ -422,7 +444,7 @@ func TestAuthenticateClient(t *testing.T) { }, keyECDSA, "kid-foo")}, "client_assertion_type": []string{consts.ClientAssertionTypeJWTBearer}}, r: new(http.Request), expectErr: ErrInvalidClient, - err: "Client authentication failed (e.g., unknown client, no client authentication included, or unsupported authentication method). Unable to verify the integrity of the 'client_assertion' value. It may have been used before it was issued, may have been used before it's allowed to be used, may have been used after it's expired, or otherwise doesn't meet a particular validation constraint. token has invalid claims: token is not valid yet", + errRegexp: regexp.MustCompile(`^Client authentication failed \(e\.g\., unknown client, no client authentication included, or unsupported authentication method\)\. OAuth 2\.0 client with id 'bar' provided a client assertion which could not be decoded or validated\. OAuth 2\.0 client with id 'bar' provided a client assertion that was issued in the future\. The client assertion is not valid before \d+\.$`), }, { name: "ShouldFailBecauseIssuedInFuture", @@ -438,7 +460,7 @@ func TestAuthenticateClient(t *testing.T) { }, keyECDSA, "kid-foo")}, "client_assertion_type": []string{consts.ClientAssertionTypeJWTBearer}}, r: new(http.Request), expectErr: ErrInvalidClient, - err: "Client authentication failed (e.g., unknown client, no client authentication included, or unsupported authentication method). Unable to verify the integrity of the 'client_assertion' value. It may have been used before it was issued, may have been used before it's allowed to be used, may have been used after it's expired, or otherwise doesn't meet a particular validation constraint. token has invalid claims: token used before issued", + errRegexp: regexp.MustCompile(`^Client authentication failed \(e\.g\., unknown client, no client authentication included, or unsupported authentication method\)\. OAuth 2\.0 client with id 'bar' provided a client assertion which could not be decoded or validated. OAuth 2\.0 client with id 'bar' provided a client assertion that was issued in the future\. The client assertion was issued at \d+\.$`), }, { name: "ShouldFailBecauseNoKeys", @@ -453,10 +475,10 @@ func TestAuthenticateClient(t *testing.T) { }, keyECDSA, "kid-foo")}, "client_assertion_type": []string{consts.ClientAssertionTypeJWTBearer}}, r: new(http.Request), expectErr: ErrInvalidClient, - err: "Client authentication failed (e.g., unknown client, no client authentication included, or unsupported authentication method). The OAuth 2.0 Client has no JSON Web Keys set registered, but they are needed to complete the request.", + err: "Client authentication failed (e.g., unknown client, no client authentication included, or unsupported authentication method). OAuth 2.0 client with id 'bar' provided a client assertion which could not be decoded or validated. OAuth 2.0 client with id 'bar' provided a client assertion that was not able to be verified. Error occurred retrieving the JSON Web Key. No JWKs have been registered for the client.", }, { - name: "ShouldFailBecauseNotBefore", + name: "ShouldFailBecauseNotBeforeAlternative", client: func(ts *httptest.Server) Client { return &DefaultJARClient{DefaultClient: &DefaultClient{ID: "bar", ClientSecret: testClientSecretBar}, JSONWebKeys: jwksECDSA, TokenEndpointAuthMethod: "private_key_jwt", TokenEndpointAuthSigningAlg: "ES256"} }, form: url.Values{"client_assertion": {mustGenerateECDSAAssertion(t, jwt.MapClaims{ @@ -469,7 +491,7 @@ func TestAuthenticateClient(t *testing.T) { }, keyECDSA, "kid-foo")}, "client_assertion_type": []string{consts.ClientAssertionTypeJWTBearer}}, r: new(http.Request), expectErr: ErrInvalidClient, - err: "Client authentication failed (e.g., unknown client, no client authentication included, or unsupported authentication method). Unable to verify the integrity of the 'client_assertion' value. It may have been used before it was issued, may have been used before it's allowed to be used, may have been used after it's expired, or otherwise doesn't meet a particular validation constraint. token has invalid claims: token is not valid yet", + errRegexp: regexp.MustCompile(`^Client authentication failed \(e\.g\., unknown client, no client authentication included, or unsupported authentication method\)\. OAuth 2\.0 client with id 'bar' provided a client assertion which could not be decoded or validated\. OAuth 2\.0 client with id 'bar' provided a client assertion that was issued in the future\. The client assertion is not valid before \d+\.$`), }, { name: "ShouldFailBecauseTokenAuthMethodIsNotPrivateKeyJwtButClientSecretJwt", @@ -484,7 +506,7 @@ func TestAuthenticateClient(t *testing.T) { }, keyRSA, "kid-foo")}, "client_assertion_type": []string{consts.ClientAssertionTypeJWTBearer}}, r: new(http.Request), expectErr: ErrInvalidClient, - err: "Client authentication failed (e.g., unknown client, no client authentication included, or unsupported authentication method). The requested OAuth 2.0 client does not support the 'token_endpoint_auth_signing_alg' value 'RS256'.", + err: "Client authentication failed (e.g., unknown client, no client authentication included, or unsupported authentication method). The request was determined to be using 'token_endpoint_auth_method' method 'private_key_jwt', however the OAuth 2.0 client registration does not allow this method. The registered client with id 'bar' is configured to only support 'token_endpoint_auth_method' method 'client_secret_jwt'. Either the Authorization Server client registration will need to have the 'token_endpoint_auth_method' updated to 'private_key_jwt' or the Relying Party will need to be configured to use 'client_secret_jwt'.", }, { name: "ShouldFailBecauseTokenAuthMethodIsNotPrivateKeyJwtButNone", @@ -499,7 +521,7 @@ func TestAuthenticateClient(t *testing.T) { }, keyRSA, "kid-foo")}, "client_assertion_type": []string{consts.ClientAssertionTypeJWTBearer}}, r: new(http.Request), expectErr: ErrInvalidClient, - err: "Client authentication failed (e.g., unknown client, no client authentication included, or unsupported authentication method). This requested OAuth 2.0 client does not support client authentication, however 'client_assertion' was provided in the request.", + err: "Client authentication failed (e.g., unknown client, no client authentication included, or unsupported authentication method). The request was determined to be using 'token_endpoint_auth_method' method 'private_key_jwt', however the OAuth 2.0 client registration does not allow this method. The registered client with id 'bar' is configured to only support 'token_endpoint_auth_method' method 'none'. Either the Authorization Server client registration will need to have the 'token_endpoint_auth_method' updated to 'private_key_jwt' or the Relying Party will need to be configured to use 'none'.", }, { name: "ShouldFailBecauseTokenAuthMethodIsNotPrivateKeyJwtButClientSecretPost", @@ -514,7 +536,7 @@ func TestAuthenticateClient(t *testing.T) { }, keyRSA, "kid-foo")}, "client_assertion_type": []string{consts.ClientAssertionTypeJWTBearer}}, r: new(http.Request), expectErr: ErrInvalidClient, - err: "Client authentication failed (e.g., unknown client, no client authentication included, or unsupported authentication method). This requested OAuth 2.0 client only supports client authentication method 'client_secret_post', however 'client_assertion' was provided in the request.", + err: "Client authentication failed (e.g., unknown client, no client authentication included, or unsupported authentication method). The request was determined to be using 'token_endpoint_auth_method' method 'private_key_jwt', however the OAuth 2.0 client registration does not allow this method. The registered client with id 'bar' is configured to only support 'token_endpoint_auth_method' method 'client_secret_post'. Either the Authorization Server client registration will need to have the 'token_endpoint_auth_method' updated to 'private_key_jwt' or the Relying Party will need to be configured to use 'client_secret_post'.", }, { name: "ShouldFailBecauseTokenAuthMethodIsNotPrivateKeyJwtButClientSecretBasic", @@ -529,7 +551,7 @@ func TestAuthenticateClient(t *testing.T) { }, keyRSA, "kid-foo")}, "client_assertion_type": []string{consts.ClientAssertionTypeJWTBearer}}, r: new(http.Request), expectErr: ErrInvalidClient, - err: "Client authentication failed (e.g., unknown client, no client authentication included, or unsupported authentication method). This requested OAuth 2.0 client only supports client authentication method 'client_secret_basic', however 'client_assertion' was provided in the request.", + err: "Client authentication failed (e.g., unknown client, no client authentication included, or unsupported authentication method). The request was determined to be using 'token_endpoint_auth_method' method 'private_key_jwt', however the OAuth 2.0 client registration does not allow this method. The registered client with id 'bar' is configured to only support 'token_endpoint_auth_method' method 'client_secret_basic'. Either the Authorization Server client registration will need to have the 'token_endpoint_auth_method' updated to 'private_key_jwt' or the Relying Party will need to be configured to use 'client_secret_basic'.", }, { name: "ShouldFailBecauseTokenAuthMethodIsNotPrivateKeyJwtButFoobar", @@ -571,7 +593,7 @@ func TestAuthenticateClient(t *testing.T) { }, keyRSA, "kid-foo")}, "client_assertion_type": []string{consts.ClientAssertionTypeJWTBearer}}, r: new(http.Request), expectErr: ErrInvalidClient, - err: "Client authentication failed (e.g., unknown client, no client authentication included, or unsupported authentication method). Unable to verify the integrity of the 'client_assertion' value. It may have been used before it was issued, may have been used before it's allowed to be used, may have been used after it's expired, or otherwise doesn't meet a particular validation constraint. Unable to validate the 'aud' claim of the 'client_assertion' value 'token-url-1', 'token-url-2' as it doesn't match any of the expected values 'token-url'.", + err: "Client authentication failed (e.g., unknown client, no client authentication included, or unsupported authentication method). OAuth 2.0 client with id 'bar' provided a client assertion which could not be decoded or validated. OAuth 2.0 client with id 'bar' provided a client assertion that has an invalid audience. The client assertion was expected to have an 'aud' claim which matches one of the values 'token-url' but the 'aud' claim had the values 'token-url-1', 'token-url-2'.", }, { name: "ShouldPassWithProperAssertionWhenJWKsAreSetWithinTheClient", @@ -613,7 +635,7 @@ func TestAuthenticateClient(t *testing.T) { }, keyRSA, "kid-foo")}, "client_assertion_type": []string{consts.ClientAssertionTypeJWTBearer}}, r: new(http.Request), expectErr: ErrInvalidClient, - err: "Client authentication failed (e.g., unknown client, no client authentication included, or unsupported authentication method). The requested OAuth 2.0 client does not support the 'token_endpoint_auth_signing_alg' value 'none'. The registered OAuth 2.0 client with id 'bar' only supports the 'RS256' algorithm.", + err: "Client authentication failed (e.g., unknown client, no client authentication included, or unsupported authentication method). OAuth 2.0 client with id 'bar' provided a client assertion which could not be decoded or validated. OAuth 2.0 client with id 'bar' expects client assertions to be signed with the 'alg' header value 'RS256' due to the client registration 'request_object_signing_alg' value but the client assertion was signed with the 'alg' header value 'none'.", }, { name: "ShouldPassWithProperAssertionWhenJWKsURIIsSet", @@ -641,7 +663,7 @@ func TestAuthenticateClient(t *testing.T) { }, keyRSA, "kid-foo")}, "client_assertion_type": []string{consts.ClientAssertionTypeJWTBearer}}, r: new(http.Request), expectErr: ErrInvalidClient, - err: "Client authentication failed (e.g., unknown client, no client authentication included, or unsupported authentication method). The supplied 'client_id' did not match the 'sub' claim of the 'client_assertion'.", + err: "Client authentication failed (e.g., unknown client, no client authentication included, or unsupported authentication method). The client assertion had invalid claims. Claim 'sub' from 'client_assertion' must match the 'client_id' of the OAuth 2.0 Client.", }, { name: "ShouldFailBecauseClientAssertionIssDoesNotMatchClient", @@ -656,10 +678,10 @@ func TestAuthenticateClient(t *testing.T) { }, keyRSA, "kid-foo")}, "client_assertion_type": []string{consts.ClientAssertionTypeJWTBearer}}, r: new(http.Request), expectErr: ErrInvalidClient, - err: "Client authentication failed (e.g., unknown client, no client authentication included, or unsupported authentication method). Claim 'iss' from 'client_assertion' must match the 'client_id' of the OAuth 2.0 Client.", + err: "Client authentication failed (e.g., unknown client, no client authentication included, or unsupported authentication method). The client assertion had invalid claims. Claim 'iss' from 'client_assertion' must match the 'client_id' of the OAuth 2.0 Client.", }, { - name: "ShouldFailBecauseClientAssertionJtiIsNotSet", + name: "ShouldFailBecauseClientAssertionJTIClaimIsNotSet", client: func(ts *httptest.Server) Client { return &DefaultJARClient{DefaultClient: &DefaultClient{ID: "bar", ClientSecret: testClientSecretBar}, JSONWebKeys: jwksRSA, TokenEndpointAuthMethod: "private_key_jwt"} }, form: url.Values{"client_id": []string{"bar"}, "client_assertion": {mustGenerateRSAAssertion(t, jwt.MapClaims{ @@ -670,7 +692,7 @@ func TestAuthenticateClient(t *testing.T) { }, keyRSA, "kid-foo")}, "client_assertion_type": []string{consts.ClientAssertionTypeJWTBearer}}, r: new(http.Request), expectErr: ErrInvalidClient, - err: "Client authentication failed (e.g., unknown client, no client authentication included, or unsupported authentication method). Claim 'jti' from 'client_assertion' must be set but is not.", + err: "Client authentication failed (e.g., unknown client, no client authentication included, or unsupported authentication method). The client assertion had invalid claims. Claim 'jti' from 'client_assertion' must be set but is not.", }, { name: "ShouldFailBecauseClientAssertionAudIsNotSet", @@ -694,13 +716,20 @@ func TestAuthenticateClient(t *testing.T) { for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { + config := &Config{ + JWKSFetcherStrategy: NewDefaultJWKSFetcherStrategy(), + AllowedJWTAssertionAudiences: []string{"token-url"}, + HTTPClient: retryablehttp.NewClient(), + } + + config.JWTStrategy = &jwt.DefaultStrategy{ + Config: config, + Issuer: jwt.NewDefaultIssuerUnverifiedFromJWKS(jwks), + } + provider := &Fosite{ - Store: storage.NewMemoryStore(), - Config: &Config{ - JWKSFetcherStrategy: NewDefaultJWKSFetcherStrategy(), - AllowedJWTAssertionAudiences: []string{"token-url"}, - HTTPClient: retryablehttp.NewClient(), - }, + Store: storage.NewMemoryStore(), + Config: config, } var h http.HandlerFunc = func(w http.ResponseWriter, r *http.Request) { @@ -720,11 +749,7 @@ func TestAuthenticateClient(t *testing.T) { c, _, err := provider.AuthenticateClient(context.Background(), tc.r, tc.form) - if len(tc.err) != 0 { - require.EqualError(t, ErrorToDebugRFC6749Error(err), tc.err) - } - - if len(tc.err) == 0 && tc.expectErr == nil { + if len(tc.err) == 0 && tc.expectErr == nil && tc.errRegexp == nil { require.NoError(t, ErrorToDebugRFC6749Error(err)) assert.EqualValues(t, client, c) } else { @@ -735,6 +760,10 @@ func TestAuthenticateClient(t *testing.T) { if tc.expectErr != nil { assert.EqualError(t, err, tc.expectErr.Error()) } + + if tc.errRegexp != nil { + require.Regexp(t, tc.errRegexp, ErrorToDebugRFC6749Error(err).Error()) + } } }) } @@ -750,9 +779,10 @@ func TestAuthenticateClientTwice(t *testing.T) { JSONWebKeys: &jose.JSONWebKeySet{ Keys: []jose.JSONWebKey{ { - KeyID: "kid-foo", - Use: consts.JSONWebTokenUseSignature, - Key: &key.PublicKey, + KeyID: "kid-foo", + Use: consts.JSONWebTokenUseSignature, + Algorithm: "RS256", + Key: &key.PublicKey, }, }, }, @@ -761,12 +791,19 @@ func TestAuthenticateClientTwice(t *testing.T) { store := storage.NewMemoryStore() store.Clients[client.ID] = client + config := &Config{ + JWKSFetcherStrategy: NewDefaultJWKSFetcherStrategy(), + AllowedJWTAssertionAudiences: []string{"token-url"}, + } + + config.JWTStrategy = &jwt.DefaultStrategy{ + Config: config, + Issuer: jwt.NewDefaultIssuerRS256Unverified(key), + } + provider := &Fosite{ - Store: store, - Config: &Config{ - JWKSFetcherStrategy: NewDefaultJWKSFetcherStrategy(), - AllowedJWTAssertionAudiences: []string{"token-url"}, - }, + Store: store, + Config: config, } formValues := url.Values{"client_id": []string{"bar"}, "client_assertion": {mustGenerateRSAAssertion(t, jwt.MapClaims{ @@ -778,13 +815,14 @@ func TestAuthenticateClientTwice(t *testing.T) { }, key, "kid-foo")}, consts.FormParameterClientAssertionType: []string{consts.ClientAssertionTypeJWTBearer}} c, _, err := provider.AuthenticateClient(context.TODO(), new(http.Request), formValues) - require.NoError(t, err, "%#v", err) + require.NoError(t, ErrorToDebugRFC6749Error(err)) assert.Equal(t, client, c) // replay the request and expect it to fail c, _, err = provider.AuthenticateClient(context.TODO(), new(http.Request), formValues) require.Error(t, err) assert.EqualError(t, err, ErrJTIKnown.Error()) + assert.EqualError(t, ErrorToDebugRFC6749Error(err), "The jti was already used. Claim 'jti' from 'client_assertion' MUST only be used once. The jti was already used.") assert.Nil(t, c) } @@ -792,7 +830,7 @@ func TestAuthenticateClientTwice(t *testing.T) { func mustGenerateRSAAssertion(t *testing.T, claims jwt.MapClaims, key *rsa.PrivateKey, kid string) string { token := jwt.NewWithClaims(jose.RS256, claims) token.Header["kid"] = kid - tokenString, err := token.SignedString(key) + tokenString, err := token.CompactSignedString(key) require.NoError(t, err) return tokenString } @@ -800,7 +838,7 @@ func mustGenerateRSAAssertion(t *testing.T, claims jwt.MapClaims, key *rsa.Priva func mustGenerateECDSAAssertion(t *testing.T, claims jwt.MapClaims, key *ecdsa.PrivateKey, kid string) string { token := jwt.NewWithClaims(jose.ES256, claims) token.Header["kid"] = kid - tokenString, err := token.SignedString(key) + tokenString, err := token.CompactSignedString(key) require.NoError(t, err) return tokenString } @@ -808,7 +846,7 @@ func mustGenerateECDSAAssertion(t *testing.T, claims jwt.MapClaims, key *ecdsa.P //nolint:unparam func mustGenerateHSAssertion(t *testing.T, claims jwt.MapClaims, key *rsa.PrivateKey, kid string) string { token := jwt.NewWithClaims(jose.HS256, claims) - tokenString, err := token.SignedString([]byte("aaaaaaaaaaaaaaabbbbbbbbbbbbbbbbbbbbbbbcccccccccccccccccccccddddddddddddddddddddddd")) + tokenString, err := token.CompactSignedString([]byte("aaaaaaaaaaaaaaabbbbbbbbbbbbbbbbbbbbbbbcccccccccccccccccccccddddddddddddddddddddddd")) require.NoError(t, err) return tokenString } @@ -816,7 +854,7 @@ func mustGenerateHSAssertion(t *testing.T, claims jwt.MapClaims, key *rsa.Privat //nolint:unparam func mustGenerateNoneAssertion(t *testing.T, claims jwt.MapClaims, key *rsa.PrivateKey, kid string) string { token := jwt.NewWithClaims(jwt.SigningMethodNone, claims) - tokenString, err := token.SignedString(jwt.UnsafeAllowNoneSignatureType) + tokenString, err := token.CompactSignedString(jwt.UnsafeAllowNoneSignatureType) require.NoError(t, err) return tokenString } diff --git a/compose/compose.go b/compose/compose.go index 0b002776..73ef4b40 100644 --- a/compose/compose.go +++ b/compose/compose.go @@ -69,13 +69,19 @@ func ComposeAllEnabled(config *oauth2.Config, storage any, key any) oauth2.Provi keyGetter := func(context.Context) (any, error) { return key, nil } + + strategy := &jwt.DefaultStrategy{ + Config: config, + Issuer: jwt.NewDefaultIssuerRS256Unverified(key), + } + return Compose( config, storage, &CommonStrategy{ CoreStrategy: NewOAuth2HMACStrategy(config), - OpenIDConnectTokenStrategy: NewOpenIDConnectStrategy(keyGetter, config), - Signer: &jwt.DefaultSigner{GetPrivateKey: keyGetter}, + OpenIDConnectTokenStrategy: NewOpenIDConnectStrategy(keyGetter, strategy, config), + Strategy: strategy, }, OAuth2AuthorizeExplicitFactory, OAuth2AuthorizeImplicitFactory, diff --git a/compose/compose_oauth2.go b/compose/compose_oauth2.go index 44492c32..33cb1f28 100644 --- a/compose/compose_oauth2.go +++ b/compose/compose_oauth2.go @@ -111,7 +111,7 @@ func OAuth2TokenIntrospectionFactory(config oauth2.Configurator, storage any, st // If you need revocation, you can validate JWTs statefully, using the other factories. func OAuth2StatelessJWTIntrospectionFactory(config oauth2.Configurator, storage any, strategy any) any { return &hoauth2.StatelessJWTValidator{ - Signer: strategy.(jwt.Signer), - Config: config, + Strategy: strategy.(jwt.Strategy), + Config: config, } } diff --git a/compose/compose_openid.go b/compose/compose_openid.go index 4c4676dc..abe3b2e0 100644 --- a/compose/compose_openid.go +++ b/compose/compose_openid.go @@ -20,7 +20,7 @@ func OpenIDConnectExplicitFactory(config oauth2.Configurator, storage any, strat IDTokenHandleHelper: &openid.IDTokenHandleHelper{ IDTokenStrategy: strategy.(openid.OpenIDConnectTokenStrategy), }, - OpenIDConnectRequestValidator: openid.NewOpenIDConnectRequestValidator(strategy.(jwt.Signer), config), + OpenIDConnectRequestValidator: openid.NewOpenIDConnectRequestValidator(strategy.(jwt.Strategy), config), Config: config, } } @@ -51,7 +51,7 @@ func OpenIDConnectImplicitFactory(config oauth2.Configurator, storage any, strat IDTokenHandleHelper: &openid.IDTokenHandleHelper{ IDTokenStrategy: strategy.(openid.OpenIDConnectTokenStrategy), }, - OpenIDConnectRequestValidator: openid.NewOpenIDConnectRequestValidator(strategy.(jwt.Signer), config), + OpenIDConnectRequestValidator: openid.NewOpenIDConnectRequestValidator(strategy.(jwt.Strategy), config), } } @@ -77,14 +77,14 @@ func OpenIDConnectHybridFactory(config oauth2.Configurator, storage any, strateg IDTokenStrategy: strategy.(openid.OpenIDConnectTokenStrategy), }, OpenIDConnectRequestStorage: storage.(openid.OpenIDConnectRequestStorage), - OpenIDConnectRequestValidator: openid.NewOpenIDConnectRequestValidator(strategy.(jwt.Signer), config), + OpenIDConnectRequestValidator: openid.NewOpenIDConnectRequestValidator(strategy.(jwt.Strategy), config), } } func OpenIDConnectDeviceAuthorizeFactory(config oauth2.Configurator, storage any, strategy any) any { return &openid.OpenIDConnectDeviceAuthorizeHandler{ OpenIDConnectRequestStorage: storage.(openid.OpenIDConnectRequestStorage), - OpenIDConnectRequestValidator: openid.NewOpenIDConnectRequestValidator(strategy.(jwt.Signer), config), + OpenIDConnectRequestValidator: openid.NewOpenIDConnectRequestValidator(strategy.(jwt.Strategy), config), CodeTokenEndpointHandler: &rfc8628.DeviceCodeTokenHandler{ Strategy: strategy.(rfc8628.CodeStrategy), Storage: storage.(rfc8628.Storage), diff --git a/compose/compose_strategy.go b/compose/compose_strategy.go index e03dda1c..cc272ff4 100644 --- a/compose/compose_strategy.go +++ b/compose/compose_strategy.go @@ -16,7 +16,7 @@ import ( type CommonStrategy struct { hoauth2.CoreStrategy openid.OpenIDConnectTokenStrategy - jwt.Signer + jwt.Strategy } type HMACSHAStrategyConfigurator interface { @@ -37,17 +37,17 @@ func NewOAuth2HMACStrategy(config HMACSHAStrategyConfigurator) *hoauth2.HMACCore } } -func NewOAuth2JWTStrategy(keyGetter func(context.Context) (any, error), strategy *hoauth2.HMACCoreStrategy, config oauth2.Configurator) *hoauth2.JWTProfileCoreStrategy { +func NewOAuth2JWTStrategy(strategy jwt.Strategy, strategyHMAC *hoauth2.HMACCoreStrategy, config oauth2.Configurator) *hoauth2.JWTProfileCoreStrategy { return &hoauth2.JWTProfileCoreStrategy{ - Signer: &jwt.DefaultSigner{GetPrivateKey: keyGetter}, - HMACCoreStrategy: strategy, + Strategy: strategy, + HMACCoreStrategy: strategyHMAC, Config: config, } } -func NewOpenIDConnectStrategy(keyGetter func(context.Context) (any, error), config oauth2.Configurator) *openid.DefaultStrategy { +func NewOpenIDConnectStrategy(keyGetter func(context.Context) (any, error), strategy jwt.Strategy, config oauth2.Configurator) *openid.DefaultStrategy { return &openid.DefaultStrategy{ - Signer: &jwt.DefaultSigner{GetPrivateKey: keyGetter}, - Config: config, + Strategy: strategy, + Config: config, } } diff --git a/config.go b/config.go index 7c16abd8..49cdd182 100644 --- a/config.go +++ b/config.go @@ -101,10 +101,10 @@ type IntrospectionIssuerProvider interface { GetIntrospectionIssuer(ctx context.Context) (issuer string) } -// IntrospectionJWTResponseSignerProvider returns the provider for configuring the Introspection signer. -type IntrospectionJWTResponseSignerProvider interface { - // GetIntrospectionJWTResponseSigner returns the Introspection JWT signer. - GetIntrospectionJWTResponseSigner(ctx context.Context) jwt.Signer +// IntrospectionJWTResponseStrategyProvider returns the provider for configuring the Introspection jwt.Strategy. +type IntrospectionJWTResponseStrategyProvider interface { + // GetIntrospectionJWTResponseStrategy returns the Introspection JWT Strategy. + GetIntrospectionJWTResponseStrategy(ctx context.Context) jwt.Strategy } // AuthorizationServerIssuerIdentificationProvider provides OAuth 2.0 Authorization Server Issuer Identification related methods. @@ -124,10 +124,15 @@ type JWTSecuredAuthorizeResponseModeIssuerProvider interface { GetJWTSecuredAuthorizeResponseModeIssuer(ctx context.Context) string } -// JWTSecuredAuthorizeResponseModeSignerProvider returns the provider for configuring the JARM signer. -type JWTSecuredAuthorizeResponseModeSignerProvider interface { - // GetJWTSecuredAuthorizeResponseModeSigner returns the JARM signer. - GetJWTSecuredAuthorizeResponseModeSigner(ctx context.Context) jwt.Signer +// JWTSecuredAuthorizeResponseModeStrategyProvider returns the provider for configuring the JARM jwt.Strategy. +type JWTSecuredAuthorizeResponseModeStrategyProvider interface { + // GetJWTSecuredAuthorizeResponseModeStrategy returns the JARM Strategy. + GetJWTSecuredAuthorizeResponseModeStrategy(ctx context.Context) jwt.Strategy +} + +// JWTStrategyProvider returns the provider for configuring the jwt.Strategy. +type JWTStrategyProvider interface { + GetJWTStrategy(ctx context.Context) jwt.Strategy } // JWTSecuredAuthorizeResponseModeLifespanProvider returns the provider for configuring the JWT Secured Authorize Response Mode token lifespan. @@ -253,7 +258,7 @@ type RevokeRefreshTokensExplicitlyProvider interface { // JWKSFetcherStrategyProvider returns the provider for configuring the JWKS fetcher strategy. type JWKSFetcherStrategyProvider interface { // GetJWKSFetcherStrategy returns the JWKS fetcher strategy. - GetJWKSFetcherStrategy(ctx context.Context) (strategy JWKSFetcherStrategy) + GetJWKSFetcherStrategy(ctx context.Context) (strategy jwt.JWKSFetcherStrategy) } // HTTPClientProvider returns the provider for configuring the HTTP client. diff --git a/config_default.go b/config_default.go index f1a105f2..961f1856 100644 --- a/config_default.go +++ b/config_default.go @@ -54,8 +54,8 @@ type Config struct { // IntrospectionIssuer is the issuer to be used when generating signed introspection responses. IntrospectionIssuer string - // IntrospectionJWTResponseSigner is the signer for Introspection Responses. Has no default. - IntrospectionJWTResponseSigner jwt.Signer + // IntrospectionJWTResponseStrategy is the signer for Introspection Responses. Has no default. + IntrospectionJWTResponseStrategy jwt.Strategy // HashCost sets the cost of the password hashing cost. Defaults to 12. HashCost int @@ -102,7 +102,7 @@ type Config struct { // JWKSFetcherStrategy is responsible for fetching JSON Web Keys from remote URLs. This is required when the private_key_jwt // client authentication method is used. Defaults to oauth2.DefaultJWKSFetcherStrategy. - JWKSFetcherStrategy JWKSFetcherStrategy + JWKSFetcherStrategy jwt.JWKSFetcherStrategy // TokenEntropy indicates the entropy of the random string, used as the "message" part of the HMAC token. // Defaults to 32. @@ -175,8 +175,11 @@ type Config struct { // JWT Secured Authorization Response Mode. Defaults to 10 minutes. JWTSecuredAuthorizeResponseModeLifespan time.Duration - // JWTSecuredAuthorizeResponseModeSigner is the signer for JWT Secured Authorization Response Mode. Has no default. - JWTSecuredAuthorizeResponseModeSigner jwt.Signer + // JWTSecuredAuthorizeResponseModeStrategy is the signer for JWT Secured Authorization Response Mode. Has no default. + JWTSecuredAuthorizeResponseModeStrategy jwt.Strategy + + // JWTStrategy handles less specific jwt.Strategy cases. + JWTStrategy jwt.Strategy // EnforceJWTProfileAccessTokens forces the issuer to return JWT Profile Access Tokens to all clients. EnforceJWTProfileAccessTokens bool @@ -333,8 +336,8 @@ func (c *Config) GetIntrospectionIssuer(ctx context.Context) string { return c.IntrospectionIssuer } -func (c *Config) GetIntrospectionJWTResponseSigner(ctx context.Context) jwt.Signer { - return c.IntrospectionJWTResponseSigner +func (c *Config) GetIntrospectionJWTResponseStrategy(ctx context.Context) jwt.Strategy { + return c.IntrospectionJWTResponseStrategy } // GetGrantTypeJWTBearerIssuedDateOptional returns the GrantTypeJWTBearerIssuedDateOptional field. @@ -389,8 +392,18 @@ func (c *Config) GetJWTSecuredAuthorizeResponseModeIssuer(ctx context.Context) s return c.IDTokenIssuer } -func (c *Config) GetJWTSecuredAuthorizeResponseModeSigner(ctx context.Context) jwt.Signer { - return c.JWTSecuredAuthorizeResponseModeSigner +func (c *Config) GetJWTSecuredAuthorizeResponseModeStrategy(ctx context.Context) jwt.Strategy { + return c.JWTSecuredAuthorizeResponseModeStrategy +} + +func (c *Config) GetJWTStrategy(ctx context.Context) jwt.Strategy { + if c.JWTStrategy == nil { + c.JWTStrategy = &jwt.DefaultStrategy{ + Config: c, + } + } + + return c.JWTStrategy } func (c *Config) GetEnforceJWTProfileAccessTokens(ctx context.Context) (enable bool) { @@ -496,8 +509,8 @@ func (c *Config) GetBCryptCost(_ context.Context) int { return c.HashCost } -// GetJWKSFetcherStrategy returns the JWKSFetcherStrategy. -func (c *Config) GetJWKSFetcherStrategy(_ context.Context) JWKSFetcherStrategy { +// GetJWKSFetcherStrategy returns the jwt.JWKSFetcherStrategy. +func (c *Config) GetJWKSFetcherStrategy(_ context.Context) jwt.JWKSFetcherStrategy { if c.JWKSFetcherStrategy == nil { c.JWKSFetcherStrategy = NewDefaultJWKSFetcherStrategy() } @@ -628,7 +641,7 @@ var ( _ AccessTokenIssuerProvider = (*Config)(nil) _ JWTScopeFieldProvider = (*Config)(nil) _ JWTSecuredAuthorizeResponseModeIssuerProvider = (*Config)(nil) - _ JWTSecuredAuthorizeResponseModeSignerProvider = (*Config)(nil) + _ JWTSecuredAuthorizeResponseModeStrategyProvider = (*Config)(nil) _ JWTSecuredAuthorizeResponseModeLifespanProvider = (*Config)(nil) _ JWTProfileAccessTokensProvider = (*Config)(nil) _ AllowedPromptsProvider = (*Config)(nil) @@ -666,5 +679,5 @@ var ( _ RFC8628DeviceAuthorizeEndpointHandlersProvider = (*Config)(nil) _ RFC8628UserAuthorizeEndpointHandlersProvider = (*Config)(nil) _ IntrospectionIssuerProvider = (*Config)(nil) - _ IntrospectionJWTResponseSignerProvider = (*Config)(nil) + _ IntrospectionJWTResponseStrategyProvider = (*Config)(nil) ) diff --git a/fosite.go b/fosite.go index 8c78bbf3..e23cce4c 100644 --- a/fosite.go +++ b/fosite.go @@ -156,7 +156,7 @@ type Configurator interface { SanitationAllowedProvider JWTScopeFieldProvider JWTSecuredAuthorizeResponseModeIssuerProvider - JWTSecuredAuthorizeResponseModeSignerProvider + JWTSecuredAuthorizeResponseModeStrategyProvider JWTSecuredAuthorizeResponseModeLifespanProvider JWTProfileAccessTokensProvider AccessTokenIssuerProvider @@ -196,7 +196,8 @@ type Configurator interface { RFC8628UserAuthorizeEndpointHandlersProvider RFC9628DeviceAuthorizeConfigProvider IntrospectionIssuerProvider - IntrospectionJWTResponseSignerProvider + IntrospectionJWTResponseStrategyProvider + JWTStrategyProvider UseLegacyErrorFormatProvider } diff --git a/generate-mocks.sh b/generate-mocks.sh index 513824d9..5c844046 100755 --- a/generate-mocks.sh +++ b/generate-mocks.sh @@ -2,7 +2,6 @@ ${MOCKGEN:-mockgen} -package mock -destination testing/mock/rw.go net/http ResponseWriter -${MOCKGEN:-mockgen} -package mock -destination testing/mock/hash.go authelia.com/provider/oauth2 Hasher ${MOCKGEN:-mockgen} -package mock -destination testing/mock/introspector.go authelia.com/provider/oauth2 TokenIntrospector ${MOCKGEN:-mockgen} -package mock -destination testing/mock/client.go authelia.com/provider/oauth2 Client ${MOCKGEN:-mockgen} -package mock -destination testing/mock/client_secret.go authelia.com/provider/oauth2 ClientSecret @@ -18,7 +17,7 @@ ${MOCKGEN:-mockgen} -package mock -destination testing/mock/transactional.go aut ${MOCKGEN:-mockgen} -package mock -destination testing/mock/oauth2_storage.go authelia.com/provider/oauth2/handler/oauth2 CoreStorage ${MOCKGEN:-mockgen} -package mock -destination testing/mock/oauth2_device_auth_storage.go -mock_names Storage=MockRFC8628Storage authelia.com/provider/oauth2/handler/rfc8628 Storage ${MOCKGEN:-mockgen} -package mock -destination testing/mock/openid_id_token_storage.go authelia.com/provider/oauth2/handler/openid OpenIDConnectRequestStorage -${MOCKGEN:-mockgen} -package mock -destination testing/mock/pkce_storage.go authelia.com/provider/oauth2/handler/pkce PKCERequestStorage +${MOCKGEN:-mockgen} -package mock -destination testing/mock/pkce_storage.go -mock_names Storage=MockPKCERequestStorage authelia.com/provider/oauth2/handler/pkce Storage ${MOCKGEN:-mockgen} -package mock -destination testing/mock/authorize_code_storage.go authelia.com/provider/oauth2/handler/oauth2 AuthorizeCodeStorage ${MOCKGEN:-mockgen} -package mock -destination testing/mock/oauth2_auth_jwt_storage.go authelia.com/provider/oauth2/handler/rfc7523 RFC7523KeyStorage ${MOCKGEN:-mockgen} -package mock -destination testing/mock/access_token_storage.go authelia.com/provider/oauth2/handler/oauth2 AccessTokenStorage diff --git a/generate.go b/generate.go index cf07ce78..5d33df43 100644 --- a/generate.go +++ b/generate.go @@ -3,7 +3,6 @@ package oauth2 -//go:generate go run go.uber.org/mock/mockgen -package internal -destination internal/hash.go authelia.com/provider/oauth2 Hasher //go:generate go run go.uber.org/mock/mockgen -package internal -destination internal/storage.go authelia.com/provider/oauth2 Storage //go:generate go run go.uber.org/mock/mockgen -package internal -destination internal/transactional.go authelia.com/provider/oauth2/storage Transactional //go:generate go run go.uber.org/mock/mockgen -package internal -destination internal/oauth2_storage.go authelia.com/provider/oauth2/handler/oauth2 CoreStorage @@ -20,7 +19,7 @@ package oauth2 //go:generate go run go.uber.org/mock/mockgen -package internal -destination internal/refresh_token_strategy.go authelia.com/provider/oauth2/handler/oauth2 ReyfreshTokenStrategy //go:generate go run go.uber.org/mock/mockgen -package internal -destination internal/authorize_code_strategy.go authelia.com/provider/oauth2/handler/oauth2 AuthorizeCodeStrategy //go:generate go run go.uber.org/mock/mockgen -package internal -destination internal/id_token_strategy.go authelia.com/provider/oauth2/handler/openid OpenIDConnectTokenStrategy -//go:generate go run go.uber.org/mock/mockgen -package internal -destination internal/pkce_storage_strategy.go authelia.com/provider/oauth2/handler/pkce PKCERequestStorage +//go:generate go run go.uber.org/mock/mockgen -package internal -destination internal/pkce_storage_strategy.go -mock_names Storage=MockPKCERequestStorage authelia.com/provider/oauth2/handler/pkce Storage //go:generate go run go.uber.org/mock/mockgen -package internal -destination internal/authorize_handler.go authelia.com/provider/oauth2 AuthorizeEndpointHandler //go:generate go run go.uber.org/mock/mockgen -package internal -destination internal/revoke_handler.go authelia.com/provider/oauth2 RevocationHandler //go:generate go run go.uber.org/mock/mockgen -package internal -destination internal/token_handler.go authelia.com/provider/oauth2 TokenEndpointHandler diff --git a/go.mod b/go.mod index ff1d6c06..9f51070d 100644 --- a/go.mod +++ b/go.mod @@ -7,7 +7,6 @@ toolchain go1.23.1 require ( github.com/dgraph-io/ristretto v0.1.1 github.com/go-jose/go-jose/v4 v4.0.4 - github.com/golang-jwt/jwt/v5 v5.2.1 github.com/google/uuid v1.6.0 github.com/gorilla/mux v1.8.1 github.com/hashicorp/go-retryablehttp v0.7.7 diff --git a/go.sum b/go.sum index 287c8796..9927a4e4 100644 --- a/go.sum +++ b/go.sum @@ -19,8 +19,6 @@ github.com/fatih/color v1.16.0 h1:zmkK9Ngbjj+K0yRhTVONQh1p/HknKYSlNT+vZCzyokM= github.com/fatih/color v1.16.0/go.mod h1:fL2Sau1YI5c0pdGEVCbKQbLXB6edEj1ZgiY4NijnWvE= github.com/go-jose/go-jose/v4 v4.0.4 h1:VsjPI33J0SB9vQM6PLmNjoHqMQNGPiZ0rHL7Ni7Q6/E= github.com/go-jose/go-jose/v4 v4.0.4/go.mod h1:NKb5HO1EZccyMpiZNbdUw/14tiXNyUJh188dfnMCAfc= -github.com/golang-jwt/jwt/v5 v5.2.1 h1:OuVbFODueb089Lh128TAcimifWaLhJwVflnrgM17wHk= -github.com/golang-jwt/jwt/v5 v5.2.1/go.mod h1:pqrtFR0X4osieyHYxtmOUWsAWrfe1Q5UVIyoH402zdk= github.com/golang/glog v0.0.0-20160126235308-23def4e6c14b/go.mod h1:SBH7ygxi8pfUlaOkMMuAQtPIUF8ecWP5IEl/CR7VP2Q= github.com/golang/glog v1.2.2 h1:1+mZ9upx1Dh6FmUTFR1naJ77miKiXgALjWOZ3NVFPmY= github.com/golang/glog v1.2.2/go.mod h1:6AhwSGph0fcJtXVM/PEHPqZlFeoLxhs7/t5UDAwmO+w= diff --git a/handler/oauth2/flow_authorize_code_token.go b/handler/oauth2/flow_authorize_code_token.go index 4ad81764..5976bbec 100644 --- a/handler/oauth2/flow_authorize_code_token.go +++ b/handler/oauth2/flow_authorize_code_token.go @@ -12,6 +12,7 @@ import ( "authelia.com/provider/oauth2" "authelia.com/provider/oauth2/internal/consts" "authelia.com/provider/oauth2/storage" + "authelia.com/provider/oauth2/token/jwt" "authelia.com/provider/oauth2/x/errorsx" ) @@ -106,11 +107,11 @@ func (c *AuthorizeExplicitGrantHandler) HandleTokenEndpointRequest(ctx context.C request.SetID(authorizeRequest.GetID()) atLifespan := oauth2.GetEffectiveLifespan(request.GetClient(), oauth2.GrantTypeAuthorizationCode, oauth2.AccessToken, c.Config.GetAccessTokenLifespan(ctx)) - request.GetSession().SetExpiresAt(oauth2.AccessToken, time.Now().UTC().Add(atLifespan).Round(time.Second)) + request.GetSession().SetExpiresAt(oauth2.AccessToken, time.Now().UTC().Add(atLifespan).Truncate(jwt.TimePrecision)) rtLifespan := oauth2.GetEffectiveLifespan(request.GetClient(), oauth2.GrantTypeAuthorizationCode, oauth2.RefreshToken, c.Config.GetRefreshTokenLifespan(ctx)) if rtLifespan > -1 { - request.GetSession().SetExpiresAt(oauth2.RefreshToken, time.Now().UTC().Add(rtLifespan).Round(time.Second)) + request.GetSession().SetExpiresAt(oauth2.RefreshToken, time.Now().UTC().Add(rtLifespan).Truncate(jwt.TimePrecision)) } return nil diff --git a/handler/oauth2/flow_authorize_code_token_test.go b/handler/oauth2/flow_authorize_code_token_test.go index 4e1810a2..2e583852 100644 --- a/handler/oauth2/flow_authorize_code_token_test.go +++ b/handler/oauth2/flow_authorize_code_token_test.go @@ -19,6 +19,7 @@ import ( "authelia.com/provider/oauth2/internal/consts" "authelia.com/provider/oauth2/storage" "authelia.com/provider/oauth2/testing/mock" + "authelia.com/provider/oauth2/token/jwt" ) func TestAuthorizeCode_PopulateTokenEndpointResponse(t *testing.T) { @@ -462,8 +463,8 @@ func TestAuthorizeExplicitGrantHandler_HandleTokenEndpointRequest(t *testing.T) require.NoError(t, s.InvalidateAuthorizeCodeSession(context.TODO(), sig)) }, func(t *testing.T, s CoreStorage, r *oauth2.AccessRequest, ar *oauth2.AuthorizeRequest) { - assert.Equal(t, time.Now().Add(time.Minute).UTC().Round(time.Second), r.GetSession().GetExpiresAt(oauth2.AccessToken)) - assert.Equal(t, time.Now().Add(time.Minute).UTC().Round(time.Second), r.GetSession().GetExpiresAt(oauth2.RefreshToken)) + assert.Equal(t, time.Now().Add(time.Minute).UTC().Truncate(jwt.TimePrecision), r.GetSession().GetExpiresAt(oauth2.AccessToken)) + assert.Equal(t, time.Now().Add(time.Minute).UTC().Truncate(jwt.TimePrecision), r.GetSession().GetExpiresAt(oauth2.RefreshToken)) }, "The provided authorization grant (e.g., authorization code, resource owner credentials) or refresh token is invalid, expired, revoked, does not match the redirection URI used in the authorization request, or was issued to another client. The authorization code has already been used.", }, diff --git a/handler/oauth2/flow_authorize_implicit.go b/handler/oauth2/flow_authorize_implicit.go index 00b05ece..2fa9e31d 100644 --- a/handler/oauth2/flow_authorize_implicit.go +++ b/handler/oauth2/flow_authorize_implicit.go @@ -11,6 +11,7 @@ import ( "authelia.com/provider/oauth2" "authelia.com/provider/oauth2/internal/consts" + "authelia.com/provider/oauth2/token/jwt" "authelia.com/provider/oauth2/x/errorsx" ) @@ -70,7 +71,7 @@ func (c *AuthorizeImplicitGrantTypeHandler) IssueImplicitAccessToken(ctx context // Only override expiry if none is set. atLifespan := oauth2.GetEffectiveLifespan(requester.GetClient(), oauth2.GrantTypeImplicit, oauth2.AccessToken, c.Config.GetAccessTokenLifespan(ctx)) if requester.GetSession().GetExpiresAt(oauth2.AccessToken).IsZero() { - requester.GetSession().SetExpiresAt(oauth2.AccessToken, time.Now().UTC().Add(atLifespan).Round(time.Second)) + requester.GetSession().SetExpiresAt(oauth2.AccessToken, time.Now().UTC().Add(atLifespan).Truncate(jwt.TimePrecision)) } // Generate the code diff --git a/handler/oauth2/flow_generic_code_token.go b/handler/oauth2/flow_generic_code_token.go index f94d5d00..9c4ac293 100644 --- a/handler/oauth2/flow_generic_code_token.go +++ b/handler/oauth2/flow_generic_code_token.go @@ -8,6 +8,7 @@ import ( "authelia.com/provider/oauth2" "authelia.com/provider/oauth2/storage" + "authelia.com/provider/oauth2/token/jwt" "authelia.com/provider/oauth2/x/errorsx" ) @@ -130,11 +131,11 @@ func (c *GenericCodeTokenEndpointHandler) HandleTokenEndpointRequest(ctx context request.SetID(authorizeRequest.GetID()) atLifespan := oauth2.GetEffectiveLifespan(request.GetClient(), oauth2.GrantTypeAuthorizationCode, oauth2.AccessToken, c.Config.GetAccessTokenLifespan(ctx)) - request.GetSession().SetExpiresAt(oauth2.AccessToken, time.Now().UTC().Add(atLifespan).Round(time.Second)) + request.GetSession().SetExpiresAt(oauth2.AccessToken, time.Now().UTC().Add(atLifespan).Truncate(jwt.TimePrecision)) rtLifespan := oauth2.GetEffectiveLifespan(request.GetClient(), oauth2.GrantTypeAuthorizationCode, oauth2.RefreshToken, c.Config.GetRefreshTokenLifespan(ctx)) if rtLifespan > -1 { - request.GetSession().SetExpiresAt(oauth2.RefreshToken, time.Now().UTC().Add(rtLifespan).Round(time.Second)) + request.GetSession().SetExpiresAt(oauth2.RefreshToken, time.Now().UTC().Add(rtLifespan).Truncate(jwt.TimePrecision)) } return nil diff --git a/handler/oauth2/flow_refresh.go b/handler/oauth2/flow_refresh.go index c0d0a6f5..fc121142 100644 --- a/handler/oauth2/flow_refresh.go +++ b/handler/oauth2/flow_refresh.go @@ -13,6 +13,7 @@ import ( "authelia.com/provider/oauth2" "authelia.com/provider/oauth2/internal/consts" "authelia.com/provider/oauth2/storage" + "authelia.com/provider/oauth2/token/jwt" "authelia.com/provider/oauth2/x/errorsx" ) @@ -146,11 +147,11 @@ func (c *RefreshTokenGrantHandler) HandleTokenEndpointRequest(ctx context.Contex } atLifespan := oauth2.GetEffectiveLifespan(request.GetClient(), oauth2.GrantTypeRefreshToken, oauth2.AccessToken, c.Config.GetAccessTokenLifespan(ctx)) - request.GetSession().SetExpiresAt(oauth2.AccessToken, time.Now().UTC().Add(atLifespan).Round(time.Second)) + request.GetSession().SetExpiresAt(oauth2.AccessToken, time.Now().UTC().Add(atLifespan).Truncate(jwt.TimePrecision)) rtLifespan := oauth2.GetEffectiveLifespan(request.GetClient(), oauth2.GrantTypeRefreshToken, oauth2.RefreshToken, c.Config.GetRefreshTokenLifespan(ctx)) if rtLifespan > -1 { - request.GetSession().SetExpiresAt(oauth2.RefreshToken, time.Now().UTC().Add(rtLifespan).Round(time.Second)) + request.GetSession().SetExpiresAt(oauth2.RefreshToken, time.Now().UTC().Add(rtLifespan).Truncate(jwt.TimePrecision)) } return nil diff --git a/handler/oauth2/flow_refresh_test.go b/handler/oauth2/flow_refresh_test.go index 230ce7a2..7f8f290e 100644 --- a/handler/oauth2/flow_refresh_test.go +++ b/handler/oauth2/flow_refresh_test.go @@ -20,6 +20,7 @@ import ( "authelia.com/provider/oauth2/internal/consts" "authelia.com/provider/oauth2/storage" "authelia.com/provider/oauth2/testing/mock" + "authelia.com/provider/oauth2/token/jwt" ) func TestRefreshFlow_HandleTokenEndpointRequest(t *testing.T) { @@ -114,7 +115,7 @@ func TestRefreshFlow_HandleTokenEndpointRequest(t *testing.T) { RequestedScope: oauth2.Arguments{"foo", "bar", consts.ScopeOffline}, Session: expiredSess, Form: url.Values{"foo": []string{"bar"}}, - RequestedAt: time.Now().UTC().Add(-time.Hour).Round(time.Hour), + RequestedAt: time.Now().UTC().Add(-time.Hour).Truncate(time.Hour), }) require.NoError(t, err) }, @@ -139,7 +140,7 @@ func TestRefreshFlow_HandleTokenEndpointRequest(t *testing.T) { RequestedScope: oauth2.Arguments{"foo", consts.ScopeOffline}, Session: sess, Form: url.Values{"foo": []string{"bar"}}, - RequestedAt: time.Now().UTC().Add(-time.Hour).Round(time.Hour), + RequestedAt: time.Now().UTC().Add(-time.Hour).Truncate(time.Hour), }) require.NoError(t, err) }, @@ -165,18 +166,18 @@ func TestRefreshFlow_HandleTokenEndpointRequest(t *testing.T) { RequestedScope: oauth2.Arguments{"foo", "bar", consts.ScopeOffline}, Session: sess, Form: url.Values{"foo": []string{"bar"}}, - RequestedAt: time.Now().UTC().Add(-time.Hour).Round(time.Hour), + RequestedAt: time.Now().UTC().Add(-time.Hour).Truncate(time.Hour), }) require.NoError(t, err) }, expect: func(t *testing.T) { assert.NotEqual(t, sess, areq.Session) - assert.NotEqual(t, time.Now().UTC().Add(-time.Hour).Round(time.Hour), areq.RequestedAt) + assert.NotEqual(t, time.Now().UTC().Add(-time.Hour).Truncate(time.Hour), areq.RequestedAt) assert.Equal(t, oauth2.Arguments{"foo", consts.ScopeOffline}, areq.GrantedScope) assert.Equal(t, oauth2.Arguments{"foo", consts.ScopeOffline}, areq.RequestedScope) assert.NotEqual(t, url.Values{"foo": []string{"bar"}}, areq.Form) - assert.Equal(t, time.Now().Add(time.Hour).UTC().Round(time.Second), areq.GetSession().GetExpiresAt(oauth2.AccessToken)) - assert.Equal(t, time.Now().Add(time.Hour).UTC().Round(time.Second), areq.GetSession().GetExpiresAt(oauth2.RefreshToken)) + assert.Equal(t, time.Now().Add(time.Hour).UTC().Truncate(jwt.TimePrecision), areq.GetSession().GetExpiresAt(oauth2.AccessToken)) + assert.Equal(t, time.Now().Add(time.Hour).UTC().Truncate(jwt.TimePrecision), areq.GetSession().GetExpiresAt(oauth2.RefreshToken)) }, }, { @@ -200,7 +201,7 @@ func TestRefreshFlow_HandleTokenEndpointRequest(t *testing.T) { RequestedScope: oauth2.Arguments{"foo", "bar", "baz", consts.ScopeOffline}, Session: sess, Form: url.Values{"foo": []string{"bar"}}, - RequestedAt: time.Now().UTC().Add(-time.Hour).Round(time.Hour), + RequestedAt: time.Now().UTC().Add(-time.Hour).Truncate(time.Hour), }) require.NoError(t, err) }, @@ -232,7 +233,7 @@ func TestRefreshFlow_HandleTokenEndpointRequest(t *testing.T) { RequestedScope: oauth2.Arguments{"foo", "bar", "baz", consts.ScopeOffline}, Session: sess, Form: url.Values{"foo": []string{"bar"}}, - RequestedAt: time.Now().UTC().Add(-time.Hour).Round(time.Hour), + RequestedAt: time.Now().UTC().Add(-time.Hour).Truncate(time.Hour), }) require.NoError(t, err) }, @@ -264,7 +265,7 @@ func TestRefreshFlow_HandleTokenEndpointRequest(t *testing.T) { RequestedScope: oauth2.Arguments{"foo", "baz", consts.ScopeOffline}, Session: sess, Form: url.Values{"foo": []string{"bar"}}, - RequestedAt: time.Now().UTC().Add(-time.Hour).Round(time.Hour), + RequestedAt: time.Now().UTC().Add(-time.Hour).Truncate(time.Hour), }) require.NoError(t, err) }, @@ -294,13 +295,13 @@ func TestRefreshFlow_HandleTokenEndpointRequest(t *testing.T) { RequestedScope: oauth2.Arguments{"foo", "bar", consts.ScopeOffline}, Session: sess, Form: url.Values{"foo": []string{"bar"}}, - RequestedAt: time.Now().UTC().Add(-time.Hour).Round(time.Hour), + RequestedAt: time.Now().UTC().Add(-time.Hour).Truncate(time.Hour), }) require.NoError(t, err) }, expect: func(t *testing.T) { assert.NotEqual(t, sess, areq.Session) - assert.NotEqual(t, time.Now().UTC().Add(-time.Hour).Round(time.Hour), areq.RequestedAt) + assert.NotEqual(t, time.Now().UTC().Add(-time.Hour).Truncate(time.Hour), areq.RequestedAt) assert.Equal(t, oauth2.Arguments{"foo", consts.ScopeOffline}, areq.GrantedScope) assert.Equal(t, oauth2.Arguments{"foo", consts.ScopeOffline}, areq.RequestedScope) assert.NotEqual(t, url.Values{"foo": []string{"bar"}}, areq.Form) @@ -328,7 +329,7 @@ func TestRefreshFlow_HandleTokenEndpointRequest(t *testing.T) { RequestedScope: oauth2.Arguments{"foo", "bar"}, Session: sess, Form: url.Values{"foo": []string{"bar"}}, - RequestedAt: time.Now().UTC().Add(-time.Hour).Round(time.Hour), + RequestedAt: time.Now().UTC().Add(-time.Hour).Truncate(time.Hour), }) require.NoError(t, err) }, @@ -355,18 +356,18 @@ func TestRefreshFlow_HandleTokenEndpointRequest(t *testing.T) { RequestedScope: oauth2.Arguments{"foo", "bar"}, Session: sess, Form: url.Values{"foo": []string{"bar"}}, - RequestedAt: time.Now().UTC().Add(-time.Hour).Round(time.Hour), + RequestedAt: time.Now().UTC().Add(-time.Hour).Truncate(time.Hour), }) require.NoError(t, err) }, expect: func(t *testing.T) { assert.NotEqual(t, sess, areq.Session) - assert.NotEqual(t, time.Now().UTC().Add(-time.Hour).Round(time.Hour), areq.RequestedAt) + assert.NotEqual(t, time.Now().UTC().Add(-time.Hour).Truncate(time.Hour), areq.RequestedAt) assert.Equal(t, oauth2.Arguments{"foo"}, areq.GrantedScope) assert.Equal(t, oauth2.Arguments{"foo"}, areq.RequestedScope) assert.NotEqual(t, url.Values{"foo": []string{"bar"}}, areq.Form) - assert.Equal(t, time.Now().Add(time.Hour).UTC().Round(time.Second), areq.GetSession().GetExpiresAt(oauth2.AccessToken)) - assert.Equal(t, time.Now().Add(time.Hour).UTC().Round(time.Second), areq.GetSession().GetExpiresAt(oauth2.RefreshToken)) + assert.Equal(t, time.Now().Add(time.Hour).UTC().Truncate(jwt.TimePrecision), areq.GetSession().GetExpiresAt(oauth2.AccessToken)) + assert.Equal(t, time.Now().Add(time.Hour).UTC().Truncate(jwt.TimePrecision), areq.GetSession().GetExpiresAt(oauth2.RefreshToken)) }, }, { @@ -389,7 +390,7 @@ func TestRefreshFlow_HandleTokenEndpointRequest(t *testing.T) { RequestedScope: oauth2.Arguments{"foo", "bar", consts.ScopeOffline}, Session: sess, Form: url.Values{"foo": []string{"bar"}}, - RequestedAt: time.Now().UTC().Add(-time.Hour).Round(time.Hour), + RequestedAt: time.Now().UTC().Add(-time.Hour).Truncate(time.Hour), } err = store.CreateRefreshTokenSession(context.TODO(), sig, req) require.NoError(t, err) diff --git a/handler/oauth2/flow_resource_owner.go b/handler/oauth2/flow_resource_owner.go index 89c59f3f..d140a948 100644 --- a/handler/oauth2/flow_resource_owner.go +++ b/handler/oauth2/flow_resource_owner.go @@ -11,6 +11,7 @@ import ( "authelia.com/provider/oauth2" "authelia.com/provider/oauth2/internal/consts" + "authelia.com/provider/oauth2/token/jwt" "authelia.com/provider/oauth2/x/errorsx" ) @@ -70,11 +71,11 @@ func (c *ResourceOwnerPasswordCredentialsGrantHandler) HandleTokenEndpointReques delete(request.GetRequestForm(), consts.FormParameterPassword) atLifespan := oauth2.GetEffectiveLifespan(request.GetClient(), oauth2.GrantTypePassword, oauth2.AccessToken, c.Config.GetAccessTokenLifespan(ctx)) - request.GetSession().SetExpiresAt(oauth2.AccessToken, time.Now().UTC().Add(atLifespan).Round(time.Second)) + request.GetSession().SetExpiresAt(oauth2.AccessToken, time.Now().UTC().Add(atLifespan).Truncate(jwt.TimePrecision)) rtLifespan := oauth2.GetEffectiveLifespan(request.GetClient(), oauth2.GrantTypePassword, oauth2.RefreshToken, c.Config.GetRefreshTokenLifespan(ctx)) if rtLifespan > -1 { - request.GetSession().SetExpiresAt(oauth2.RefreshToken, time.Now().UTC().Add(rtLifespan).Round(time.Second)) + request.GetSession().SetExpiresAt(oauth2.RefreshToken, time.Now().UTC().Add(rtLifespan).Truncate(jwt.TimePrecision)) } return nil diff --git a/handler/oauth2/flow_resource_owner_test.go b/handler/oauth2/flow_resource_owner_test.go index 030e7613..1cabbb16 100644 --- a/handler/oauth2/flow_resource_owner_test.go +++ b/handler/oauth2/flow_resource_owner_test.go @@ -18,6 +18,7 @@ import ( "authelia.com/provider/oauth2" "authelia.com/provider/oauth2/internal/consts" "authelia.com/provider/oauth2/testing/mock" + "authelia.com/provider/oauth2/token/jwt" ) func TestResourceOwnerFlow_HandleTokenEndpointRequest(t *testing.T) { @@ -89,8 +90,8 @@ func TestResourceOwnerFlow_HandleTokenEndpointRequest(t *testing.T) { store.EXPECT().Authenticate(context.TODO(), "peter", "pan").Return(nil) }, check: func(areq *oauth2.AccessRequest) { - assert.Equal(t, time.Now().Add(time.Hour).UTC().Round(time.Second), areq.GetSession().GetExpiresAt(oauth2.AccessToken)) - assert.Equal(t, time.Now().Add(time.Hour).UTC().Round(time.Second), areq.GetSession().GetExpiresAt(oauth2.RefreshToken)) + assert.Equal(t, time.Now().Add(time.Hour).UTC().Truncate(jwt.TimePrecision), areq.GetSession().GetExpiresAt(oauth2.AccessToken)) + assert.Equal(t, time.Now().Add(time.Hour).UTC().Truncate(jwt.TimePrecision), areq.GetSession().GetExpiresAt(oauth2.RefreshToken)) }, }, } { diff --git a/handler/oauth2/introspector_jwt.go b/handler/oauth2/introspector_jwt.go index e8793723..78be9f64 100644 --- a/handler/oauth2/introspector_jwt.go +++ b/handler/oauth2/introspector_jwt.go @@ -14,55 +14,37 @@ import ( ) type StatelessJWTValidator struct { - jwt.Signer + jwt.Strategy Config interface { oauth2.ScopeStrategyProvider } } -func (v *StatelessJWTValidator) IntrospectToken(ctx context.Context, token string, tokenUse oauth2.TokenUse, accessRequest oauth2.AccessRequester, scopes []string) (oauth2.TokenUse, error) { - t, err := validateJWT(ctx, v.Signer, token) - if err != nil { +func (v *StatelessJWTValidator) IntrospectToken(ctx context.Context, tokenString string, tokenUse oauth2.TokenUse, requester oauth2.AccessRequester, scopes []string) (use oauth2.TokenUse, err error) { + var token *jwt.Token + + if token, err = validateJWT(ctx, v.Strategy, jwt.NewStatelessJWTProfileIntrospectionClient(requester.GetClient()), tokenString); err != nil { return "", err } - if !IsJWTProfileAccessToken(t) { - return "", errorsx.WithStack(oauth2.ErrRequestUnauthorized.WithDebug("The provided token is not a valid RFC9068 JWT Profile Access Token as it is missing the header 'typ' value of 'at+jwt' ")) + if err = token.Valid(jwt.ValidateTypes(consts.JSONWebTokenTypeAccessToken, consts.JSONWebTokenTypeAccessTokenAlternative)); err != nil { + return "", errorsx.WithStack(oauth2.ErrRequestUnauthorized.WithDebug("The provided token is not a valid RFC9068 JWT Profile Access Token as it is missing the header 'typ' value of 'at+jwt'.")) } - requester := AccessTokenJWTToRequest(t) + r := AccessTokenJWTToRequest(token) - if err := matchScopes(v.Config.GetScopeStrategy(ctx), requester.GetGrantedScopes(), scopes); err != nil { + if err = matchScopes(v.Config.GetScopeStrategy(ctx), r.GetGrantedScopes(), scopes); err != nil { return oauth2.AccessToken, err } - accessRequest.Merge(requester) + requester.Merge(r) return oauth2.AccessToken, nil } -// IsJWTProfileAccessToken validates a *jwt.Token is actually a RFC9068 JWT Profile Access Token by checking the -// relevant header as per https://datatracker.ietf.org/doc/html/rfc9068#section-2.1 which explicitly states that -// the header MUST include a typ of 'at+jwt' or 'application/at+jwt' with a preference of 'at+jwt'. -func IsJWTProfileAccessToken(token *jwt.Token) bool { - var ( - raw any - typ string - ok bool - ) - - if raw, ok = token.Header[jwt.JWTHeaderKeyValueType]; !ok { - return false - } - - typ, ok = raw.(string) - - return ok && (typ == jwt.JWTHeaderTypeValueAccessTokenJWT || typ == "application/at+jwt") -} - // AccessTokenJWTToRequest tries to reconstruct oauth2.Request from a JWT. func AccessTokenJWTToRequest(token *jwt.Token) oauth2.Requester { - mapClaims := token.Claims + mapClaims := token.Claims.ToMapClaims() claims := jwt.JWTClaims{} claims.FromMapClaims(mapClaims) diff --git a/handler/oauth2/introspector_jwt_test.go b/handler/oauth2/introspector_jwt_test.go index 33cc49a9..9f9d5d69 100644 --- a/handler/oauth2/introspector_jwt_test.go +++ b/handler/oauth2/introspector_jwt_test.go @@ -6,9 +6,9 @@ package oauth2 import ( "context" "encoding/base64" - "fmt" "strings" "testing" + "time" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -19,8 +19,6 @@ import ( ) func TestIntrospectJWT(t *testing.T) { - rsaKey := gen.MustRSAKey() - config := &oauth2.Config{ EnforceJWTProfileAccessTokens: true, GlobalSecret: []byte("foofoofoofoofoofoofoofoofoofoofoo"), @@ -28,95 +26,122 @@ func TestIntrospectJWT(t *testing.T) { strategy := &JWTProfileCoreStrategy{ HMACCoreStrategy: NewHMACCoreStrategy(config, "authelia_%s_"), - Signer: &jwt.DefaultSigner{ - GetPrivateKey: func(_ context.Context) (any, error) { - return rsaKey, nil - }, + Strategy: &jwt.DefaultStrategy{ + Config: config, + Issuer: jwt.NewDefaultIssuerRS256Unverified(gen.MustRSAKey()), }, Config: config, } - var v = &StatelessJWTValidator{ - Signer: strategy, + validator := &StatelessJWTValidator{ + Strategy: strategy, Config: &oauth2.Config{ ScopeStrategy: oauth2.HierarchicScopeStrategy, }, } - for k, c := range []struct { - description string - token func() string - expectErr error - scopes []string + testCases := []struct { + name string + token func(t *testing.T) string + err error + expected string + scopes []string }{ { - description: "should fail because jwt is expired", - token: func() string { - jwt := jwtExpiredCase(oauth2.AccessToken) - token, _, err := strategy.GenerateAccessToken(context.TODO(), jwt) - assert.NoError(t, err) - return token + name: "ShouldFailTokenExpired", + token: func(t *testing.T) string { + token := jwtExpiredCase(oauth2.AccessToken, time.Unix(1726972738, 0)) + tokenString, _, err := strategy.GenerateAccessToken(context.TODO(), token) + require.NoError(t, err) + + return tokenString }, - expectErr: oauth2.ErrTokenExpired, + err: oauth2.ErrTokenExpired, + expected: "Token expired. The token expired. Token expired at 1726969138.", }, { - description: "should pass because scope was granted", - token: func() string { - jwt := jwtValidCase(oauth2.AccessToken) - jwt.GrantedScope = []string{"foo", "bar"} - token, _, err := strategy.GenerateAccessToken(context.TODO(), jwt) - assert.NoError(t, err) - return token + name: "ShouldPassScopeGranted", + token: func(t *testing.T) string { + token := jwtValidCase(oauth2.AccessToken) + token.GrantedScope = []string{"foo", "bar"} + tokenString, _, err := strategy.GenerateAccessToken(context.TODO(), token) + + require.NoError(t, err) + + return tokenString }, scopes: []string{"foo"}, }, { - description: "should fail because scope was not granted", - token: func() string { - jwt := jwtValidCase(oauth2.AccessToken) - token, _, err := strategy.GenerateAccessToken(context.TODO(), jwt) - assert.NoError(t, err) - return token + name: "ShouldFailWrongTyp", + token: func(t *testing.T) string { + token := jwtInvalidTypCase(oauth2.AccessToken) + token.GrantedScope = []string{"foo", "bar"} + tokenString, _, err := strategy.GenerateAccessToken(context.TODO(), token) + + require.NoError(t, err) + + return tokenString }, - scopes: []string{"foo"}, - expectErr: oauth2.ErrInvalidScope, + scopes: []string{"foo"}, + err: oauth2.ErrRequestUnauthorized, + expected: "The request could not be authorized. Check that you provided valid credentials in the right format. The provided token is not a valid RFC9068 JWT Profile Access Token as it is missing the header 'typ' value of 'at+jwt'.", }, { - description: "should fail because signature is invalid", - token: func() string { - jwt := jwtValidCase(oauth2.AccessToken) - token, _, err := strategy.GenerateAccessToken(context.TODO(), jwt) - assert.NoError(t, err) - parts := strings.Split(token, ".") - require.Len(t, parts, 3, "%s - %v", token, parts) + name: "ShouldFailScopeNotGranted", + token: func(t *testing.T) string { + token := jwtValidCase(oauth2.AccessToken) + tokenString, _, err := strategy.GenerateAccessToken(context.TODO(), token) + require.NoError(t, err) + + return tokenString + }, + scopes: []string{"foo"}, + err: oauth2.ErrInvalidScope, + expected: "The requested scope is invalid, unknown, or malformed. The request scope 'foo' has not been granted or is not allowed to be requested.", + }, + { + name: "ShouldFailInvalidSignature", + token: func(t *testing.T) string { + token := jwtValidCase(oauth2.AccessToken) + tokenString, _, err := strategy.GenerateAccessToken(context.TODO(), token) + require.NoError(t, err) + parts := strings.Split(tokenString, ".") + require.Len(t, parts, 3, "%s - %v", tokenString, parts) dec, err := base64.RawURLEncoding.DecodeString(parts[1]) - assert.NoError(t, err) + require.NoError(t, err) s := strings.ReplaceAll(string(dec), "peter", "piper") parts[1] = base64.RawURLEncoding.EncodeToString([]byte(s)) + return strings.Join(parts, ".") }, - expectErr: oauth2.ErrTokenSignatureMismatch, + err: oauth2.ErrTokenSignatureMismatch, + expected: "Token signature mismatch. Check that you provided a valid token in the right format. Token has an invalid signature.", }, { - description: "should pass", - token: func() string { - jwt := jwtValidCase(oauth2.AccessToken) - token, _, err := strategy.GenerateAccessToken(context.TODO(), jwt) - assert.NoError(t, err) - return token + name: "ShouldPass", + token: func(t *testing.T) string { + token := jwtValidCase(oauth2.AccessToken) + tokenString, _, err := strategy.GenerateAccessToken(context.TODO(), token) + require.NoError(t, err) + + return tokenString }, }, - } { - t.Run(fmt.Sprintf("case=%d:%v", k, c.description), func(t *testing.T) { - if c.scopes == nil { - c.scopes = []string{} + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + if tc.scopes == nil { + tc.scopes = []string{} } areq := oauth2.NewAccessRequest(nil) - _, err := v.IntrospectToken(context.TODO(), c.token(), oauth2.AccessToken, areq, c.scopes) + _, err := validator.IntrospectToken(context.TODO(), tc.token(t), oauth2.AccessToken, areq, tc.scopes) - if c.expectErr != nil { - require.EqualError(t, err, c.expectErr.Error()) + if tc.err != nil { + assert.EqualError(t, err, tc.err.Error()) + assert.EqualError(t, oauth2.ErrorToDebugRFC6749Error(err), tc.expected) } else { require.NoError(t, err) assert.Equal(t, "peter", areq.Session.GetSubject()) @@ -126,16 +151,18 @@ func TestIntrospectJWT(t *testing.T) { } func BenchmarkIntrospectJWT(b *testing.B) { + config := &oauth2.Config{} + strategy := &JWTProfileCoreStrategy{ - Signer: &jwt.DefaultSigner{GetPrivateKey: func(_ context.Context) (any, error) { - return gen.MustRSAKey(), nil - }, + Strategy: &jwt.DefaultStrategy{ + Config: config, + Issuer: jwt.NewDefaultIssuerRS256Unverified(gen.MustRSAKey()), }, - Config: &oauth2.Config{}, + Config: config, } v := &StatelessJWTValidator{ - Signer: strategy, + Strategy: strategy, } jwt := jwtValidCase(oauth2.AccessToken) diff --git a/handler/oauth2/strategy.go b/handler/oauth2/strategy.go index 065539b6..4f165a28 100644 --- a/handler/oauth2/strategy.go +++ b/handler/oauth2/strategy.go @@ -12,13 +12,13 @@ import ( // NewCoreStrategy is a special constructor that if provided a signer will automatically decorate the HMACCoreStrategy // with a JWTProfileCoreStrategy, otherwise it just returns the HMACCoreStrategy. -func NewCoreStrategy(config CoreStrategyConfigurator, prefix string, signer jwt.Signer) (strategy CoreStrategy) { - if signer == nil { +func NewCoreStrategy(config CoreStrategyConfigurator, prefix string, strategy jwt.Strategy) (core CoreStrategy) { + if strategy == nil { return NewHMACCoreStrategy(config, prefix) } return &JWTProfileCoreStrategy{ - Signer: signer, + Strategy: strategy, HMACCoreStrategy: NewHMACCoreStrategy(config, prefix), Config: config, } diff --git a/handler/oauth2/strategy_jwt_profile.go b/handler/oauth2/strategy_jwt_profile.go index d8589563..c553d5a5 100644 --- a/handler/oauth2/strategy_jwt_profile.go +++ b/handler/oauth2/strategy_jwt_profile.go @@ -5,6 +5,7 @@ package oauth2 import ( "context" + "fmt" "strings" "time" @@ -13,12 +14,12 @@ import ( "authelia.com/provider/oauth2" "authelia.com/provider/oauth2/internal/consts" "authelia.com/provider/oauth2/token/jwt" - "authelia.com/provider/oauth2/x/errorsx" ) // JWTProfileCoreStrategy is a JWT RS256 strategy. type JWTProfileCoreStrategy struct { - jwt.Signer + jwt.Strategy + HMACCoreStrategy *HMACCoreStrategy Config interface { oauth2.AccessTokenIssuerProvider @@ -56,7 +57,7 @@ func (s *JWTProfileCoreStrategy) GenerateAccessToken(ctx context.Context, reques func (s *JWTProfileCoreStrategy) ValidateAccessToken(ctx context.Context, requester oauth2.Requester, tokenString string) (err error) { if possible, _ := s.IsPossiblyJWTProfileAccessToken(ctx, tokenString); possible { - _, err = validateJWT(ctx, s.Signer, tokenString) + _, err = validateJWT(ctx, s.Strategy, jwt.NewJWTProfileAccessTokenClient(requester.GetClient()), tokenString) return } @@ -170,42 +171,107 @@ func (s *JWTProfileCoreStrategy) GenerateJWT(ctx context.Context, tokenType oaut s.Config.GetJWTScopeField(ctx), ) - return s.Signer.Generate(ctx, claims.ToMapClaims(), header) + mapClaims := claims.ToMapClaims() + + return s.Strategy.Encode(ctx, mapClaims, jwt.WithHeaders(header), jwt.WithJWTProfileAccessTokenClient(client)) } -func validateJWT(ctx context.Context, jwtStrategy jwt.Signer, token string) (t *jwt.Token, err error) { - t, err = jwtStrategy.Decode(ctx, token) - if err == nil { - err = t.Claims.Valid() - return +func validateJWT(ctx context.Context, strategy jwt.Strategy, client jwt.Client, tokenString string) (token *jwt.Token, err error) { + if token, err = strategy.Decode(ctx, tokenString, jwt.WithClient(client)); err != nil { + return nil, fmtValidateJWTError(token, client, err) } - var e *jwt.ValidationError - if err != nil && errors.As(err, &e) { - err = errorsx.WithStack(toRFCErr(e).WithWrap(err).WithDebugError(err)) + if err = token.Claims.Valid(); err != nil { + return token, fmtValidateJWTError(token, client, err) + } + + return token, nil +} + +func fmtValidateJWTError(token *jwt.Token, client jwt.Client, inner error) (err error) { + var ( + clientText string + sigKID, sigAlg string + encKID, encAlg, enc string + date *jwt.NumericDate + ) + + if client != nil { + clientText = fmt.Sprintf("provided by client with id '%s' ", client.GetID()) + sigKID, sigAlg = client.GetSigningKeyID(), client.GetSigningAlg() + encKID, encAlg, enc = client.GetEncryptionKeyID(), client.GetEncryptionAlg(), client.GetEncryptionEnc() } - return -} - -func toRFCErr(v *jwt.ValidationError) *oauth2.RFC6749Error { - switch { - case v == nil: - return nil - case v.Has(jwt.ValidationErrorMalformed): - return oauth2.ErrInvalidTokenFormat - case v.Has(jwt.ValidationErrorUnverifiable | jwt.ValidationErrorSignatureInvalid): - return oauth2.ErrTokenSignatureMismatch - case v.Has(jwt.ValidationErrorExpired): - return oauth2.ErrTokenExpired - case v.Has(jwt.ValidationErrorAudience | - jwt.ValidationErrorIssuedAt | - jwt.ValidationErrorIssuer | - jwt.ValidationErrorNotValidYet | - jwt.ValidationErrorId | - jwt.ValidationErrorClaimsInvalid): - return oauth2.ErrTokenClaim - default: - return oauth2.ErrRequestUnauthorized + if errJWTValidation := new(jwt.ValidationError); errors.As(inner, &errJWTValidation) { + switch { + case errJWTValidation.Has(jwt.ValidationErrorHeaderKeyIDInvalid): + return oauth2.ErrInvalidTokenFormat.WithDebugf("Token %sis expected to be signed with the 'kid' header value '%s' but it was signed with the 'kid' header value '%s'.", clientText, sigKID, token.KeyID) + case errJWTValidation.Has(jwt.ValidationErrorHeaderAlgorithmInvalid): + return oauth2.ErrInvalidTokenFormat.WithDebugf("Token %sis expected to be signed with the 'alg' header value '%s' but it was signed with the 'alg' header value '%s'.", clientText, sigAlg, token.SignatureAlgorithm) + case errJWTValidation.Has(jwt.ValidationErrorHeaderTypeInvalid): + return oauth2.ErrInvalidTokenFormat.WithDebugf("Token %sis expected to be signed with the 'typ' header value '%s' but it was signed with the 'typ' header value '%s'.", clientText, consts.JSONWebTokenTypeJWT, token.Header[consts.JSONWebTokenHeaderType]) + case errJWTValidation.Has(jwt.ValidationErrorHeaderEncryptionTypeInvalid): + return oauth2.ErrInvalidTokenFormat.WithDebugf("Token %sis expected to be encrypted with the 'typ' header value '%s' but it was encrypted with the 'typ' header value '%s'.", clientText, consts.JSONWebTokenTypeJWT, token.HeaderJWE[consts.JSONWebTokenHeaderType]) + case errJWTValidation.Has(jwt.ValidationErrorHeaderContentTypeInvalidMismatch): + return oauth2.ErrInvalidTokenFormat.WithDebugf("Token %sis expected to be encrypted with a 'cty' header value and signed with a 'typ' value that match but it was encrypted with the 'cty' header value '%s' and signed with the 'typ' header value '%s'.", clientText, token.HeaderJWE[consts.JSONWebTokenHeaderContentType], token.HeaderJWE[consts.JSONWebTokenHeaderType]) + case errJWTValidation.Has(jwt.ValidationErrorHeaderContentTypeInvalid): + return oauth2.ErrInvalidTokenFormat.WithDebugf("Token %sis expected to be encrypted with the 'cty' header value '%s' but it was encrypted with the 'cty' header value '%s'.", clientText, consts.JSONWebTokenTypeJWT, token.HeaderJWE[consts.JSONWebTokenHeaderContentType]) + case errJWTValidation.Has(jwt.ValidationErrorHeaderEncryptionKeyIDInvalid): + return oauth2.ErrInvalidTokenFormat.WithDebugf("Token %sis expected to be encrypted with the 'kid' header value '%s' but it was encrypted with the 'kid' header value '%s'.", clientText, encKID, token.EncryptionKeyID) + case errJWTValidation.Has(jwt.ValidationErrorHeaderKeyAlgorithmInvalid): + return oauth2.ErrInvalidTokenFormat.WithDebugf("Token %sis expected to be encrypted with the 'alg' header value '%s' but it was encrypted with the 'alg' header value '%s'.", clientText, encAlg, token.KeyAlgorithm) + case errJWTValidation.Has(jwt.ValidationErrorHeaderContentEncryptionInvalid): + return oauth2.ErrInvalidTokenFormat.WithDebugf("Token %sis expected to be encrypted with the 'enc' header value '%s' but it was encrypted with the 'enc' header value '%s'.", clientText, enc, token.ContentEncryption) + case errJWTValidation.Has(jwt.ValidationErrorMalformedNotCompactSerialized): + return oauth2.ErrInvalidTokenFormat.WithDebugf("Token %sis malformed. The token does not appear to be a JWE or JWS compact serialized JWT.", clientText) + case errJWTValidation.Has(jwt.ValidationErrorMalformed): + return oauth2.ErrInvalidTokenFormat.WithDebugf("Token %sis malformed. %s.", clientText, strings.TrimPrefix(errJWTValidation.Error(), "go-jose/go-jose: ")) + case errJWTValidation.Has(jwt.ValidationErrorUnverifiable): + return oauth2.ErrTokenSignatureMismatch.WithDebugf("Token %sis not able to be verified. %s.", clientText, strings.TrimPrefix(errJWTValidation.Error(), "go-jose/go-jose: ")) + case errJWTValidation.Has(jwt.ValidationErrorSignatureInvalid): + return oauth2.ErrTokenSignatureMismatch.WithDebugf("Token %shas an invalid signature.", clientText) + case errJWTValidation.Has(jwt.ValidationErrorExpired): + if date, err = token.Claims.GetExpirationTime(); err == nil { + return oauth2.ErrTokenExpired.WithDebugf("Token %sexpired at %d.", clientText, date.Int64()) + } else { + return oauth2.ErrTokenExpired.WithDebugf("Token %sdoes not have an 'exp' claim or it has an invalid type.", clientText) + } + case errJWTValidation.Has(jwt.ValidationErrorIssuedAt): + if date, err = token.Claims.GetIssuedAt(); err == nil { + return oauth2.ErrTokenClaim.WithDebugf("Token %sis issued in the future. The token was issued at %d.", clientText, date.Int64()) + } else { + return oauth2.ErrTokenClaim.WithDebugf("Token %sis issued in the future. The token does not have an 'iat' claim or it has an invalid type.", clientText) + } + case errJWTValidation.Has(jwt.ValidationErrorNotValidYet): + if date, err = token.Claims.GetNotBefore(); err == nil { + return oauth2.ErrTokenClaim.WithDebugf("Token %sis not valid yet. The token is not valid before %d.", clientText, date.Int64()) + } else { + return oauth2.ErrTokenClaim.WithDebugf("Token %sis not valid yet. The token does not have an 'nbf' claim or it has an invalid type.", clientText) + } + case errJWTValidation.Has(jwt.ValidationErrorIssuer): + var iss string + + if iss, err = token.Claims.GetIssuer(); err == nil { + return oauth2.ErrTokenClaim.WithDebugf("Token %shas an invalid issuer. The token was expected to have an 'iss' claim with one of the following values: ''. The 'iss' claim has a value of '%s'.", clientText, iss) + } else { + return oauth2.ErrTokenClaim.WithDebugf("Token %shas an invalid issuer. The token does not have an 'iss' claim or it has an invalid type.", clientText) + } + case errJWTValidation.Has(jwt.ValidationErrorAudience): + var aud jwt.ClaimStrings + + if aud, err = token.Claims.GetAudience(); err == nil { + return oauth2.ErrTokenClaim.WithDebugf("Token %shas an invalid audience. The token was expected to have an 'iss' claim with one of the following values: ''. The 'aud' claim has a value of '%s'.", clientText, aud) + } else { + return oauth2.ErrTokenClaim.WithDebugf("Token %shas an invalid audience. The token does not have an 'aud' claim or it has an invalid type.", clientText) + } + case errJWTValidation.Has(jwt.ValidationErrorClaimsInvalid): + return oauth2.ErrTokenClaim.WithDebugf("Token %shas invalid claims. Error occurred trying to validate the request objects claims: %s", clientText, strings.TrimPrefix(errJWTValidation.Error(), "go-jose/go-jose: ")) + default: + return oauth2.ErrTokenClaim.WithDebugf("Token %scould not be validated. Error occurred trying to validate the token: %s", clientText, strings.TrimPrefix(errJWTValidation.Error(), "go-jose/go-jose: ")) + } + } else if errJWKLookup := new(jwt.JWKLookupError); errors.As(inner, &errJWKLookup) { + return oauth2.ErrRequestUnauthorized.WithDebugf("Token %scould not be validated due to a key lookup error. %s.", clientText, errJWKLookup.Description) + } else { + return oauth2.ErrRequestUnauthorized.WithDebugf("Token %scould not be validated. %s", clientText, oauth2.ErrorToDebugRFC6749Error(inner).Error()) } } diff --git a/handler/oauth2/strategy_jwt_profile_test.go b/handler/oauth2/strategy_jwt_profile_test.go index c83861ac..ab820c26 100644 --- a/handler/oauth2/strategy_jwt_profile_test.go +++ b/handler/oauth2/strategy_jwt_profile_test.go @@ -23,8 +23,8 @@ import ( var rsaKey = gen.MustRSAKey() -// returns a valid JWT type. The JWTClaims.ExpiresAt time is intentionally -// left empty to ensure it is pulled from the session's ExpiresAt map for +// returns a valid JWT type. The JWTClaims.ExpirationTime time is intentionally +// left empty to ensure it is pulled from the session's ExpirationTime map for // the given oauth2.TokenType. var jwtValidCase = func(tokenType oauth2.TokenType) *oauth2.Request { r := &oauth2.Request{ @@ -55,6 +55,35 @@ var jwtValidCase = func(tokenType oauth2.TokenType) *oauth2.Request { return r } +var jwtInvalidTypCase = func(tokenType oauth2.TokenType) *oauth2.Request { + r := &oauth2.Request{ + Client: &oauth2.DefaultClient{ + ClientSecret: mustNewBCryptClientSecretPlain("foobarfoobarfoobarfoobar"), + }, + Session: &JWTSession{ + JWTClaims: &jwt.JWTClaims{ + Issuer: "oauth2", + Subject: "peter", + IssuedAt: time.Now().UTC(), + NotBefore: time.Now().UTC(), + Extra: map[string]any{"foo": "bar"}, + }, + JWTHeader: &jwt.Headers{ + Extra: map[string]any{consts.JSONWebTokenHeaderType: consts.JSONWebTokenTypeJWT}, + }, + ExpiresAt: map[oauth2.TokenType]time.Time{ + tokenType: time.Now().UTC().Add(time.Hour), + }, + }, + } + r.SetRequestedScopes([]string{consts.ScopeEmail, consts.ScopeOffline}) + r.GrantScope(consts.ScopeEmail) + r.GrantScope(consts.ScopeOffline) + r.SetRequestedAudience([]string{"group0"}) + r.GrantAudience("group0") + return r +} + var jwtValidCaseWithZeroRefreshExpiry = func(tokenType oauth2.TokenType) *oauth2.Request { r := &oauth2.Request{ Client: &oauth2.DefaultClient{ @@ -103,7 +132,7 @@ var jwtValidCaseWithRefreshExpiry = func(tokenType oauth2.TokenType) *oauth2.Req }, ExpiresAt: map[oauth2.TokenType]time.Time{ tokenType: time.Now().UTC().Add(time.Hour), - oauth2.RefreshToken: time.Now().UTC().Add(time.Hour * 2).Round(time.Hour), + oauth2.RefreshToken: time.Now().UTC().Add(time.Hour * 2).Truncate(time.Hour), }, }, } @@ -115,10 +144,10 @@ var jwtValidCaseWithRefreshExpiry = func(tokenType oauth2.TokenType) *oauth2.Req return r } -// returns an expired JWT type. The JWTClaims.ExpiresAt time is intentionally -// left empty to ensure it is pulled from the session's ExpiresAt map for +// returns an expired JWT type. The JWTClaims.ExpirationTime time is intentionally +// left empty to ensure it is pulled from the session's ExpirationTime map for // the given oauth2.TokenType. -var jwtExpiredCase = func(tokenType oauth2.TokenType) *oauth2.Request { +var jwtExpiredCase = func(tokenType oauth2.TokenType, now time.Time) *oauth2.Request { r := &oauth2.Request{ Client: &oauth2.DefaultClient{ ClientSecret: mustNewBCryptClientSecretPlain("foobarfoobarfoobarfoobar"), @@ -127,16 +156,16 @@ var jwtExpiredCase = func(tokenType oauth2.TokenType) *oauth2.Request { JWTClaims: &jwt.JWTClaims{ Issuer: "oauth2", Subject: "peter", - IssuedAt: time.Now().UTC().Add(-time.Minute * 10), - NotBefore: time.Now().UTC().Add(-time.Minute * 10), - ExpiresAt: time.Now().UTC().Add(-time.Minute), + IssuedAt: now.UTC().Add(-time.Minute * 10), + NotBefore: now.UTC().Add(-time.Minute * 10), + ExpiresAt: now.UTC().Add(-time.Minute), Extra: map[string]any{"foo": "bar"}, }, JWTHeader: &jwt.Headers{ Extra: make(map[string]any), }, ExpiresAt: map[oauth2.TokenType]time.Time{ - tokenType: time.Now().UTC().Add(-time.Hour), + tokenType: now.UTC().Add(-time.Hour), }, }, } @@ -163,7 +192,7 @@ func TestAccessToken(t *testing.T) { pass: true, }, { - r: jwtExpiredCase(oauth2.AccessToken), + r: jwtExpiredCase(oauth2.AccessToken, time.Unix(1726972738, 0)), pass: false, }, { @@ -176,19 +205,18 @@ func TestAccessToken(t *testing.T) { }, } { t.Run(fmt.Sprintf("case=%d/%d", s, k), func(t *testing.T) { - signer := &jwt.DefaultSigner{ - GetPrivateKey: func(_ context.Context) (any, error) { - return rsaKey, nil - }, - } - config := &oauth2.Config{ EnforceJWTProfileAccessTokens: true, GlobalSecret: []byte("foofoofoofoofoofoofoofoofoofoofoo"), JWTScopeClaimKey: scopeField, } - strategy := NewCoreStrategy(config, "authelia_%s_", signer) + jwtStrategy := &jwt.DefaultStrategy{ + Config: config, + Issuer: jwt.NewDefaultIssuerRS256Unverified(rsaKey), + } + + strategy := NewCoreStrategy(config, "authelia_%s_", jwtStrategy) token, signature, err := strategy.GenerateAccessToken(context.TODO(), c.r) assert.NoError(t, err) @@ -221,7 +249,7 @@ func TestAccessToken(t *testing.T) { require.NoError(t, json.Unmarshal(rawHeader, &header)) - assert.Equal(t, jwt.JWTHeaderTypeValueAccessTokenJWT, header[jwt.JWTHeaderKeyValueType]) + assert.Equal(t, consts.JSONWebTokenTypeAccessToken, header[consts.JSONWebTokenHeaderType]) extraClaimsSession, ok := c.r.GetSession().(oauth2.ExtraClaimsSession) require.True(t, ok) diff --git a/handler/oauth2/strategy_jwt_session.go b/handler/oauth2/strategy_jwt_session.go index 54ad57d0..9ec9e155 100644 --- a/handler/oauth2/strategy_jwt_session.go +++ b/handler/oauth2/strategy_jwt_session.go @@ -9,6 +9,7 @@ import ( "github.com/mohae/deepcopy" "authelia.com/provider/oauth2" + "authelia.com/provider/oauth2/internal/consts" "authelia.com/provider/oauth2/token/jwt" ) @@ -42,11 +43,11 @@ func (j *JWTSession) GetJWTHeader() *jwt.Headers { if j.JWTHeader == nil { j.JWTHeader = &jwt.Headers{ Extra: map[string]any{ - jwt.JWTHeaderKeyValueType: jwt.JWTHeaderTypeValueAccessTokenJWT, + consts.JSONWebTokenHeaderType: consts.JSONWebTokenTypeAccessToken, }, } - } else if j.JWTHeader.Extra[jwt.JWTHeaderKeyValueType] == nil { - j.JWTHeader.Extra[jwt.JWTHeaderKeyValueType] = jwt.JWTHeaderTypeValueAccessTokenJWT + } else if j.JWTHeader.Extra[consts.JSONWebTokenHeaderType] == nil { + j.JWTHeader.Extra[consts.JSONWebTokenHeaderType] = consts.JSONWebTokenTypeAccessToken } return j.JWTHeader diff --git a/handler/openid/flow_device_authorization_test.go b/handler/openid/flow_device_authorization_test.go index 3c5ce5ff..b23309b3 100644 --- a/handler/openid/flow_device_authorization_test.go +++ b/handler/openid/flow_device_authorization_test.go @@ -27,13 +27,15 @@ func TestOpenIDConnectDeviceAuthorizeHandler_PopulateRFC8628UserAuthorizeEndpoin AuthorizeCodeLifespan: time.Minute * 24, RFC8628CodeLifespan: time.Minute * 24, } - j := &DefaultStrategy{ - Signer: &jwt.DefaultSigner{ - GetPrivateKey: func(ctx context.Context) (any, error) { - return key, nil - }, - }, + + jwtStrategy := &jwt.DefaultStrategy{ Config: config, + Issuer: jwt.NewDefaultIssuerRS256Unverified(key), + } + + j := &DefaultStrategy{ + Strategy: jwtStrategy, + Config: config, } oidcStore := mock.NewMockOpenIDConnectRequestStorage(ctrl) @@ -41,7 +43,7 @@ func TestOpenIDConnectDeviceAuthorizeHandler_PopulateRFC8628UserAuthorizeEndpoin handler := &OpenIDConnectDeviceAuthorizeHandler{ OpenIDConnectRequestStorage: oidcStore, - OpenIDConnectRequestValidator: NewOpenIDConnectRequestValidator(j.Signer, config), + OpenIDConnectRequestValidator: NewOpenIDConnectRequestValidator(j.Strategy, config), CodeTokenEndpointHandler: tokenHandler, Config: config, IDTokenHandleHelper: &IDTokenHandleHelper{ @@ -128,13 +130,15 @@ func TestOpenIDConnectDeviceAuthorizeHandler_PopulateTokenEndpointResponse(t *te AuthorizeCodeLifespan: time.Minute * 24, RFC8628CodeLifespan: time.Minute * 24, } - j := &DefaultStrategy{ - Signer: &jwt.DefaultSigner{ - GetPrivateKey: func(ctx context.Context) (any, error) { - return key, nil - }, - }, + + jwtStrategy := &jwt.DefaultStrategy{ Config: config, + Issuer: jwt.NewDefaultIssuerRS256Unverified(key), + } + + j := &DefaultStrategy{ + Strategy: jwtStrategy, + Config: config, } oidcStore := mock.NewMockOpenIDConnectRequestStorage(ctrl) @@ -142,7 +146,7 @@ func TestOpenIDConnectDeviceAuthorizeHandler_PopulateTokenEndpointResponse(t *te handler := &OpenIDConnectDeviceAuthorizeHandler{ OpenIDConnectRequestStorage: oidcStore, - OpenIDConnectRequestValidator: NewOpenIDConnectRequestValidator(j.Signer, config), + OpenIDConnectRequestValidator: NewOpenIDConnectRequestValidator(j.Strategy, config), CodeTokenEndpointHandler: tokenHandler, Config: config, IDTokenHandleHelper: &IDTokenHandleHelper{ @@ -165,7 +169,7 @@ func TestOpenIDConnectDeviceAuthorizeHandler_PopulateTokenEndpointResponse(t *te setup: func() { sess := &DefaultSession{ Claims: &jwt.IDTokenClaims{ - RequestedAt: time.Now().UTC(), + RequestedAt: jwt.Now(), Subject: "foobar", }, Headers: &jwt.Headers{}, @@ -293,7 +297,7 @@ func TestOpenIDConnectDeviceAuthorizeHandler_PopulateTokenEndpointResponse(t *te setup: func() { sess := &DefaultSession{ Claims: &jwt.IDTokenClaims{ - RequestedAt: time.Now().UTC(), + RequestedAt: jwt.Now(), Subject: "", }, Headers: &jwt.Headers{}, diff --git a/handler/openid/flow_explicit_auth_test.go b/handler/openid/flow_explicit_auth_test.go index d007dd47..4e0d3b1a 100644 --- a/handler/openid/flow_explicit_auth_test.go +++ b/handler/openid/flow_explicit_auth_test.go @@ -120,13 +120,14 @@ func makeOpenIDConnectExplicitHandler(ctrl *gomock.Controller, minParameterEntro store := mock.NewMockOpenIDConnectRequestStorage(ctrl) config := &oauth2.Config{MinParameterEntropy: minParameterEntropy} - var j = &DefaultStrategy{ - Signer: &jwt.DefaultSigner{ - GetPrivateKey: func(ctx context.Context) (any, error) { - return key, nil - }, - }, + jwtStrategy := &jwt.DefaultStrategy{ Config: config, + Issuer: jwt.NewDefaultIssuerRS256Unverified(key), + } + + var j = &DefaultStrategy{ + Strategy: jwtStrategy, + Config: config, } return OpenIDConnectExplicitHandler{ @@ -134,7 +135,7 @@ func makeOpenIDConnectExplicitHandler(ctrl *gomock.Controller, minParameterEntro IDTokenHandleHelper: &IDTokenHandleHelper{ IDTokenStrategy: j, }, - OpenIDConnectRequestValidator: NewOpenIDConnectRequestValidator(j.Signer, config), + OpenIDConnectRequestValidator: NewOpenIDConnectRequestValidator(j.Strategy, config), Config: config, }, store } diff --git a/handler/openid/flow_explicit_token_test.go b/handler/openid/flow_explicit_token_test.go index 1af76df9..3993fb42 100644 --- a/handler/openid/flow_explicit_token_test.go +++ b/handler/openid/flow_explicit_token_test.go @@ -113,7 +113,7 @@ func TestExplicit_PopulateTokenEndpointResponse(t *testing.T) { return key.PublicKey, nil }) require.NoError(t, err) - claims := decodedIdToken.Claims + claims := decodedIdToken.Claims.ToMapClaims() assert.NotEmpty(t, claims["at_hash"]) idTokenExp := internal.ExtractJwtExpClaim(t, idToken) internal.RequireEqualTime(t, time.Now().Add(*internal.TestLifespans.AuthorizationCodeGrantIDTokenLifespan).UTC(), *idTokenExp, time.Minute) @@ -144,7 +144,7 @@ func TestExplicit_PopulateTokenEndpointResponse(t *testing.T) { return key.PublicKey, nil }) require.NoError(t, err) - claims := decodedIdToken.Claims + claims := decodedIdToken.Claims.ToMapClaims() assert.NotEmpty(t, claims["at_hash"]) idTokenExp := internal.ExtractJwtExpClaim(t, idToken) internal.RequireEqualTime(t, time.Now().Add(time.Hour), *idTokenExp, time.Minute) @@ -211,15 +211,18 @@ func TestExplicit_PopulateTokenEndpointResponse(t *testing.T) { aresp := oauth2.NewAccessResponse() areq := oauth2.NewAccessRequest(session) + config := &oauth2.Config{ + MinParameterEntropy: oauth2.MinParameterEntropy, + } + + jwtStrategy := &jwt.DefaultStrategy{ + Config: config, + Issuer: jwt.NewDefaultIssuerRS256Unverified(key), + } + var j = &DefaultStrategy{ - Signer: &jwt.DefaultSigner{ - GetPrivateKey: func(ctx context.Context) (any, error) { - return key, nil - }, - }, - Config: &oauth2.Config{ - MinParameterEntropy: oauth2.MinParameterEntropy, - }, + Strategy: jwtStrategy, + Config: config, } h := &OpenIDConnectExplicitHandler{ diff --git a/handler/openid/flow_hybrid.go b/handler/openid/flow_hybrid.go index 4d7dd0c8..4c9915bc 100644 --- a/handler/openid/flow_hybrid.go +++ b/handler/openid/flow_hybrid.go @@ -21,7 +21,7 @@ type OpenIDConnectHybridHandler struct { OpenIDConnectRequestValidator *OpenIDConnectRequestValidator OpenIDConnectRequestStorage OpenIDConnectRequestStorage - Enigma *jwt.DefaultSigner + Enigma *jwt.DefaultStrategy Config interface { oauth2.IDTokenLifespanProvider @@ -132,11 +132,11 @@ func (c *OpenIDConnectHybridHandler) HandleAuthorizeEndpointRequest(ctx context. // sets the proper access/refresh token lifetimes. // // if c.AuthorizeExplicitGrantHandler.RefreshTokenLifespan > -1 { - // requester.GetSession().SetExpiresAt(oauth2.RefreshToken, time.Now().UTC().Add(c.AuthorizeExplicitGrantHandler.RefreshTokenLifespan).Round(time.Second)) + // requester.GetSession().SetExpiresAt(oauth2.RefreshToken, time.Now().UTC().Add(c.AuthorizeExplicitGrantHandler.RefreshTokenLifespan).Truncate(jwt.TimePrecision)) // } // This is required because we must limit the authorize code lifespan. - requester.GetSession().SetExpiresAt(oauth2.AuthorizeCode, time.Now().UTC().Add(c.AuthorizeExplicitGrantHandler.Config.GetAuthorizeCodeLifespan(ctx)).Round(time.Second)) + requester.GetSession().SetExpiresAt(oauth2.AuthorizeCode, time.Now().UTC().Add(c.AuthorizeExplicitGrantHandler.Config.GetAuthorizeCodeLifespan(ctx)).Truncate(jwt.TimePrecision)) if err = c.AuthorizeExplicitGrantHandler.CoreStorage.CreateAuthorizeCodeSession(ctx, signature, requester.Sanitize(c.AuthorizeExplicitGrantHandler.GetSanitationWhiteList(ctx))); err != nil { return errorsx.WithStack(oauth2.ErrServerError.WithWrap(err).WithDebugError(err)) diff --git a/handler/openid/flow_hybrid_test.go b/handler/openid/flow_hybrid_test.go index cef1aa76..6413e4b2 100644 --- a/handler/openid/flow_hybrid_test.go +++ b/handler/openid/flow_hybrid_test.go @@ -9,7 +9,6 @@ import ( "testing" "time" - xjwt "github.com/golang-jwt/jwt/v5" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "go.uber.org/mock/gomock" @@ -18,7 +17,6 @@ import ( hoauth2 "authelia.com/provider/oauth2/handler/oauth2" "authelia.com/provider/oauth2/internal" "authelia.com/provider/oauth2/internal/consts" - "authelia.com/provider/oauth2/internal/gen" "authelia.com/provider/oauth2/storage" "authelia.com/provider/oauth2/token/hmac" "authelia.com/provider/oauth2/token/jwt" @@ -191,11 +189,9 @@ func TestHybrid_HandleAuthorizeEndpointRequest(t *testing.T) { assert.NotEmpty(t, idToken) assert.True(t, request.GetSession().GetExpiresAt(oauth2.IDToken).IsZero()) - parser := xjwt.NewParser() + claims := &jwt.IDTokenClaims{} - claims := &IDTokenClaims{} - - _, _, err := parser.ParseUnverified(idToken, claims) + _, err := jwt.UnsafeParseSignedAny(idToken, claims) require.NoError(t, err) assert.Equal(t, "MvmJNOT-fq6rnnnrUTC_2A", claims.StateHash) @@ -338,14 +334,12 @@ func TestHybrid_HandleAuthorizeEndpointRequest(t *testing.T) { assert.NotEmpty(t, idToken) assert.True(t, request.GetSession().GetExpiresAt(oauth2.IDToken).IsZero()) - parser := xjwt.NewParser() - - claims := &IDTokenClaims{} - - _, _, err := parser.ParseUnverified(idToken, claims) + claims := &jwt.IDTokenClaims{} + _, err := jwt.UnsafeParseSignedAny(idToken, claims) require.NoError(t, err) - internal.RequireEqualTime(t, time.Now().Add(*internal.TestLifespans.ImplicitGrantIDTokenLifespan), claims.ExpiresAt.Time, time.Minute) + + internal.RequireEqualTime(t, time.Now().Add(*internal.TestLifespans.ImplicitGrantIDTokenLifespan), claims.GetExpirationTimeSafe(), time.Minute) assert.NotEmpty(t, claims.CodeHash) assert.Empty(t, claims.StateHash) @@ -416,13 +410,6 @@ func TestHybrid_HandleAuthorizeEndpointRequest(t *testing.T) { } } -type IDTokenClaims struct { - StateHash string `json:"s_hash"` - CodeHash string `json:"c_hash"` - - xjwt.RegisteredClaims -} - var hmacStrategy = &hoauth2.HMACCoreStrategy{ Enigma: &hmac.HMACStrategy{ Config: &oauth2.Config{ @@ -432,35 +419,34 @@ var hmacStrategy = &hoauth2.HMACCoreStrategy{ } func makeOpenIDConnectHybridHandler(minParameterEntropy int) OpenIDConnectHybridHandler { + config := &oauth2.Config{ + ScopeStrategy: oauth2.HierarchicScopeStrategy, + MinParameterEntropy: minParameterEntropy, + AccessTokenLifespan: time.Hour, + AuthorizeCodeLifespan: time.Hour, + RefreshTokenLifespan: time.Hour, + } + + jwtStrategy := &jwt.DefaultStrategy{ + Config: config, + Issuer: jwt.NewDefaultIssuerRS256Unverified(key), + } + var idStrategy = &DefaultStrategy{ - Signer: &jwt.DefaultSigner{ - GetPrivateKey: func(_ context.Context) (any, error) { - return gen.MustRSAKey(), nil - }, - }, - Config: &oauth2.Config{ - MinParameterEntropy: minParameterEntropy, + Strategy: &jwt.DefaultStrategy{ + Config: config, + Issuer: jwt.MustGenDefaultIssuer(), }, + Config: config, } var j = &DefaultStrategy{ - Signer: &jwt.DefaultSigner{ - GetPrivateKey: func(_ context.Context) (any, error) { - return key, nil - }, - }, + Strategy: jwtStrategy, Config: &oauth2.Config{ MinParameterEntropy: minParameterEntropy, }, } - config := &oauth2.Config{ - ScopeStrategy: oauth2.HierarchicScopeStrategy, - MinParameterEntropy: minParameterEntropy, - AccessTokenLifespan: time.Hour, - AuthorizeCodeLifespan: time.Hour, - RefreshTokenLifespan: time.Hour, - } return OpenIDConnectHybridHandler{ AuthorizeExplicitGrantHandler: &hoauth2.AuthorizeExplicitGrantHandler{ AuthorizeCodeStrategy: hmacStrategy, @@ -479,7 +465,7 @@ func makeOpenIDConnectHybridHandler(minParameterEntropy int) OpenIDConnectHybrid IDTokenStrategy: idStrategy, }, Config: config, - OpenIDConnectRequestValidator: NewOpenIDConnectRequestValidator(j.Signer, config), + OpenIDConnectRequestValidator: NewOpenIDConnectRequestValidator(j.Strategy, config), OpenIDConnectRequestStorage: storage.NewMemoryStore(), } } diff --git a/handler/openid/flow_implicit.go b/handler/openid/flow_implicit.go index 7fe3c47a..e46006f5 100644 --- a/handler/openid/flow_implicit.go +++ b/handler/openid/flow_implicit.go @@ -18,7 +18,7 @@ type OpenIDConnectImplicitHandler struct { AuthorizeImplicitGrantTypeHandler *hoauth2.AuthorizeImplicitGrantTypeHandler OpenIDConnectRequestValidator *OpenIDConnectRequestValidator - RS256JWTStrategy *jwt.DefaultSigner + RS256JWTStrategy *jwt.DefaultStrategy Config interface { oauth2.IDTokenLifespanProvider diff --git a/handler/openid/flow_implicit_test.go b/handler/openid/flow_implicit_test.go index 81350df9..bfd4e80b 100644 --- a/handler/openid/flow_implicit_test.go +++ b/handler/openid/flow_implicit_test.go @@ -17,7 +17,6 @@ import ( hoauth2 "authelia.com/provider/oauth2/handler/oauth2" "authelia.com/provider/oauth2/internal" "authelia.com/provider/oauth2/internal/consts" - "authelia.com/provider/oauth2/internal/gen" "authelia.com/provider/oauth2/storage" "authelia.com/provider/oauth2/token/jwt" ) @@ -30,19 +29,17 @@ func makeOpenIDConnectImplicitHandler(minParameterEntropy int) OpenIDConnectImpl } var idStrategy = &DefaultStrategy{ - Signer: &jwt.DefaultSigner{ - GetPrivateKey: func(ctx context.Context) (any, error) { - return gen.MustRSAKey(), nil - }, + Strategy: &jwt.DefaultStrategy{ + Config: config, + Issuer: jwt.NewDefaultIssuerRS256Unverified(key), }, Config: config, } var j = &DefaultStrategy{ - Signer: &jwt.DefaultSigner{ - GetPrivateKey: func(ctx context.Context) (any, error) { - return key, nil - }, + Strategy: &jwt.DefaultStrategy{ + Config: config, + Issuer: jwt.NewDefaultIssuerRS256Unverified(key), }, Config: config, } @@ -56,7 +53,7 @@ func makeOpenIDConnectImplicitHandler(minParameterEntropy int) OpenIDConnectImpl IDTokenHandleHelper: &IDTokenHandleHelper{ IDTokenStrategy: idStrategy, }, - OpenIDConnectRequestValidator: NewOpenIDConnectRequestValidator(j.Signer, config), + OpenIDConnectRequestValidator: NewOpenIDConnectRequestValidator(j.Strategy, config), Config: config, } } diff --git a/handler/openid/flow_refresh_token.go b/handler/openid/flow_refresh_token.go index 80563849..97699187 100644 --- a/handler/openid/flow_refresh_token.go +++ b/handler/openid/flow_refresh_token.go @@ -12,6 +12,7 @@ import ( "authelia.com/provider/oauth2" "authelia.com/provider/oauth2/internal/consts" + "authelia.com/provider/oauth2/token/jwt" "authelia.com/provider/oauth2/x/errorsx" ) @@ -49,7 +50,7 @@ func (c *OpenIDConnectRefreshHandler) HandleTokenEndpointRequest(ctx context.Con } // We need to reset the expires at value as this would be the previous expiry. - sess.IDTokenClaims().ExpiresAt = time.Time{} + sess.IDTokenClaims().ExpirationTime = jwt.NewNumericDate(time.Time{}) // These will be recomputed in PopulateTokenEndpointResponse sess.IDTokenClaims().JTI = "" @@ -92,7 +93,7 @@ func (c *OpenIDConnectRefreshHandler) PopulateTokenEndpointResponse(ctx context. claims.AccessTokenHash = c.GetAccessTokenHash(ctx, requester, responder) claims.JTI = uuid.New().String() claims.CodeHash = "" - claims.IssuedAt = time.Now().Truncate(time.Second) + claims.IssuedAt = jwt.Now() idTokenLifespan := oauth2.GetEffectiveLifespan(requester.GetClient(), oauth2.GrantTypeRefreshToken, oauth2.IDToken, c.Config.GetIDTokenLifespan(ctx)) return c.IssueExplicitIDToken(ctx, idTokenLifespan, requester, responder) diff --git a/handler/openid/flow_refresh_token_test.go b/handler/openid/flow_refresh_token_test.go index 333bfbdb..77e0bc0b 100644 --- a/handler/openid/flow_refresh_token_test.go +++ b/handler/openid/flow_refresh_token_test.go @@ -82,11 +82,12 @@ func TestOpenIDConnectRefreshHandler_HandleTokenEndpointRequest(t *testing.T) { } func TestOpenIDConnectRefreshHandler_PopulateTokenEndpointResponse(t *testing.T) { + config := &oauth2.Config{} + var j = &DefaultStrategy{ - Signer: &jwt.DefaultSigner{ - GetPrivateKey: func(ctx context.Context) (any, error) { - return key, nil - }, + Strategy: &jwt.DefaultStrategy{ + Config: config, + Issuer: jwt.NewDefaultIssuerRS256Unverified(key), }, Config: &oauth2.Config{ MinParameterEntropy: oauth2.MinParameterEntropy, @@ -97,7 +98,7 @@ func TestOpenIDConnectRefreshHandler_PopulateTokenEndpointResponse(t *testing.T) IDTokenHandleHelper: &IDTokenHandleHelper{ IDTokenStrategy: j, }, - Config: &oauth2.Config{}, + Config: config, } for _, c := range []struct { areq *oauth2.AccessRequest @@ -146,7 +147,7 @@ func TestOpenIDConnectRefreshHandler_PopulateTokenEndpointResponse(t *testing.T) return key.PublicKey, nil }) require.NoError(t, err) - claims := decodedIdToken.Claims + claims := decodedIdToken.Claims.ToMapClaims() assert.NotEmpty(t, claims[consts.ClaimAccessTokenHash]) idTokenExp := internal.ExtractJwtExpClaim(t, idToken) require.NotEmpty(t, idTokenExp) @@ -181,7 +182,7 @@ func TestOpenIDConnectRefreshHandler_PopulateTokenEndpointResponse(t *testing.T) return key.PublicKey, nil }) require.NoError(t, err) - claims := decodedIdToken.Claims + claims := decodedIdToken.Claims.ToMapClaims() assert.NotEmpty(t, claims[consts.ClaimAccessTokenHash]) idTokenExp := internal.ExtractJwtExpClaim(t, idToken) require.NotEmpty(t, idTokenExp) diff --git a/handler/openid/helper_test.go b/handler/openid/helper_test.go index f72c3a6f..75b79e30 100644 --- a/handler/openid/helper_test.go +++ b/handler/openid/helper_test.go @@ -15,16 +15,16 @@ import ( "authelia.com/provider/oauth2" "authelia.com/provider/oauth2/internal/consts" - "authelia.com/provider/oauth2/internal/gen" "authelia.com/provider/oauth2/testing/mock" "authelia.com/provider/oauth2/token/jwt" ) var strategy = &DefaultStrategy{ - Signer: &jwt.DefaultSigner{ - GetPrivateKey: func(_ context.Context) (any, error) { - return gen.MustRSAKey(), nil + Strategy: &jwt.DefaultStrategy{ + Config: &oauth2.Config{ + MinParameterEntropy: oauth2.MinParameterEntropy, }, + Issuer: jwt.MustGenDefaultIssuer(), }, Config: &oauth2.Config{ MinParameterEntropy: oauth2.MinParameterEntropy, @@ -74,7 +74,6 @@ func TestGenerateIDToken(t *testing.T) { if err == nil { assert.NotEmpty(t, token, "(%d) %s", k, c.description) } - t.Logf("Passed test case %d", k) } } diff --git a/handler/openid/strategy_jwt.go b/handler/openid/strategy_jwt.go index d573e17a..52802788 100644 --- a/handler/openid/strategy_jwt.go +++ b/handler/openid/strategy_jwt.go @@ -43,7 +43,7 @@ type DefaultSession struct { func NewDefaultSession() *DefaultSession { return &DefaultSession{ Claims: &jwt.IDTokenClaims{ - RequestedAt: time.Now().UTC(), + RequestedAt: jwt.Now(), }, Headers: &jwt.Headers{}, } @@ -110,7 +110,7 @@ func (s *DefaultSession) IDTokenClaims() *jwt.IDTokenClaims { } type DefaultStrategy struct { - jwt.Signer + jwt.Strategy Config interface { oauth2.IDTokenIssuerProvider @@ -121,7 +121,7 @@ type DefaultStrategy struct { // GenerateIDToken returns a JWT string. // -// lifespan is ignored if requester.GetSession().IDTokenClaims().ExpiresAt is not zero. +// lifespan is ignored if requester.GetSession().IDTokenClaims().ExpirationTime is not zero. // // TODO: Refactor time permitting. // @@ -141,45 +141,48 @@ func (h DefaultStrategy) GenerateIDToken(ctx context.Context, lifespan time.Dura return "", errorsx.WithStack(oauth2.ErrServerError.WithDebug("Failed to generate id token because subject is an empty string.")) } + jwtClient := jwt.NewIDTokenClient(requester.GetClient()) + if requester.GetRequestForm().Get(consts.FormParameterGrantType) != consts.GrantTypeRefreshToken { - maxAge, err := strconv.ParseInt(requester.GetRequestForm().Get(consts.FormParameterMaximumAge), 10, 64) - if err != nil { + var maxAge int64 + + if maxAge, err = strconv.ParseInt(requester.GetRequestForm().Get(consts.FormParameterMaximumAge), 10, 64); err != nil { maxAge = 0 } // Adds a bit of wiggle room for timing issues - if claims.AuthTime.After(time.Now().UTC().Add(time.Second * 5)) { + if claims.GetAuthTimeSafe().After(time.Now().UTC().Add(time.Second * 5)) { return "", errorsx.WithStack(oauth2.ErrServerError.WithDebug("Failed to validate OpenID Connect request because authentication time is in the future.")) } if maxAge > 0 { switch { - case claims.AuthTime.IsZero(): + case claims.AuthTime == nil, claims.AuthTime.IsZero(): return "", errorsx.WithStack(oauth2.ErrServerError.WithDebug("Failed to generate id token because authentication time claim is required when max_age is set.")) - case claims.RequestedAt.IsZero(): + case claims.RequestedAt == nil, claims.RequestedAt.IsZero(): return "", errorsx.WithStack(oauth2.ErrServerError.WithDebug("Failed to generate id token because requested at claim is required when max_age is set.")) - case claims.AuthTime.Add(time.Second * time.Duration(maxAge)).Before(claims.RequestedAt): + case claims.AuthTime.Add(time.Second * time.Duration(maxAge)).Before(claims.RequestedAt.Time): return "", errorsx.WithStack(oauth2.ErrServerError.WithDebug("Failed to generate id token because authentication time does not satisfy max_age time.")) } } prompt := requester.GetRequestForm().Get(consts.FormParameterPrompt) if prompt != "" { - if claims.AuthTime.IsZero() { + if claims.AuthTime == nil || claims.AuthTime.IsZero() { return "", errorsx.WithStack(oauth2.ErrServerError.WithDebug("Unable to determine validity of prompt parameter because auth_time is missing in id token claims.")) } } switch prompt { case consts.PromptTypeNone: - if !claims.AuthTime.Equal(claims.RequestedAt) && claims.AuthTime.After(claims.RequestedAt) { + if !claims.GetAuthTimeSafe().Equal(claims.GetRequestedAtSafe()) && claims.GetAuthTimeSafe().After(claims.GetRequestedAtSafe()) { return "", errorsx.WithStack(oauth2.ErrServerError. - WithDebugf("Failed to generate id token because prompt was set to 'none' but auth_time ('%s') happened after the authorization request ('%s') was registered, indicating that the user was logged in during this request which is not allowed.", claims.AuthTime, claims.RequestedAt)) + WithDebugf("Failed to generate id token because prompt was set to 'none' but auth_time ('%s') happened after the authorization request ('%s') was registered, indicating that the user was logged in during this request which is not allowed.", claims.GetAuthTimeSafe(), claims.GetRequestedAtSafe())) } case consts.PromptTypeLogin: - if !claims.AuthTime.Equal(claims.RequestedAt) && claims.AuthTime.Before(claims.RequestedAt) { + if !claims.GetAuthTimeSafe().Equal(claims.GetRequestedAtSafe()) && claims.GetAuthTimeSafe().Before(claims.GetRequestedAtSafe()) { return "", errorsx.WithStack(oauth2.ErrServerError. - WithDebugf("Failed to generate id token because prompt was set to 'login' but auth_time ('%s') happened before the authorization request ('%s') was registered, indicating that the user was not re-authenticated which is forbidden.", claims.AuthTime, claims.RequestedAt)) + WithDebugf("Failed to generate id token because prompt was set to 'login' but auth_time ('%s') happened before the authorization request ('%s') was registered, indicating that the user was not re-authenticated which is forbidden.", claims.GetAuthTimeSafe(), claims.GetRequestedAtSafe())) } } @@ -190,7 +193,10 @@ func (h DefaultStrategy) GenerateIDToken(ctx context.Context, lifespan time.Dura } if tokenHintString := requester.GetRequestForm().Get(consts.FormParameterIDTokenHint); tokenHintString != "" { - tokenHint, err := h.Signer.Decode(ctx, tokenHintString) + var tokenHint *jwt.Token + + tokenHint, err = h.Strategy.Decode(ctx, tokenHintString, jwt.WithClient(jwtClient)) + var ve *jwt.ValidationError if errors.As(err, &ve) && ve.Has(jwt.ValidationErrorExpired) { // Expired ID Tokens are allowed as values to id_token_hint @@ -198,24 +204,26 @@ func (h DefaultStrategy) GenerateIDToken(ctx context.Context, lifespan time.Dura return "", errorsx.WithStack(oauth2.ErrServerError.WithWrap(err).WithDebugf("Unable to decode id token from 'id_token_hint' parameter because %s.", err.Error())) } - if hintSub, _ := tokenHint.Claims[consts.ClaimSubject].(string); hintSub == "" { + var subHint string + + if subHint, err = tokenHint.Claims.GetSubject(); subHint == "" || err != nil { return "", errorsx.WithStack(oauth2.ErrServerError.WithDebug("Provided id token from 'id_token_hint' does not have a subject.")) - } else if hintSub != claims.Subject { + } else if subHint != claims.Subject { return "", errorsx.WithStack(oauth2.ErrServerError.WithDebug("Subject from authorization mismatches id token subject from 'id_token_hint'.")) } } } - if claims.ExpiresAt.IsZero() { - claims.ExpiresAt = time.Now().UTC().Add(lifespan) + if claims.ExpirationTime == nil || claims.ExpirationTime.IsZero() { + claims.ExpirationTime = jwt.NewNumericDate(time.Now().Add(lifespan)) } - if claims.ExpiresAt.Before(time.Now().UTC()) { + if claims.ExpirationTime.Before(time.Now().UTC()) { return "", errorsx.WithStack(oauth2.ErrServerError.WithDebug("Failed to generate id token because expiry claim can not be in the past.")) } - if claims.AuthTime.IsZero() { - claims.AuthTime = time.Now().Truncate(time.Second).UTC() + if claims.AuthTime == nil || claims.AuthTime.IsZero() { + claims.AuthTime = jwt.Now() } if claims.Issuer == "" { @@ -232,8 +240,9 @@ func (h DefaultStrategy) GenerateIDToken(ctx context.Context, lifespan time.Dura } claims.Audience = stringslice.Unique(append(claims.Audience, requester.GetClient().GetID())) - claims.IssuedAt = time.Now().UTC() + claims.IssuedAt = jwt.Now() + + token, _, err = h.Strategy.Encode(ctx, claims.ToMapClaims(), jwt.WithHeaders(sess.IDTokenHeaders()), jwt.WithClient(jwtClient)) - token, _, err = h.Signer.Generate(ctx, claims.ToMapClaims(), sess.IDTokenHeaders()) return token, err } diff --git a/handler/openid/strategy_jwt_test.go b/handler/openid/strategy_jwt_test.go index fe6d2ae4..d0c479a3 100644 --- a/handler/openid/strategy_jwt_test.go +++ b/handler/openid/strategy_jwt_test.go @@ -17,14 +17,16 @@ import ( ) func TestJWTStrategy_GenerateIDToken(t *testing.T) { + config := &oauth2.Config{ + MinParameterEntropy: oauth2.MinParameterEntropy, + } + var j = &DefaultStrategy{ - Signer: &jwt.DefaultSigner{ - GetPrivateKey: func(_ context.Context) (any, error) { - return key, nil - }}, - Config: &oauth2.Config{ - MinParameterEntropy: oauth2.MinParameterEntropy, + Strategy: &jwt.DefaultStrategy{ + Config: config, + Issuer: jwt.NewDefaultIssuerRS256Unverified(key), }, + Config: config, } var req *oauth2.AccessRequest @@ -50,8 +52,8 @@ func TestJWTStrategy_GenerateIDToken(t *testing.T) { req = oauth2.NewAccessRequest(&DefaultSession{ Claims: &jwt.IDTokenClaims{ Subject: "peter", - AuthTime: time.Now().UTC(), - RequestedAt: time.Now().UTC(), + AuthTime: jwt.Now(), + RequestedAt: jwt.Now(), }, Headers: &jwt.Headers{}, }) @@ -64,8 +66,8 @@ func TestJWTStrategy_GenerateIDToken(t *testing.T) { setup: func() { req = oauth2.NewAccessRequest(&DefaultSession{ Claims: &jwt.IDTokenClaims{ - Subject: "peter", - ExpiresAt: time.Now().UTC().Add(-time.Hour), + Subject: "peter", + ExpirationTime: jwt.NewNumericDate(time.Now().Add(-time.Hour)), }, Headers: &jwt.Headers{}, }) @@ -113,8 +115,8 @@ func TestJWTStrategy_GenerateIDToken(t *testing.T) { req = oauth2.NewAccessRequest(&DefaultSession{ Claims: &jwt.IDTokenClaims{ Subject: "peter", - AuthTime: time.Now().UTC(), - RequestedAt: time.Now().UTC(), + AuthTime: jwt.Now(), + RequestedAt: jwt.Now(), }, Headers: &jwt.Headers{}, }) @@ -128,7 +130,7 @@ func TestJWTStrategy_GenerateIDToken(t *testing.T) { req = oauth2.NewAccessRequest(&DefaultSession{ Claims: &jwt.IDTokenClaims{ Subject: "peter", - AuthTime: time.Now().Add(-time.Hour).UTC(), + AuthTime: jwt.NewNumericDate(time.Now().Add(-time.Hour)), }, Headers: &jwt.Headers{}, }) @@ -142,8 +144,8 @@ func TestJWTStrategy_GenerateIDToken(t *testing.T) { req = oauth2.NewAccessRequest(&DefaultSession{ Claims: &jwt.IDTokenClaims{ Subject: "peter", - AuthTime: time.Now().UTC(), - RequestedAt: time.Now().Add(-time.Minute), + AuthTime: jwt.Now(), + RequestedAt: jwt.NewNumericDate(time.Now().Add(-time.Minute)), }, Headers: &jwt.Headers{}, }) @@ -157,8 +159,8 @@ func TestJWTStrategy_GenerateIDToken(t *testing.T) { req = oauth2.NewAccessRequest(&DefaultSession{ Claims: &jwt.IDTokenClaims{ Subject: "peter", - AuthTime: time.Now().UTC(), - RequestedAt: time.Now().Add(-time.Minute), + AuthTime: jwt.Now(), + RequestedAt: jwt.NewNumericDate(time.Now().Add(-time.Minute)), }, Headers: &jwt.Headers{}, }) @@ -173,8 +175,8 @@ func TestJWTStrategy_GenerateIDToken(t *testing.T) { req = oauth2.NewAccessRequest(&DefaultSession{ Claims: &jwt.IDTokenClaims{ Subject: "peter", - AuthTime: time.Now().Add(-time.Hour).UTC(), - RequestedAt: time.Now().Add(-time.Minute), + AuthTime: jwt.NewNumericDate(time.Now().Add(-time.Hour)), + RequestedAt: jwt.NewNumericDate(time.Now().Add(-time.Minute)), }, Headers: &jwt.Headers{}, }) @@ -188,8 +190,8 @@ func TestJWTStrategy_GenerateIDToken(t *testing.T) { req = oauth2.NewAccessRequest(&DefaultSession{ Claims: &jwt.IDTokenClaims{ Subject: "peter", - AuthTime: time.Now().UTC(), - RequestedAt: time.Now().Add(-time.Minute), + AuthTime: jwt.Now(), + RequestedAt: jwt.NewNumericDate(time.Now().Add(-time.Minute)), }, Headers: &jwt.Headers{}, }) @@ -203,8 +205,8 @@ func TestJWTStrategy_GenerateIDToken(t *testing.T) { req = oauth2.NewAccessRequest(&DefaultSession{ Claims: &jwt.IDTokenClaims{ Subject: "peter", - AuthTime: time.Now().Add(-time.Hour).UTC(), - RequestedAt: time.Now().Add(-time.Minute), + AuthTime: jwt.NewNumericDate(time.Now().Add(-time.Hour)), + RequestedAt: jwt.NewNumericDate(time.Now().Add(-time.Minute)), }, Headers: &jwt.Headers{}, }) @@ -218,8 +220,8 @@ func TestJWTStrategy_GenerateIDToken(t *testing.T) { req = oauth2.NewAccessRequest(&DefaultSession{ Claims: &jwt.IDTokenClaims{ Subject: "peter", - AuthTime: time.Now().Add(-time.Hour).UTC(), - RequestedAt: time.Now().Add(-time.Minute), + AuthTime: jwt.NewNumericDate(time.Now().Add(-time.Hour)), + RequestedAt: jwt.NewNumericDate(time.Now().Add(-time.Minute)), }, Headers: &jwt.Headers{}, }) @@ -239,15 +241,15 @@ func TestJWTStrategy_GenerateIDToken(t *testing.T) { req = oauth2.NewAccessRequest(&DefaultSession{ Claims: &jwt.IDTokenClaims{ Subject: "peter", - AuthTime: time.Now().Add(-time.Hour).UTC(), - RequestedAt: time.Now().Add(-time.Minute), + AuthTime: jwt.NewNumericDate(time.Now().Add(-time.Hour)), + RequestedAt: jwt.NewNumericDate(time.Now().Add(-time.Minute)), }, Headers: &jwt.Headers{}, }) token, _ := j.GenerateIDToken(context.TODO(), time.Duration(0), oauth2.NewAccessRequest(&DefaultSession{ Claims: &jwt.IDTokenClaims{ - Subject: "peter", - ExpiresAt: time.Now().Add(-time.Hour).UTC(), + Subject: "peter", + ExpirationTime: jwt.NewNumericDate(time.Now().Add(-time.Hour)), }, Headers: &jwt.Headers{}, })) @@ -261,8 +263,8 @@ func TestJWTStrategy_GenerateIDToken(t *testing.T) { req = oauth2.NewAccessRequest(&DefaultSession{ Claims: &jwt.IDTokenClaims{ Subject: "peter", - AuthTime: time.Now().Add(-time.Hour).UTC(), - RequestedAt: time.Now().Add(-time.Minute), + AuthTime: jwt.NewNumericDate(time.Now().Add(-time.Hour)), + RequestedAt: jwt.NewNumericDate(time.Now().Add(-time.Minute)), }, Headers: &jwt.Headers{}, }) diff --git a/handler/openid/validator.go b/handler/openid/validator.go index 8fc147c5..0ad6441a 100644 --- a/handler/openid/validator.go +++ b/handler/openid/validator.go @@ -31,11 +31,11 @@ type openIDConnectRequestValidatorConfigProvider interface { } type OpenIDConnectRequestValidator struct { - Strategy jwt.Signer + Strategy jwt.Strategy Config openIDConnectRequestValidatorConfigProvider } -func NewOpenIDConnectRequestValidator(strategy jwt.Signer, config openIDConnectRequestValidatorConfigProvider) *OpenIDConnectRequestValidator { +func NewOpenIDConnectRequestValidator(strategy jwt.Strategy, config openIDConnectRequestValidatorConfigProvider) *OpenIDConnectRequestValidator { return &OpenIDConnectRequestValidator{ Strategy: strategy, Config: config, @@ -108,34 +108,34 @@ func (v *OpenIDConnectRequestValidator) ValidatePrompt(ctx context.Context, req } // Adds a bit of wiggle room for timing issues - if claims.AuthTime.After(time.Now().UTC().Add(time.Second * 5)) { + if claims.GetAuthTimeSafe().After(time.Now().UTC().Add(time.Second * 5)) { return errorsx.WithStack(oauth2.ErrServerError.WithDebug("Failed to validate OpenID Connect request because authentication time is in the future.")) } if maxAge > 0 { switch { - case claims.AuthTime.IsZero(): + case claims.AuthTime == nil, claims.AuthTime.IsZero(): return errorsx.WithStack(oauth2.ErrServerError.WithDebug("Failed to validate OpenID Connect request because authentication time claim is required when max_age is set.")) - case claims.RequestedAt.IsZero(): + case claims.RequestedAt == nil, claims.RequestedAt.IsZero(): return errorsx.WithStack(oauth2.ErrServerError.WithDebug("Failed to validate OpenID Connect request because requested at claim is required when max_age is set.")) - case claims.AuthTime.Add(time.Second * time.Duration(maxAge)).Before(claims.RequestedAt): + case claims.GetAuthTimeSafe().Add(time.Second * time.Duration(maxAge)).Before(claims.GetRequestedAtSafe()): return errorsx.WithStack(oauth2.ErrLoginRequired.WithDebug("Failed to validate OpenID Connect request because authentication time does not satisfy max_age time.")) } } if stringslice.Has(requiredPrompt, consts.PromptTypeNone) { - if claims.AuthTime.IsZero() { + if claims.AuthTime == nil || claims.AuthTime.IsZero() { return errorsx.WithStack(oauth2.ErrServerError.WithDebug("Failed to validate OpenID Connect request because because auth_time is missing from session.")) } - if !claims.AuthTime.Equal(claims.RequestedAt) && claims.AuthTime.After(claims.RequestedAt) { + if !claims.GetAuthTimeSafe().Equal(claims.GetRequestedAtSafe()) && claims.GetAuthTimeSafe().After(claims.GetRequestedAtSafe()) { // !claims.AuthTime.Truncate(time.Second).Equal(claims.RequestedAt) && claims.AuthTime.Truncate(time.Second).Before(claims.RequestedAt) { - return errorsx.WithStack(oauth2.ErrLoginRequired.WithHintf("Failed to validate OpenID Connect request because prompt was set to 'none' but auth_time ('%s') happened after the authorization request ('%s') was registered, indicating that the user was logged in during this request which is not allowed.", claims.AuthTime, claims.RequestedAt)) + return errorsx.WithStack(oauth2.ErrLoginRequired.WithHintf("Failed to validate OpenID Connect request because prompt was set to 'none' but auth_time ('%s') happened after the authorization request ('%s') was registered, indicating that the user was logged in during this request which is not allowed.", claims.GetAuthTimeSafe(), claims.GetRequestedAtSafe())) } } if stringslice.Has(requiredPrompt, consts.PromptTypeLogin) { - if claims.AuthTime.Before(claims.RequestedAt) { - return errorsx.WithStack(oauth2.ErrLoginRequired.WithHintf("Failed to validate OpenID Connect request because prompt was set to 'login' but auth_time ('%s') happened before the authorization request ('%s') was registered, indicating that the user was not re-authenticated which is forbidden.", claims.AuthTime, claims.RequestedAt)) + if claims.GetAuthTimeSafe().Before(claims.GetRequestedAtSafe()) { + return errorsx.WithStack(oauth2.ErrLoginRequired.WithHintf("Failed to validate OpenID Connect request because prompt was set to 'login' but auth_time ('%s') happened before the authorization request ('%s') was registered, indicating that the user was not re-authenticated which is forbidden.", claims.GetAuthTimeSafe(), claims.GetRequestedAtSafe())) } } @@ -144,7 +144,10 @@ func (v *OpenIDConnectRequestValidator) ValidatePrompt(ctx context.Context, req return nil } - tokenHint, err := v.Strategy.Decode(ctx, idTokenHint) + var tokenHint *jwt.Token + + tokenHint, err = v.Strategy.Decode(ctx, idTokenHint, jwt.WithIDTokenClient(req.GetClient())) + var ve *jwt.ValidationError if errors.As(err, &ve) && ve.Has(jwt.ValidationErrorExpired) { // Expired tokens are ok @@ -152,9 +155,11 @@ func (v *OpenIDConnectRequestValidator) ValidatePrompt(ctx context.Context, req return errorsx.WithStack(oauth2.ErrInvalidRequest.WithHint("Failed to validate OpenID Connect request as decoding id token from id_token_hint parameter failed.").WithWrap(err).WithDebugError(err)) } - if hintSub, _ := tokenHint.Claims[consts.ClaimSubject].(string); hintSub == "" { + var subHint string + + if subHint, err = tokenHint.Claims.GetSubject(); subHint == "" || err != nil { return errorsx.WithStack(oauth2.ErrInvalidRequest.WithHint("Failed to validate OpenID Connect request because provided id token from id_token_hint does not have a subject.")) - } else if hintSub != claims.Subject { + } else if subHint != claims.Subject { return errorsx.WithStack(oauth2.ErrLoginRequired.WithHint("Failed to validate OpenID Connect request because the subject from provided id token from id_token_hint does not match the current session's subject.")) } diff --git a/handler/openid/validator_test.go b/handler/openid/validator_test.go index 0aedb23f..a1864c89 100644 --- a/handler/openid/validator_test.go +++ b/handler/openid/validator_test.go @@ -21,11 +21,12 @@ func TestValidatePrompt(t *testing.T) { config := &oauth2.Config{ MinParameterEntropy: oauth2.MinParameterEntropy, } + var j = &DefaultStrategy{ - Signer: &jwt.DefaultSigner{ - GetPrivateKey: func(_ context.Context) (any, error) { - return key, nil - }}, + Strategy: &jwt.DefaultStrategy{ + Config: config, + Issuer: jwt.NewDefaultIssuerRS256Unverified(key), + }, Config: &oauth2.Config{ MinParameterEntropy: oauth2.MinParameterEntropy, }, @@ -34,7 +35,7 @@ func TestValidatePrompt(t *testing.T) { v := NewOpenIDConnectRequestValidator(j, config) var genIDToken = func(c jwt.IDTokenClaims) string { - s, _, err := j.Generate(context.TODO(), c.ToMapClaims(), jwt.NewHeaders()) + s, _, err := j.Encode(context.TODO(), c.ToMapClaims()) require.NoError(t, err) return s } @@ -58,8 +59,8 @@ func TestValidatePrompt(t *testing.T) { Subject: "foo", Claims: &jwt.IDTokenClaims{ Subject: "foo", - RequestedAt: time.Now().UTC(), - AuthTime: time.Now().UTC().Add(-time.Minute), + RequestedAt: jwt.Now(), + AuthTime: jwt.NewNumericDate(time.Now().Add(-time.Minute)), }, }, }, @@ -73,8 +74,8 @@ func TestValidatePrompt(t *testing.T) { Subject: "foo", Claims: &jwt.IDTokenClaims{ Subject: "foo", - RequestedAt: time.Now().UTC(), - AuthTime: time.Now().UTC().Add(-time.Minute), + RequestedAt: jwt.Now(), + AuthTime: jwt.NewNumericDate(time.Now().Add(-time.Minute)), }, }, }, @@ -88,8 +89,8 @@ func TestValidatePrompt(t *testing.T) { Subject: "foo", Claims: &jwt.IDTokenClaims{ Subject: "foo", - RequestedAt: time.Now().UTC(), - AuthTime: time.Now().UTC().Add(-time.Minute), + RequestedAt: jwt.Now(), + AuthTime: jwt.NewNumericDate(time.Now().Add(-time.Minute)), }, }, }, @@ -102,7 +103,7 @@ func TestValidatePrompt(t *testing.T) { Subject: "foo", Claims: &jwt.IDTokenClaims{ Subject: "foo", - RequestedAt: time.Now().UTC(), + RequestedAt: jwt.Now(), }, }, }, @@ -115,8 +116,8 @@ func TestValidatePrompt(t *testing.T) { Subject: "foo", Claims: &jwt.IDTokenClaims{ Subject: "foo", - RequestedAt: time.Now().UTC().Add(-time.Minute), - AuthTime: time.Now().UTC(), + RequestedAt: jwt.NewNumericDate(time.Now().Add(-time.Minute)), + AuthTime: jwt.Now(), }, }, }, @@ -129,8 +130,8 @@ func TestValidatePrompt(t *testing.T) { Subject: "foo", Claims: &jwt.IDTokenClaims{ Subject: "foo", - RequestedAt: time.Now().UTC(), - AuthTime: time.Now().UTC().Add(-time.Minute), + RequestedAt: jwt.Now(), + AuthTime: jwt.NewNumericDate(time.Now().Add(-time.Minute)), }, }, }, @@ -143,8 +144,8 @@ func TestValidatePrompt(t *testing.T) { Subject: "foo", Claims: &jwt.IDTokenClaims{ Subject: "foo", - RequestedAt: time.Now().UTC(), - AuthTime: time.Now().UTC(), + RequestedAt: jwt.Now(), + AuthTime: jwt.Now(), }, }, }, @@ -157,8 +158,8 @@ func TestValidatePrompt(t *testing.T) { Subject: "foo", Claims: &jwt.IDTokenClaims{ Subject: "foo", - RequestedAt: time.Now().UTC(), - AuthTime: time.Now().UTC(), + RequestedAt: jwt.Now(), + AuthTime: jwt.Now(), }, }, }, @@ -171,8 +172,8 @@ func TestValidatePrompt(t *testing.T) { Subject: "foo", Claims: &jwt.IDTokenClaims{ Subject: "foo", - RequestedAt: time.Now().UTC().Add(-time.Second * 5), - AuthTime: time.Now().UTC().Add(-time.Second), + RequestedAt: jwt.NewNumericDate(time.Now().Add(-time.Second * 5)), + AuthTime: jwt.NewNumericDate(time.Now().Add(-time.Second)), }, }, }, @@ -185,8 +186,8 @@ func TestValidatePrompt(t *testing.T) { Subject: "foo", Claims: &jwt.IDTokenClaims{ Subject: "foo", - RequestedAt: time.Now().UTC().Add(-time.Second * 5), - AuthTime: time.Now().UTC().Add(-time.Second), + RequestedAt: jwt.NewNumericDate(time.Now().Add(-time.Second * 5)), + AuthTime: jwt.NewNumericDate(time.Now().Add(-time.Second)), }, }, }, @@ -199,14 +200,14 @@ func TestValidatePrompt(t *testing.T) { Subject: "foo", Claims: &jwt.IDTokenClaims{ Subject: "foo", - RequestedAt: time.Now().UTC(), - AuthTime: time.Now().UTC().Add(-time.Second), + RequestedAt: jwt.Now(), + AuthTime: jwt.NewNumericDate(time.Now().Add(-time.Second)), }, }, idTokenHint: genIDToken(jwt.IDTokenClaims{ - Subject: "bar", - RequestedAt: time.Now(), - ExpiresAt: time.Now().Add(time.Hour), + Subject: "bar", + RequestedAt: jwt.Now(), + ExpirationTime: jwt.NewNumericDate(time.Now().Add(time.Hour)), }), }, { @@ -218,14 +219,14 @@ func TestValidatePrompt(t *testing.T) { Subject: "foo", Claims: &jwt.IDTokenClaims{ Subject: "foo", - RequestedAt: time.Now().UTC(), - AuthTime: time.Now().UTC().Add(-time.Second), + RequestedAt: jwt.Now(), + AuthTime: jwt.NewNumericDate(time.Now().Add(-time.Second)), }, }, idTokenHint: genIDToken(jwt.IDTokenClaims{ - Subject: "foo", - RequestedAt: time.Now(), - ExpiresAt: time.Now().Add(time.Hour), + Subject: "foo", + RequestedAt: jwt.Now(), + ExpirationTime: jwt.NewNumericDate(time.Now().Add(time.Hour)), }), }, { @@ -236,21 +237,20 @@ func TestValidatePrompt(t *testing.T) { s: &DefaultSession{ Subject: "foo", Claims: &jwt.IDTokenClaims{ - Subject: "foo", - RequestedAt: time.Now().UTC(), - AuthTime: time.Now().UTC().Add(-time.Second), - ExpiresAt: time.Now().UTC().Add(-time.Second), + Subject: "foo", + RequestedAt: jwt.Now(), + AuthTime: jwt.NewNumericDate(time.Now().Add(-time.Second)), + ExpirationTime: jwt.NewNumericDate(time.Now().Add(-time.Second)), }, }, idTokenHint: genIDToken(jwt.IDTokenClaims{ - Subject: "foo", - RequestedAt: time.Now(), - ExpiresAt: time.Now().Add(time.Hour), + Subject: "foo", + RequestedAt: jwt.Now(), + ExpirationTime: jwt.NewNumericDate(time.Now().Add(time.Hour)), }), }, } { t.Run(fmt.Sprintf("case=%d/description=%s", k, tc.d), func(t *testing.T) { - t.Logf("%s", tc.idTokenHint) err := v.ValidatePrompt(context.TODO(), &oauth2.AuthorizeRequest{ Request: oauth2.Request{ Form: url.Values{"prompt": {tc.prompt}, "id_token_hint": {tc.idTokenHint}}, diff --git a/handler/rfc7523/handler.go b/handler/rfc7523/handler.go index 166320d4..8054bb8e 100644 --- a/handler/rfc7523/handler.go +++ b/handler/rfc7523/handler.go @@ -45,8 +45,8 @@ var ( // TODO: Refactor time permitting. // //nolint:gocyclo -func (c *Handler) HandleTokenEndpointRequest(ctx context.Context, request oauth2.AccessRequester) error { - if err := c.CheckRequest(ctx, request); err != nil { +func (c *Handler) HandleTokenEndpointRequest(ctx context.Context, request oauth2.AccessRequester) (err error) { + if err = c.CheckRequest(ctx, request); err != nil { return err } diff --git a/handler/rfc8628/device_authorize_handler.go b/handler/rfc8628/device_authorize_handler.go index 8c879657..dad35955 100644 --- a/handler/rfc8628/device_authorize_handler.go +++ b/handler/rfc8628/device_authorize_handler.go @@ -7,6 +7,7 @@ import ( "authelia.com/provider/oauth2" "authelia.com/provider/oauth2/internal/consts" + "authelia.com/provider/oauth2/token/jwt" "authelia.com/provider/oauth2/x/errorsx" ) @@ -38,7 +39,7 @@ func (d *DeviceAuthorizeHandler) HandleRFC8628DeviceAuthorizeEndpointRequest(ctx dar.SetDeviceCodeSignature(deviceCodeSignature) dar.SetUserCodeSignature(userCodeSignature) - expireAt := time.Now().UTC().Add(d.Config.GetRFC8628CodeLifespan(ctx)).Round(time.Second) + expireAt := time.Now().UTC().Add(d.Config.GetRFC8628CodeLifespan(ctx)).Truncate(jwt.TimePrecision) session.SetExpiresAt(oauth2.DeviceCode, expireAt) session.SetExpiresAt(oauth2.UserCode, expireAt) diff --git a/handler/rfc8628/token_endpoint_handler_test.go b/handler/rfc8628/token_endpoint_handler_test.go index 514a1c38..ecdc0e84 100644 --- a/handler/rfc8628/token_endpoint_handler_test.go +++ b/handler/rfc8628/token_endpoint_handler_test.go @@ -19,6 +19,7 @@ import ( "authelia.com/provider/oauth2/storage" "authelia.com/provider/oauth2/testing/mock" "authelia.com/provider/oauth2/token/hmac" + "authelia.com/provider/oauth2/token/jwt" ) var o2hmacshaStrategy = hoauth2.HMACCoreStrategy{ @@ -428,16 +429,16 @@ func TestDeviceAuthorizeCode_HandleTokenEndpointRequest(t *testing.T) { }, }, check: func(t *testing.T, areq *oauth2.AccessRequest, authreq *oauth2.DeviceAuthorizeRequest) { - assert.Equal(t, time.Now().Add(time.Minute).UTC().Round(time.Second), areq.GetSession().GetExpiresAt(oauth2.AccessToken)) - assert.Equal(t, time.Now().Add(time.Minute).UTC().Round(time.Second), areq.GetSession().GetExpiresAt(oauth2.RefreshToken)) + assert.Equal(t, time.Now().Add(time.Minute).UTC().Truncate(jwt.TimePrecision), areq.GetSession().GetExpiresAt(oauth2.AccessToken)) + assert.Equal(t, time.Now().Add(time.Minute).UTC().Truncate(jwt.TimePrecision), areq.GetSession().GetExpiresAt(oauth2.RefreshToken)) }, setup: func(t *testing.T, areq *oauth2.AccessRequest, authreq *oauth2.DeviceAuthorizeRequest) { authreq = oauth2.NewDeviceAuthorizeRequest() authreq.SetSession(openid.NewDefaultSession()) authreq.GetSession().SetExpiresAt(oauth2.UserCode, - time.Now().Add(-time.Hour).UTC().Round(time.Second)) + time.Now().Add(-time.Hour).UTC().Truncate(jwt.TimePrecision)) authreq.GetSession().SetExpiresAt(oauth2.DeviceCode, - time.Now().Add(-time.Hour).UTC().Round(time.Second)) + time.Now().Add(-time.Hour).UTC().Truncate(jwt.TimePrecision)) dCode, dSig, err := strategy.GenerateRFC8628DeviceCode(context.TODO()) require.NoError(t, err) _, uSig, err := strategy.GenerateRFC8628UserCode(context.TODO()) @@ -458,8 +459,6 @@ func TestDeviceAuthorizeCode_HandleTokenEndpointRequest(t *testing.T) { c.setup(t, c.areq, c.authreq) } - t.Logf("Processing %+v", c.areq.Client) - err := h.HandleTokenEndpointRequest(context.Background(), c.areq) if c.expectErr != nil { require.EqualError(t, err, c.expectErr.Error(), "%+v", err) diff --git a/handler/rfc8628/user_authorize_handler_test.go b/handler/rfc8628/user_authorize_handler_test.go index c02703d8..7240a072 100644 --- a/handler/rfc8628/user_authorize_handler_test.go +++ b/handler/rfc8628/user_authorize_handler_test.go @@ -16,6 +16,7 @@ import ( "authelia.com/provider/oauth2/handler/openid" . "authelia.com/provider/oauth2/handler/rfc8628" "authelia.com/provider/oauth2/storage" + "authelia.com/provider/oauth2/token/jwt" ) func TestUserAuthorizeHandler_PopulateRFC8628UserAuthorizeEndpointResponse(t *testing.T) { @@ -37,7 +38,7 @@ func TestUserAuthorizeHandler_PopulateRFC8628UserAuthorizeEndpointResponse(t *te dar.SetSession(openid.NewDefaultSession()) dar.GetSession().SetExpiresAt(oauth2.UserCode, time.Now().UTC().Add( - f.Config.GetRFC8628CodeLifespan(a.ctx)).Round(time.Second)) + f.Config.GetRFC8628CodeLifespan(a.ctx)).Truncate(jwt.TimePrecision)) code, sig, err := f.Strategy.GenerateRFC8628UserCode(a.ctx) require.NoError(t, err) dar.SetUserCodeSignature(sig) @@ -243,7 +244,7 @@ func TestUserAuthorizeHandler_PopulateRFC8628UserAuthorizeEndpointResponse_Handl dar.SetSession(openid.NewDefaultSession()) dar.GetSession().SetExpiresAt(oauth2.UserCode, time.Now().UTC().Add( - f.Config.GetRFC8628CodeLifespan(a.ctx)).Round(time.Second)) + f.Config.GetRFC8628CodeLifespan(a.ctx)).Truncate(jwt.TimePrecision)) code, sig, err := f.Strategy.GenerateRFC8628UserCode(a.ctx) require.NoError(t, err) dar.SetUserCodeSignature(sig) @@ -400,7 +401,7 @@ func TestUserAuthorizeHandler_PopulateRFC8628UserAuthorizeEndpointResponse_Handl dar.SetSession(openid.NewDefaultSession()) dar.GetSession().SetExpiresAt(oauth2.UserCode, time.Now().UTC().Add( - f.Config.GetRFC8628CodeLifespan(a.ctx)).Round(time.Second)) + f.Config.GetRFC8628CodeLifespan(a.ctx)).Truncate(jwt.TimePrecision)) code, sig, err := f.Strategy.GenerateRFC8628UserCode(a.ctx) require.NoError(t, err) dar.SetUserCodeSignature(sig) @@ -443,7 +444,7 @@ func TestUserAuthorizeHandler_PopulateRFC8628UserAuthorizeEndpointResponse_Handl dar.SetSession(openid.NewDefaultSession()) dar.GetSession().SetExpiresAt(oauth2.UserCode, time.Now().UTC().Add(time.Duration(-1)* - f.Config.GetRFC8628CodeLifespan(a.ctx)).Round(time.Second)) + f.Config.GetRFC8628CodeLifespan(a.ctx)).Truncate(jwt.TimePrecision)) code, sig, err := f.Strategy.GenerateRFC8628UserCode(a.ctx) require.NoError(t, err) dar.SetUserCodeSignature(sig) diff --git a/handler/rfc8693/access_token_type_handler.go b/handler/rfc8693/access_token_type_handler.go index 7e83ae76..b03551c9 100644 --- a/handler/rfc8693/access_token_type_handler.go +++ b/handler/rfc8693/access_token_type_handler.go @@ -10,6 +10,7 @@ import ( hoauth2 "authelia.com/provider/oauth2/handler/oauth2" "authelia.com/provider/oauth2/internal/consts" "authelia.com/provider/oauth2/storage" + "authelia.com/provider/oauth2/token/jwt" "authelia.com/provider/oauth2/x/errorsx" ) @@ -161,7 +162,7 @@ func (c *AccessTokenTypeHandler) issue(ctx context.Context, request oauth2.Acces issueRefreshToken := c.canIssueRefreshToken(request) if issueRefreshToken { - request.GetSession().SetExpiresAt(oauth2.RefreshToken, time.Now().UTC().Add(c.RefreshTokenLifespan).Round(time.Second)) + request.GetSession().SetExpiresAt(oauth2.RefreshToken, time.Now().UTC().Add(c.RefreshTokenLifespan).Truncate(jwt.TimePrecision)) refresh, refreshSignature, err := c.CoreStrategy.GenerateRefreshToken(ctx, request) if err != nil { return errors.WithStack(oauth2.ErrServerError.WithDebugError(err)) diff --git a/handler/rfc8693/custom_jwt_type_handler.go b/handler/rfc8693/custom_jwt_type_handler.go index 768900f6..2fb274c2 100644 --- a/handler/rfc8693/custom_jwt_type_handler.go +++ b/handler/rfc8693/custom_jwt_type_handler.go @@ -15,8 +15,9 @@ import ( ) type CustomJWTTypeHandler struct { - Config oauth2.RFC8693ConfigProvider - JWTStrategy jwt.Signer + Config oauth2.RFC8693ConfigProvider + + jwt.Strategy Storage } @@ -122,7 +123,7 @@ func (c *CustomJWTTypeHandler) validate(ctx context.Context, _ oauth2.AccessRequ if window == 0 { window = 1 * time.Hour } - claims := ftoken.Claims + claims := ftoken.Claims.ToMapClaims() if issued, exists := claims[consts.ClaimIssuedAt]; exists { if time.Unix(toInt64(issued), 0).Add(window).Before(time.Now()) { @@ -154,7 +155,7 @@ func (c *CustomJWTTypeHandler) validate(ctx context.Context, _ oauth2.AccessRequ } } - return map[string]any(claims), nil + return claims, nil } func (c *CustomJWTTypeHandler) issue(ctx context.Context, request oauth2.AccessRequester, tokenType oauth2.RFC8693TokenType, response oauth2.AccessResponder) error { @@ -175,8 +176,8 @@ func (c *CustomJWTTypeHandler) issue(ctx context.Context, request oauth2.AccessR claims.Subject = request.GetClient().GetID() } - if claims.ExpiresAt.IsZero() { - claims.ExpiresAt = time.Now().UTC().Add(jwtType.Expiry) + if claims.ExpirationTime == nil || claims.ExpirationTime.IsZero() { + claims.ExpirationTime = jwt.NewNumericDate(time.Now().Add(jwtType.Expiry)) } if claims.Issuer == "" { @@ -200,16 +201,17 @@ func (c *CustomJWTTypeHandler) issue(ctx context.Context, request oauth2.AccessR claims.JTI = uuid.New().String() } - claims.IssuedAt = time.Now().UTC() + claims.IssuedAt = jwt.Now() - token, _, err := c.JWTStrategy.Generate(ctx, claims.ToMapClaims(), sess.IDTokenHeaders()) + token, _, err := c.Strategy.Encode(ctx, claims.ToMapClaims(), jwt.WithHeaders(sess.IDTokenHeaders()), jwt.WithIDTokenClient(request.GetClient())) if err != nil { return err } response.SetAccessToken(token) response.SetTokenType("N_A") - response.SetExpiresIn(time.Duration(claims.ExpiresAt.UnixNano() - time.Now().UTC().UnixNano())) + response.SetExpiresIn(time.Duration(claims.GetExpirationTimeSafe().UnixNano() - time.Now().UTC().UnixNano())) + return nil } diff --git a/handler/rfc8693/id_token_type_handler.go b/handler/rfc8693/id_token_type_handler.go index 6a768563..b000fbb6 100644 --- a/handler/rfc8693/id_token_type_handler.go +++ b/handler/rfc8693/id_token_type_handler.go @@ -12,7 +12,7 @@ import ( type IDTokenTypeHandler struct { Config oauth2.Configurator - JWTStrategy jwt.Signer + Strategy jwt.Strategy IssueStrategy openid.OpenIDConnectTokenStrategy ValidationStrategy openid.TokenValidationStrategy Storage @@ -118,7 +118,7 @@ func (c *IDTokenTypeHandler) validate(ctx context.Context, request oauth2.Access return nil, errorsx.WithStack(oauth2.ErrInvalidRequest.WithHint("Claim 'sub' is missing.")) } - return map[string]any(claims), nil + return claims, nil } func (c *IDTokenTypeHandler) issue(ctx context.Context, request oauth2.AccessRequester, response oauth2.AccessResponder) error { diff --git a/handler/rfc8693/refresh_token_type_handler.go b/handler/rfc8693/refresh_token_type_handler.go index c7e673b2..3e31f251 100644 --- a/handler/rfc8693/refresh_token_type_handler.go +++ b/handler/rfc8693/refresh_token_type_handler.go @@ -10,6 +10,7 @@ import ( hoauth2 "authelia.com/provider/oauth2/handler/oauth2" "authelia.com/provider/oauth2/internal/consts" "authelia.com/provider/oauth2/storage" + "authelia.com/provider/oauth2/token/jwt" "authelia.com/provider/oauth2/x/errorsx" ) @@ -148,7 +149,7 @@ func (c *RefreshTokenTypeHandler) validate(ctx context.Context, request oauth2.A } func (c *RefreshTokenTypeHandler) issue(ctx context.Context, request oauth2.AccessRequester, response oauth2.AccessResponder) error { - request.GetSession().SetExpiresAt(oauth2.RefreshToken, time.Now().UTC().Add(c.RefreshTokenLifespan).Round(time.Second)) + request.GetSession().SetExpiresAt(oauth2.RefreshToken, time.Now().UTC().Add(c.RefreshTokenLifespan).Truncate(jwt.TimePrecision)) refresh, refreshSignature, err := c.CoreStrategy.GenerateRefreshToken(ctx, request) if err != nil { return errors.WithStack(oauth2.ErrServerError.WithDebugError(err)) diff --git a/handler/rfc8693/token_exchange_test.go b/handler/rfc8693/token_exchange_test.go index e1a595d0..3dbf3d19 100644 --- a/handler/rfc8693/token_exchange_test.go +++ b/handler/rfc8693/token_exchange_test.go @@ -30,12 +30,6 @@ func TestAccessTokenExchangeImpersonation(t *testing.T) { store := storage.NewExampleStore() jwtName := "urn:custom:jwt" - jwtSigner := &jwt.DefaultSigner{ - GetPrivateKey: func(_ context.Context) (any, error) { - return key, nil - }, - } - customJWTType := &JWTType{ Name: jwtName, JWTValidationConfig: JWTValidationConfig{ @@ -70,6 +64,11 @@ func TestAccessTokenExchangeImpersonation(t *testing.T) { DefaultRequestedTokenType: consts.TokenTypeRFC8693AccessToken, } + strategy := &jwt.DefaultStrategy{ + Config: config, + Issuer: jwt.NewDefaultIssuerRS256Unverified(key), + } + coreStrategy := &hoauth2.HMACCoreStrategy{ Enigma: &hmac.HMACStrategy{Config: config}, Config: config, @@ -93,10 +92,9 @@ func TestAccessTokenExchangeImpersonation(t *testing.T) { customJWTHandler := &CustomJWTTypeHandler{ Config: config, - JWTStrategy: &jwt.DefaultSigner{ - GetPrivateKey: func(_ context.Context) (any, error) { - return key, nil - }, + Strategy: &jwt.DefaultStrategy{ + Config: config, + Issuer: jwt.NewDefaultIssuerRS256Unverified(key), }, Storage: store, } @@ -142,7 +140,7 @@ func TestAccessTokenExchangeImpersonation(t *testing.T) { Client: store.Clients["my-client"], Form: url.Values{ "subject_token_type": []string{jwtName}, - "subject_token": []string{createJWT(context.Background(), jwtSigner, jwt.MapClaims{ + "subject_token": []string{createJWT(context.Background(), store.Clients["my-client"], strategy, jwt.MapClaims{ "subject": "peter_for_jwt", "jti": uuid.New(), "iss": "https://customory.com", @@ -267,8 +265,9 @@ func createAccessToken(ctx context.Context, coreStrategy hoauth2.CoreStrategy, s return token } -func createJWT(ctx context.Context, signer jwt.Signer, claims jwt.MapClaims) string { - token, _, err := signer.Generate(ctx, claims, &jwt.Headers{}) +func createJWT(ctx context.Context, client any, strategy jwt.Strategy, claims jwt.MapClaims) string { + token, _, err := strategy.Encode(ctx, claims, jwt.WithIDTokenClient(client)) + if err != nil { panic(err.Error()) } diff --git a/helper_test.go b/helper_test.go index 0c68a9fa..fe5839ae 100644 --- a/helper_test.go +++ b/helper_test.go @@ -25,7 +25,6 @@ func TestStringInSlice(t *testing.T) { {needle: "foo", haystack: []string{}, ok: false}, } { assert.Equal(t, c.ok, StringInSlice(c.needle, c.haystack), "%d", k) - t.Logf("Passed test case %d", k) } } diff --git a/integration/authorize_code_grant_public_client_pkce_test.go b/integration/authorize_code_grant_public_client_pkce_test.go index 0d72a973..8e1a3525 100644 --- a/integration/authorize_code_grant_public_client_pkce_test.go +++ b/integration/authorize_code_grant_public_client_pkce_test.go @@ -79,19 +79,11 @@ func runAuthorizeCodeGrantWithPublicClientAndPKCETest(t *testing.T, strategy any t.Run(fmt.Sprintf("case=%d/description=%s", k, c.description), func(t *testing.T) { c.setup() - t.Logf("Got url: %s", authCodeUrl) - resp, err := http.Get(authCodeUrl) //nolint:gosec require.NoError(t, err) require.Equal(t, resp.StatusCode, c.authStatusCode) if resp.StatusCode == http.StatusOK { - // This should fail because no verifier was given - // _, err := oauthClient.Exchange(xoauth2.NoContext, resp.Request.URL.Query().Get(consts.FormParameterAuthorizationCode)) - // require.Error(t, err) - // require.Empty(t, token.AccessToken) - t.Logf("Got redirect url: %s", resp.Request.URL) - resp, err := http.PostForm(ts.URL+"/token", url.Values{ consts.FormParameterAuthorizationCode: {resp.Request.URL.Query().Get(consts.FormParameterAuthorizationCode)}, consts.FormParameterGrantType: {consts.GrantTypeAuthorizationCode}, diff --git a/integration/client_credentials_grant_test.go b/integration/client_credentials_grant_test.go index 261eb8fc..d38f8449 100644 --- a/integration/client_credentials_grant_test.go +++ b/integration/client_credentials_grant_test.go @@ -145,8 +145,6 @@ func runClientCredentialsGrantTest(t *testing.T, strategy hoauth2.AccessTokenStr if c.check != nil { c.check(t, token) } - - t.Logf("Passed test case %d", k) }) } } diff --git a/integration/helper_setup_test.go b/integration/helper_setup_test.go index e5ff0038..436f650b 100644 --- a/integration/helper_setup_test.go +++ b/integration/helper_setup_test.go @@ -4,7 +4,6 @@ package integration_test import ( - "context" "crypto" "crypto/rand" "crypto/rsa" @@ -186,10 +185,9 @@ var hmacStrategy = &hoauth2.HMACCoreStrategy{ var defaultRSAKey = gen.MustRSAKey() var jwtStrategy = &hoauth2.JWTProfileCoreStrategy{ - Signer: &jwt.DefaultSigner{ - GetPrivateKey: func(ctx context.Context) (any, error) { - return defaultRSAKey, nil - }, + Strategy: &jwt.DefaultStrategy{ + Config: &oauth2.Config{}, + Issuer: jwt.NewDefaultIssuerRS256Unverified(defaultRSAKey), }, Config: &oauth2.Config{}, HMACCoreStrategy: hmacStrategy, diff --git a/integration/introspect_token_test.go b/integration/introspect_token_test.go index 72adec7d..3c6de835 100644 --- a/integration/introspect_token_test.go +++ b/integration/introspect_token_test.go @@ -26,10 +26,9 @@ func TestIntrospectToken(t *testing.T) { EnforceJWTProfileAccessTokens: true, } - signer := &jwt.DefaultSigner{ - GetPrivateKey: func(ctx context.Context) (any, error) { - return defaultRSAKey, nil - }, + strategy := &jwt.DefaultStrategy{ + Config: config, + Issuer: jwt.NewDefaultIssuerRS256Unverified(defaultRSAKey), } for _, c := range []struct { @@ -44,16 +43,15 @@ func TestIntrospectToken(t *testing.T) { }, { description: "JWT strategy with OAuth2TokenIntrospectionFactory", - strategy: hoauth2.NewCoreStrategy(config, "authelia_%s_", signer), + strategy: hoauth2.NewCoreStrategy(config, "authelia_%s_", strategy), factory: compose.OAuth2TokenIntrospectionFactory, }, { description: "JWT strategy with OAuth2StatelessJWTIntrospectionFactory", - strategy: hoauth2.NewCoreStrategy(config, "authelia_%s_", signer), + strategy: hoauth2.NewCoreStrategy(config, "authelia_%s_", strategy), factory: compose.OAuth2StatelessJWTIntrospectionFactory, }, } { - t.Logf("testing %v", c.description) runIntrospectTokenTest(t, c.strategy, c.factory) } } @@ -125,7 +123,6 @@ func runIntrospectTokenTest(t *testing.T, strategy hoauth2.AccessTokenStrategy, _, bytes, errs := c.prepare(s).End() assert.Nil(t, json.Unmarshal([]byte(bytes), &res)) - t.Logf("Got answer: %s", bytes) assert.Len(t, errs, 0) assert.Equal(t, c.isActive, res.Active) diff --git a/integration/oidc_explicit_test.go b/integration/oidc_explicit_test.go index 77c2ec65..e71a535c 100644 --- a/integration/oidc_explicit_test.go +++ b/integration/oidc_explicit_test.go @@ -105,8 +105,8 @@ func TestOpenIDConnectExplicitFlow(t *testing.T) { { session: newIDSession(&jwt.IDTokenClaims{ Subject: "peter", - RequestedAt: time.Now().UTC(), - AuthTime: time.Now().Add(time.Second).UTC(), + RequestedAt: jwt.Now(), + AuthTime: jwt.NewNumericDate(time.Now().Add(time.Second)), }), description: "should not pass missing redirect uri", setup: func(oauthClient *xoauth2.Config) string { @@ -130,8 +130,8 @@ func TestOpenIDConnectExplicitFlow(t *testing.T) { { session: newIDSession(&jwt.IDTokenClaims{ Subject: "peter", - RequestedAt: time.Now().UTC(), - AuthTime: time.Now().Add(time.Second).UTC(), + RequestedAt: jwt.Now(), + AuthTime: jwt.NewNumericDate(time.Now().Add(time.Second)), }), description: "should pass", setup: func(oauthClient *xoauth2.Config) string { @@ -143,8 +143,8 @@ func TestOpenIDConnectExplicitFlow(t *testing.T) { { session: newIDSession(&jwt.IDTokenClaims{ Subject: "peter", - RequestedAt: time.Now().UTC(), - AuthTime: time.Now().Add(time.Second).UTC(), + RequestedAt: jwt.Now(), + AuthTime: jwt.NewNumericDate(time.Now().Add(time.Second)), }), description: "should not pass missing redirect uri", setup: func(oauthClient *xoauth2.Config) string { @@ -158,8 +158,8 @@ func TestOpenIDConnectExplicitFlow(t *testing.T) { { session: newIDSession(&jwt.IDTokenClaims{ Subject: "peter", - RequestedAt: time.Now().UTC(), - AuthTime: time.Now().Add(time.Second).UTC(), + RequestedAt: jwt.Now(), + AuthTime: jwt.NewNumericDate(time.Now().Add(time.Second)), }), description: "should not pass missing redirect uri", setup: func(oauthClient *xoauth2.Config) string { @@ -173,8 +173,8 @@ func TestOpenIDConnectExplicitFlow(t *testing.T) { { session: newIDSession(&jwt.IDTokenClaims{ Subject: "peter", - RequestedAt: time.Now().UTC(), - AuthTime: time.Now().Add(-time.Minute).UTC(), + RequestedAt: jwt.Now(), + AuthTime: jwt.NewNumericDate(time.Now().Add(-time.Minute)), }), description: "should fail because authentication was in the past", setup: func(oauthClient *xoauth2.Config) string { @@ -187,8 +187,8 @@ func TestOpenIDConnectExplicitFlow(t *testing.T) { { session: newIDSession(&jwt.IDTokenClaims{ Subject: "peter", - RequestedAt: time.Now().UTC(), - AuthTime: time.Now().Add(-time.Minute).UTC(), + RequestedAt: jwt.Now(), + AuthTime: jwt.NewNumericDate(time.Now().Add(-time.Minute)), }), description: "should pass because authorization was in the past and no login was required", setup: func(oauthClient *xoauth2.Config) string { diff --git a/integration/resource_owner_password_credentials_grant_test.go b/integration/resource_owner_password_credentials_grant_test.go index ec235b02..ac024223 100644 --- a/integration/resource_owner_password_credentials_grant_test.go +++ b/integration/resource_owner_password_credentials_grant_test.go @@ -94,7 +94,5 @@ func runResourceOwnerPasswordCredentialsGrantTest(t *testing.T, strategy hoauth2 c.check(t, token) } } - - t.Logf("Passed test case %d", k) } } diff --git a/internal/consts/client_auth_method.go b/internal/consts/client_auth_method.go index 982284d0..b2bec6c9 100644 --- a/internal/consts/client_auth_method.go +++ b/internal/consts/client_auth_method.go @@ -1,6 +1,6 @@ package consts -// Client Auth Method strings. +// Client Auth SignatureAlgorithm strings. const ( ClientAuthMethodClientSecretBasic = "client_secret_basic" ClientAuthMethodClientSecretPost = "client_secret_post" diff --git a/internal/consts/const.go b/internal/consts/const.go index 22911e86..5fbd19de 100644 --- a/internal/consts/const.go +++ b/internal/consts/const.go @@ -16,4 +16,5 @@ const ( valueNonce = "nonce" valueDeviceCode = "device_code" valueUserCode = "user_code" + valueEnc = "enc" ) diff --git a/internal/consts/jwt.go b/internal/consts/jwt.go index 37ddff52..ecc35b4f 100644 --- a/internal/consts/jwt.go +++ b/internal/consts/jwt.go @@ -1,21 +1,26 @@ package consts const ( - JSONWebTokenHeaderKeyIdentifier = "kid" - JSONWebTokenHeaderAlgorithm = "alg" - JSONWebTokenHeaderUse = "use" - JSONWebTokenHeaderType = "typ" + JSONWebTokenHeaderKeyIdentifier = "kid" + JSONWebTokenHeaderAlgorithm = "alg" + JSONWebTokenHeaderEncryptionAlgorithm = valueEnc + JSONWebTokenHeaderCompressionAlgorithm = "zip" + JSONWebTokenHeaderPBES2Count = "p2c" + + JSONWebTokenHeaderType = "typ" + JSONWebTokenHeaderContentType = "cty" ) const ( JSONWebTokenUseSignature = "sig" - JSONWebTokenUseEncryption = "enc" + JSONWebTokenUseEncryption = valueEnc ) const ( - JSONWebTokenTypeJWT = "JWT" - JSONWebTokenTypeAccessToken = "at+jwt" - JSONWebTokenTypeTokenIntrospection = "token-introspection+jwt" + JSONWebTokenTypeJWT = "JWT" + JSONWebTokenTypeAccessToken = "at+jwt" + JSONWebTokenTypeAccessTokenAlternative = "application/at+jwt" + JSONWebTokenTypeTokenIntrospection = "token-introspection+jwt" ) const ( diff --git a/internal/consts/spec.go b/internal/consts/spec.go index db8742b8..bdfa6647 100644 --- a/internal/consts/spec.go +++ b/internal/consts/spec.go @@ -7,7 +7,7 @@ const ( PromptTypeSelectAccount = "select_account" ) -// Proof Key Code Exchange Challenge Method strings. +// Proof Key Code Exchange Challenge SignatureAlgorithm strings. const ( PKCEChallengeMethodPlain = "plain" PKCEChallengeMethodSHA256 = "S256" diff --git a/internal/mock/client_secret.go b/internal/mock/client_secret.go new file mode 100644 index 00000000..6d430bf3 --- /dev/null +++ b/internal/mock/client_secret.go @@ -0,0 +1,83 @@ +// Code generated by MockGen. DO NOT EDIT. +// Source: authelia.com/provider/oauth2 (interfaces: ClientSecret) +// +// Generated by this command: +// +// mockgen -package internal -destination internal/mock_client_secret.go authelia.com/provider/oauth2 ClientSecret +// + +// Package internal is a generated GoMock package. +package mock + +import ( + context "context" + reflect "reflect" + + gomock "go.uber.org/mock/gomock" +) + +// MockClientSecret is a mock of ClientSecret interface. +type MockClientSecret struct { + ctrl *gomock.Controller + recorder *MockClientSecretMockRecorder +} + +// MockClientSecretMockRecorder is the mock recorder for MockClientSecret. +type MockClientSecretMockRecorder struct { + mock *MockClientSecret +} + +// NewMockClientSecret creates a new mock instance. +func NewMockClientSecret(ctrl *gomock.Controller) *MockClientSecret { + mock := &MockClientSecret{ctrl: ctrl} + mock.recorder = &MockClientSecretMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockClientSecret) EXPECT() *MockClientSecretMockRecorder { + return m.recorder +} + +// Compare mocks base method. +func (m *MockClientSecret) Compare(arg0 context.Context, arg1 []byte) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Compare", arg0, arg1) + ret0, _ := ret[0].(error) + return ret0 +} + +// Compare indicates an expected call of Compare. +func (mr *MockClientSecretMockRecorder) Compare(arg0, arg1 any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Compare", reflect.TypeOf((*MockClientSecret)(nil).Compare), arg0, arg1) +} + +// GetPlainTextValue mocks base method. +func (m *MockClientSecret) GetPlainTextValue() ([]byte, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetPlainTextValue") + ret0, _ := ret[0].([]byte) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// GetPlainTextValue indicates an expected call of GetPlainTextValue. +func (mr *MockClientSecretMockRecorder) GetPlainTextValue() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetPlainTextValue", reflect.TypeOf((*MockClientSecret)(nil).GetPlainTextValue)) +} + +// IsPlainText mocks base method. +func (m *MockClientSecret) IsPlainText() bool { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "IsPlainText") + ret0, _ := ret[0].(bool) + return ret0 +} + +// IsPlainText indicates an expected call of IsPlainText. +func (mr *MockClientSecretMockRecorder) IsPlainText() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "IsPlainText", reflect.TypeOf((*MockClientSecret)(nil).IsPlainText)) +} diff --git a/internal/randx/sequence.go b/internal/randx/sequence.go index 0b0163ac..f19cc1c3 100644 --- a/internal/randx/sequence.go +++ b/internal/randx/sequence.go @@ -13,20 +13,28 @@ var rander = rand.Reader // random function var ( // AlphaNum contains runes [abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789]. AlphaNum = []rune("abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789") + // Alpha contains runes [abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ]. Alpha = []rune("abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ") + // AlphaLowerNum contains runes [abcdefghijklmnopqrstuvwxyz0123456789]. AlphaLowerNum = []rune("abcdefghijklmnopqrstuvwxyz0123456789") + // AlphaUpperNum contains runes [ABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789]. AlphaUpperNum = []rune("ABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789") + // AlphaLower contains runes [abcdefghijklmnopqrstuvwxyz]. AlphaLower = []rune("abcdefghijklmnopqrstuvwxyz") + // AlphaUpperVowels contains runes [AEIOUY]. AlphaUpperVowels = []rune("AEIOUY") + // AlphaUpperNoVowels contains runes [BCDFGHJKLMNPQRSTVWXZ]. AlphaUpperNoVowels = []rune("BCDFGHJKLMNPQRSTVWXZ") + // AlphaUpper contains runes [ABCDEFGHIJKLMNOPQRSTUVWXYZ]. AlphaUpper = []rune("ABCDEFGHIJKLMNOPQRSTUVWXYZ") + // Numeric contains runes [0123456789]. Numeric = []rune("0123456789") ) @@ -36,11 +44,13 @@ func RuneSequence(l int, allowedRunes []rune) (seq []rune, err error) { c := big.NewInt(int64(len(allowedRunes))) seq = make([]rune, l) + var r *big.Int + for i := 0; i < l; i++ { - r, err := rand.Int(rander, c) - if err != nil { + if r, err = rand.Int(rander, c); err != nil { return seq, err } + rn := allowedRunes[r.Uint64()] seq[i] = rn } @@ -50,9 +60,9 @@ func RuneSequence(l int, allowedRunes []rune) (seq []rune, err error) { // MustString returns a random string sequence using the defined runes. Panics on error. func MustString(l int, allowedRunes []rune) string { - seq, err := RuneSequence(l, allowedRunes) - if err != nil { + if seq, err := RuneSequence(l, allowedRunes); err != nil { panic(err) + } else { + return string(seq) } - return string(seq) } diff --git a/internal/test_helpers.go b/internal/test_helpers.go index 3638a1d6..d09e32c2 100644 --- a/internal/test_helpers.go +++ b/internal/test_helpers.go @@ -12,13 +12,13 @@ import ( "testing" "time" - "github.com/golang-jwt/jwt/v5" "github.com/stretchr/testify/require" "golang.org/x/net/html" xoauth2 "golang.org/x/oauth2" "authelia.com/provider/oauth2" "authelia.com/provider/oauth2/internal/consts" + "authelia.com/provider/oauth2/token/jwt" ) func ptr(d time.Duration) *time.Duration { @@ -61,18 +61,17 @@ func RequireEqualTime(t *testing.T, expected time.Time, actual time.Time, precis } func ExtractJwtExpClaim(t *testing.T, token string) *time.Time { - parser := jwt.NewParser(jwt.WithoutClaimsValidation()) + claims := &jwt.IDTokenClaims{} - claims := &jwt.RegisteredClaims{} + _, err := jwt.UnsafeParseSignedAny(token, claims) - _, _, err := parser.ParseUnverified(token, claims) require.NoError(t, err) - if claims.ExpiresAt == nil { + if claims.ExpirationTime == nil { return nil } - return &claims.ExpiresAt.Time + return &claims.ExpirationTime.Time } //nolint:gocyclo diff --git a/introspection_request_handler_test.go b/introspection_request_handler_test.go index 3dcf0b5b..05c16f95 100644 --- a/introspection_request_handler_test.go +++ b/introspection_request_handler_test.go @@ -342,10 +342,6 @@ func TestIntrospectionResponseToMap(t *testing.T) { consts.ClaimAudience: []string{"https://example.com", "aclient"}, consts.ClaimIssuedAt: int64(100000), consts.ClaimClientIdentifier: "aclient", - //"aclaim": 1, - //consts.ClaimSubject: "asubj", - //consts.ClaimExpirationTime: int64(1000000), - //consts.ClaimUsername: "auser", }, }, } diff --git a/introspection_response_writer.go b/introspection_response_writer.go index c1b86104..99c7f2cb 100644 --- a/introspection_response_writer.go +++ b/introspection_response_writer.go @@ -272,7 +272,7 @@ func (f *Fosite) writeIntrospectionResponse(ctx context.Context, rw http.Respons return } - claims := map[string]any{ + claims := jwt.MapClaims{ consts.ClaimJWTID: jti.String(), consts.ClaimIssuer: f.Config.GetIntrospectionIssuer(ctx), consts.ClaimIssuedAt: time.Now().UTC().Unix(), @@ -283,15 +283,15 @@ func (f *Fosite) writeIntrospectionResponse(ctx context.Context, rw http.Respons claims[consts.ClaimAudience] = aud } - signer := f.Config.GetIntrospectionJWTResponseSigner(ctx) + strategy := f.Config.GetIntrospectionJWTResponseStrategy(ctx) - if signer == nil { - f.WriteIntrospectionError(ctx, rw, errors.WithStack(ErrServerError.WithHint("Failed to generate the response.").WithDebug("The Introspection JWT could not be generated as the server is misconfigured. The Introspection Signer was not configured."))) + if strategy == nil { + f.WriteIntrospectionError(ctx, rw, errors.WithStack(ErrServerError.WithHint("Failed to generate the response.").WithDebug("The Introspection JWT could not be generated as the server is misconfigured. The Introspection jwt.Strategy was not configured."))) return } - if token, _, err = signer.Generate(ctx, claims, header); err != nil { + if token, _, err = strategy.Encode(ctx, claims, jwt.WithHeaders(header), jwt.WithIntrospectionClient(r.GetAccessRequester().GetClient())); err != nil { f.WriteIntrospectionError(ctx, rw, errors.WithStack(ErrServerError.WithHint("Failed to generate the response.").WithDebugf("The Introspection JWT itself could not be generated with error %+v.", err))) return diff --git a/introspection_response_writer_test.go b/introspection_response_writer_test.go index faf033a7..be35c122 100644 --- a/introspection_response_writer_test.go +++ b/introspection_response_writer_test.go @@ -99,7 +99,7 @@ func TestWriteIntrospectionResponseBody(t *testing.T) { hasExtra: false, }, { - description: "should success for ExpiresAt not set access token", + description: "should success for ExpirationTime not set access token", setup: func() { ires.Active = true ires.TokenUse = AccessToken diff --git a/pushed_authorize_response_writer_test.go b/pushed_authorize_response_writer_test.go index cff8428f..b92e781f 100644 --- a/pushed_authorize_response_writer_test.go +++ b/pushed_authorize_response_writer_test.go @@ -57,6 +57,5 @@ func TestNewPushedAuthorizeResponse(t *testing.T) { } else { assert.NotNil(t, responder, "%d", k) } - t.Logf("Passed test case %d", k) } } diff --git a/response_handler.go b/response_handler.go index 6823956c..051f3495 100644 --- a/response_handler.go +++ b/response_handler.go @@ -256,7 +256,7 @@ type ResponseModeHandlerConfigurator interface { FormPostHTMLTemplateProvider FormPostResponseProvider JWTSecuredAuthorizeResponseModeIssuerProvider - JWTSecuredAuthorizeResponseModeSignerProvider + JWTSecuredAuthorizeResponseModeStrategyProvider JWTSecuredAuthorizeResponseModeLifespanProvider MessageCatalogProvider SendDebugMessagesToClientsProvider diff --git a/rfc8628_device_authorize_response_writer_test.go b/rfc8628_device_authorize_response_writer_test.go index 798e6fea..884ce975 100644 --- a/rfc8628_device_authorize_response_writer_test.go +++ b/rfc8628_device_authorize_response_writer_test.go @@ -68,6 +68,5 @@ func TestNewDeviceResponse(t *testing.T) { } else { assert.NotNil(t, responder, "%d", k) } - t.Logf("Passed test case %d", k) } } diff --git a/rfc8628_device_authorize_write_test.go b/rfc8628_device_authorize_write_test.go index 4b7eb5b8..8e307ce6 100644 --- a/rfc8628_device_authorize_write_test.go +++ b/rfc8628_device_authorize_write_test.go @@ -10,6 +10,7 @@ import ( "github.com/stretchr/testify/require" . "authelia.com/provider/oauth2" + "authelia.com/provider/oauth2/token/jwt" ) func TestWriteDeviceAuthorizeResponse(t *testing.T) { @@ -25,10 +26,10 @@ func TestWriteDeviceAuthorizeResponse(t *testing.T) { resp.SetUserCode("AAAA") resp.SetDeviceCode("BBBB") resp.SetInterval(int( - oauth2.Config.GetRFC8628TokenPollingInterval(context.TODO()).Round(time.Second).Seconds(), + oauth2.Config.GetRFC8628TokenPollingInterval(context.TODO()).Truncate(jwt.TimePrecision).Seconds(), )) resp.SetExpiresIn(int64( - time.Now().Round(time.Second).Add(oauth2.Config.GetRFC8628CodeLifespan(context.TODO())).Second(), + time.Now().Truncate(jwt.TimePrecision).Add(oauth2.Config.GetRFC8628CodeLifespan(context.TODO())).Second(), )) resp.SetVerificationURI(oauth2.Config.GetRFC8628UserVerificationURL(context.TODO())) resp.SetVerificationURIComplete( diff --git a/rfc8628_user_authorize_request_handler_test.go b/rfc8628_user_authorize_request_handler_test.go index 6fbb5242..4887178c 100644 --- a/rfc8628_user_authorize_request_handler_test.go +++ b/rfc8628_user_authorize_request_handler_test.go @@ -75,6 +75,5 @@ func TestFosite_NewRFC8628UserAuthorizeRequest(t *testing.T) { assert.NotNil(t, resp, "%d", k) assert.Equal(t, req.Form, resp.GetRequestForm()) } - t.Logf("Passed test case %d", k) } } diff --git a/rfc8628_user_authorize_response_writer_test.go b/rfc8628_user_authorize_response_writer_test.go index f17296bc..46861116 100644 --- a/rfc8628_user_authorize_response_writer_test.go +++ b/rfc8628_user_authorize_response_writer_test.go @@ -71,6 +71,5 @@ func TestFosite_NewRFC8628UserAuthorizeResponse(t *testing.T) { } else { assert.NotNil(t, responder, "%d", k) } - t.Logf("Passed test case %d", k) } } diff --git a/session.go b/session.go index ea691175..9b1faa9a 100644 --- a/session.go +++ b/session.go @@ -19,7 +19,7 @@ type Session interface { // GetExpiresAt returns the expiration time of a token if set, or time.IsZero() if not. // - // session.GetExpiresAt(oauth2.AccessToken) + // session.GetExpiresTimeX(oauth2.AccessToken) GetExpiresAt(key TokenType) time.Time // GetUsername returns the username, if set. This is optional and only used during token introspection. diff --git a/testing/mock/client.go b/testing/mock/client.go index 4d43f535..7c88ba6c 100644 --- a/testing/mock/client.go +++ b/testing/mock/client.go @@ -67,6 +67,22 @@ func (mr *MockClientMockRecorder) GetClientSecret() *gomock.Call { return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetClientSecret", reflect.TypeOf((*MockClient)(nil).GetClientSecret)) } +// GetClientSecretPlainText mocks base method. +func (m *MockClient) GetClientSecretPlainText() ([]byte, bool, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetClientSecretPlainText") + ret0, _ := ret[0].([]byte) + ret1, _ := ret[1].(bool) + ret2, _ := ret[2].(error) + return ret0, ret1, ret2 +} + +// GetClientSecretPlainText indicates an expected call of GetClientSecretPlainText. +func (mr *MockClientMockRecorder) GetClientSecretPlainText() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetClientSecretPlainText", reflect.TypeOf((*MockClient)(nil).GetClientSecretPlainText)) +} + // GetGrantTypes mocks base method. func (m *MockClient) GetGrantTypes() oauth2.Arguments { m.ctrl.T.Helper() diff --git a/testing/mock/pkce_storage.go b/testing/mock/pkce_storage.go index f3a9014d..baf4fb4f 100644 --- a/testing/mock/pkce_storage.go +++ b/testing/mock/pkce_storage.go @@ -1,9 +1,9 @@ // Code generated by MockGen. DO NOT EDIT. -// Source: authelia.com/provider/oauth2/handler/pkce (interfaces: PKCERequestStorage) +// Source: authelia.com/provider/oauth2/handler/pkce (interfaces: Storage) // // Generated by this command: // -// mockgen -package mock -destination testing/mock/pkce_storage.go authelia.com/provider/oauth2/handler/pkce PKCERequestStorage +// mockgen -package mock -destination testing/mock/pkce_storage.go -mock_names Storage=MockPKCERequestStorage authelia.com/provider/oauth2/handler/pkce Storage // // Package mock is a generated GoMock package. @@ -17,7 +17,7 @@ import ( gomock "go.uber.org/mock/gomock" ) -// MockPKCERequestStorage is a mock of PKCERequestStorage interface. +// MockPKCERequestStorage is a mock of Storage interface. type MockPKCERequestStorage struct { ctrl *gomock.Controller recorder *MockPKCERequestStorageMockRecorder diff --git a/token/hmac/hmacsha_test.go b/token/hmac/hmacsha_test.go index a064a2f1..63b8a0c8 100644 --- a/token/hmac/hmacsha_test.go +++ b/token/hmac/hmacsha_test.go @@ -65,7 +65,7 @@ func TestValidateSignatureRejects(t *testing.T) { cg := HMACStrategy{ Config: &oauth2.Config{GlobalSecret: []byte("1234567890123456789012345678901234567890")}, } - for k, c := range []string{ + for _, c := range []string{ "", " ", "foo.bar", @@ -74,7 +74,6 @@ func TestValidateSignatureRejects(t *testing.T) { } { err = cg.Validate(context.Background(), c) assert.Error(t, err) - t.Logf("Passed test case %d", k) } } diff --git a/token/jarm/generate.go b/token/jarm/generate.go index bb4e18bd..4801fc5e 100644 --- a/token/jarm/generate.go +++ b/token/jarm/generate.go @@ -4,9 +4,6 @@ import ( "context" "errors" "net/url" - "time" - - "github.com/google/uuid" "authelia.com/provider/oauth2/internal/consts" "authelia.com/provider/oauth2/token/jwt" @@ -22,15 +19,15 @@ func EncodeParameters(token, _ string, tErr error) (parameters url.Values, err e } // Generate generates the token and signature for a JARM response. -func Generate(ctx context.Context, config Configurator, client Client, session any, in url.Values) (token, signature string, err error) { +func Generate(ctx context.Context, config Configurator, client Client, session any, parameters url.Values) (token, signature string, err error) { headers := map[string]any{} if alg := client.GetAuthorizationSignedResponseAlg(); len(alg) > 0 { - headers[consts.JSONWebTokenHeaderAlgorithm] = alg + headers[jwt.JSONWebTokenHeaderAlgorithm] = alg } if kid := client.GetAuthorizationSignedResponseKeyID(); len(kid) > 0 { - headers[consts.JSONWebTokenHeaderKeyIdentifier] = kid + headers[jwt.JSONWebTokenHeaderKeyIdentifier] = kid } var issuer string @@ -55,29 +52,22 @@ func Generate(ctx context.Context, config Configurator, client Client, session a return "", "", errors.New("The JARM response modes require the Authorize Requester session to implement either the openid.Session or oauth2.JWTSessionContainer interfaces but it doesn't.") } - if value, ok = src[consts.ClaimIssuer]; ok { + if value, ok = src[jwt.ClaimIssuer]; ok { issuer, _ = value.(string) } } - claims := &jwt.JARMClaims{ - JTI: uuid.New().String(), - Issuer: issuer, - IssuedAt: time.Now().UTC(), - ExpiresAt: time.Now().UTC().Add(config.GetJWTSecuredAuthorizeResponseModeLifespan(ctx)), - Audience: []string{client.GetID()}, - Extra: map[string]any{}, - } + claims := jwt.NewJARMClaims(issuer, jwt.ClaimStrings{client.GetID()}, config.GetJWTSecuredAuthorizeResponseModeLifespan(ctx)) - for param := range in { - claims.Extra[param] = in.Get(param) + for param := range parameters { + claims.Extra[param] = parameters.Get(param) } - var signer jwt.Signer + var strategy jwt.Strategy - if signer = config.GetJWTSecuredAuthorizeResponseModeSigner(ctx); signer == nil { - return "", "", errors.New("The JARM response modes require the JWTSecuredAuthorizeResponseModeSignerProvider to return a jwt.Signer but it didn't.") + if strategy = config.GetJWTSecuredAuthorizeResponseModeStrategy(ctx); strategy == nil { + return "", "", errors.New("The JARM response modes require the JWTSecuredAuthorizeResponseModeSignerProvider to return a jwt.Strategy but it didn't.") } - return signer.Generate(ctx, claims.ToMapClaims(), &jwt.Headers{Extra: headers}) + return strategy.Encode(ctx, claims.ToMapClaims(), jwt.WithHeaders(&jwt.Headers{Extra: headers}), jwt.WithJARMClient(client)) } diff --git a/token/jarm/types.go b/token/jarm/types.go index 2ad362da..4bb0fdae 100644 --- a/token/jarm/types.go +++ b/token/jarm/types.go @@ -9,7 +9,7 @@ import ( type Configurator interface { GetJWTSecuredAuthorizeResponseModeIssuer(ctx context.Context) string - GetJWTSecuredAuthorizeResponseModeSigner(ctx context.Context) jwt.Signer + GetJWTSecuredAuthorizeResponseModeStrategy(ctx context.Context) jwt.Strategy GetJWTSecuredAuthorizeResponseModeLifespan(ctx context.Context) time.Duration } @@ -17,6 +17,9 @@ type Client interface { // GetID returns the client ID. GetID() (id string) + // IsPublic returns true if the client has the public client type. + IsPublic() (public bool) + // GetAuthorizationSignedResponseKeyID returns the specific key identifier used to satisfy JWS requirements of the // JWT-secured Authorization Response Method (JARM) specifications. If unspecified the other available parameters // will be utilized to select an appropriate key. diff --git a/token/jwt/claims.go b/token/jwt/claims.go index 501e1ad3..36719a70 100644 --- a/token/jwt/claims.go +++ b/token/jwt/claims.go @@ -7,6 +7,17 @@ import ( "time" ) +type Claims interface { + GetExpirationTime() (exp *NumericDate, err error) + GetIssuedAt() (iat *NumericDate, err error) + GetNotBefore() (nbf *NumericDate, err error) + GetIssuer() (iss string, err error) + GetSubject() (sub string, err error) + GetAudience() (aud ClaimStrings, err error) + ToMapClaims() MapClaims + Valid(opts ...ClaimValidationOption) (err error) +} + // Mapper is the interface used internally to map key-value pairs type Mapper interface { ToMap() map[string]any diff --git a/token/jwt/claims_id_token.go b/token/jwt/claims_id_token.go index 862b9305..7bddfe33 100644 --- a/token/jwt/claims_id_token.go +++ b/token/jwt/claims_id_token.go @@ -4,11 +4,16 @@ package jwt import ( + "bytes" + "errors" + "fmt" "time" + jjson "github.com/go-jose/go-jose/v4/json" "github.com/google/uuid" "authelia.com/provider/oauth2/internal/consts" + "authelia.com/provider/oauth2/x/errorsx" ) // IDTokenClaims represent the claims used in open id connect requests @@ -17,109 +22,340 @@ type IDTokenClaims struct { Issuer string `json:"iss"` Subject string `json:"sub"` Audience []string `json:"aud"` - Nonce string `json:"nonce"` - ExpiresAt time.Time `json:"exp"` - IssuedAt time.Time `json:"iat"` - RequestedAt time.Time `json:"rat"` - AuthTime time.Time `json:"auth_time"` - AccessTokenHash string `json:"at_hash"` - AuthenticationContextClassReference string `json:"acr"` - AuthenticationMethodsReferences []string `json:"amr"` - CodeHash string `json:"c_hash"` - StateHash string `json:"s_hash"` - Extra map[string]any `json:"ext"` + ExpirationTime *NumericDate `json:"exp"` + IssuedAt *NumericDate `json:"iat"` + AuthTime *NumericDate `json:"auth_time,omitempty"` + RequestedAt *NumericDate `json:"rat,omitempty"` + Nonce string `json:"nonce,omitempty"` + AuthenticationContextClassReference string `json:"acr,omitempty"` + AuthenticationMethodsReferences []string `json:"amr,omitempty"` + AuthorizedParty string `json:"azp,omitempty"` + AccessTokenHash string `json:"at_hash,omitempty"` + CodeHash string `json:"c_hash,omitempty"` + StateHash string `json:"s_hash,omitempty"` + Extra map[string]any `json:"ext,omitempty"` +} + +func (c *IDTokenClaims) GetExpirationTime() (exp *NumericDate, err error) { + return c.ExpirationTime, nil +} + +func (c *IDTokenClaims) GetIssuedAt() (iat *NumericDate, err error) { + return c.IssuedAt, nil +} + +func (c *IDTokenClaims) GetNotBefore() (nbf *NumericDate, err error) { + return toNumericDate(ClaimNotBefore) +} + +func (c *IDTokenClaims) GetIssuer() (iss string, err error) { + return c.Issuer, nil +} + +func (c *IDTokenClaims) GetSubject() (sub string, err error) { + return c.Subject, nil +} + +func (c *IDTokenClaims) GetAudience() (aud ClaimStrings, err error) { + return c.Audience, nil +} + +func (c IDTokenClaims) Valid(opts ...ClaimValidationOption) (err error) { + vopts := &ClaimValidationOptions{} + + for _, opt := range opts { + opt(vopts) + } + + var now int64 + + if vopts.timef != nil { + now = vopts.timef().UTC().Unix() + } else { + now = TimeFunc().UTC().Unix() + } + + vErr := new(ValidationError) + + var date *NumericDate + + if date, err = c.GetExpirationTime(); !validDate(validInt64Future, now, vopts.expRequired, date, err) { + vErr.Inner = errors.New("Token is expired") + vErr.Errors |= ValidationErrorExpired + } + + if date, err = c.GetIssuedAt(); !validDate(validInt64Past, now, vopts.expRequired, date, err) { + vErr.Inner = errors.New("Token used before issued") + vErr.Errors |= ValidationErrorIssuedAt + } + + if date, err = c.GetNotBefore(); !validDate(validInt64Past, now, vopts.expRequired, date, err) { + vErr.Inner = errors.New("Token is not valid yet") + vErr.Errors |= ValidationErrorNotValidYet + } + + var str string + + if len(vopts.iss) != 0 { + if str, err = c.GetIssuer(); err != nil { + vErr.Inner = errors.New("Token has invalid issuer") + vErr.Errors |= ValidationErrorIssuer + } else if !validString(str, vopts.iss, true) { + vErr.Inner = errors.New("Token has invalid issuer") + vErr.Errors |= ValidationErrorIssuer + } + } + + if len(vopts.sub) != 0 { + if str, err = c.GetSubject(); err != nil { + vErr.Inner = errors.New("Token has invalid subject") + vErr.Errors |= ValidationErrorIssuer + } else if !validString(str, vopts.sub, true) { + vErr.Inner = errors.New("Token has invalid subject") + vErr.Errors |= ValidationErrorSubject + } + } + + var aud ClaimStrings + + if len(vopts.aud) != 0 { + if aud, err = c.GetAudience(); err != nil || aud == nil || !aud.ValidAny(vopts.aud, true) { + vErr.Inner = errors.New("Token has invalid audience") + vErr.Errors |= ValidationErrorAudience + } + } + + if len(vopts.audAll) != 0 { + if aud, err = c.GetAudience(); err != nil || aud == nil || !aud.ValidAll(vopts.audAll, true) { + vErr.Inner = errors.New("Token has invalid audience") + vErr.Errors |= ValidationErrorAudience + } + } + + if vErr.valid() { + return nil + } + + return vErr +} + +func (c *IDTokenClaims) GetExpirationTimeSafe() time.Time { + if c.ExpirationTime == nil { + return time.Unix(0, 0).UTC() + } + + return c.ExpirationTime.UTC() +} + +func (c *IDTokenClaims) GetIssuedAtSafe() time.Time { + if c.IssuedAt == nil { + return time.Unix(0, 0).UTC() + } + + return c.IssuedAt.UTC() +} + +func (c *IDTokenClaims) GetAuthTimeSafe() time.Time { + if c.AuthTime == nil { + return time.Unix(0, 0).UTC() + } + + return c.AuthTime.UTC() +} + +func (c *IDTokenClaims) GetRequestedAtSafe() time.Time { + if c.RequestedAt == nil { + return time.Unix(0, 0).UTC() + } + + return c.RequestedAt.UTC() +} + +func (c *IDTokenClaims) UnmarshalJSON(data []byte) error { + claims := MapClaims{} + + decoder := jjson.NewDecoder(bytes.NewReader(data)) + decoder.SetNumberType(jjson.UnmarshalIntOrFloat) + + if err := decoder.Decode(&claims); err != nil { + return errorsx.WithStack(err) + } + + var ( + ok bool + err error + ) + + for claim, value := range claims { + ok = false + + switch claim { + case ClaimJWTID: + c.JTI, ok = value.(string) + case ClaimIssuer: + c.Issuer, ok = value.(string) + case ClaimSubject: + c.Subject, ok = value.(string) + case ClaimAudience: + c.Audience, ok = toStringSlice(value) + case ClaimExpirationTime: + if c.ExpirationTime, err = toNumericDate(value); err == nil { + ok = true + } + case ClaimIssuedAt: + if c.IssuedAt, err = toNumericDate(value); err == nil { + ok = true + } + case ClaimAuthenticationTime: + if c.AuthTime, err = toNumericDate(value); err == nil { + ok = true + } + case ClaimRequestedAt: + if c.RequestedAt, err = toNumericDate(value); err == nil { + ok = true + } + case ClaimNonce: + c.Nonce, ok = value.(string) + case ClaimAuthenticationContextClassReference: + c.AuthenticationContextClassReference, ok = value.(string) + case ClaimAuthenticationMethodsReference: + c.AuthenticationMethodsReferences, ok = toStringSlice(value) + case ClaimAuthorizedParty: + c.AuthorizedParty, ok = value.(string) + case ClaimAccessTokenHash: + c.AccessTokenHash, ok = value.(string) + case ClaimCodeHash: + c.CodeHash, ok = value.(string) + case ClaimStateHash: + c.StateHash, ok = value.(string) + case ClaimExtra: + c.Extra, ok = value.(map[string]any) + default: + if c.Extra == nil { + c.Extra = make(map[string]any) + } + + c.Extra[claim] = value + + continue + } + + if !ok { + return fmt.Errorf("claim %s with value %v could not be decoded", claim, value) + } + } + + return nil } // ToMap will transform the headers to a map structure func (c *IDTokenClaims) ToMap() map[string]any { var ret = Copy(c.Extra) - if c.Subject != "" { - ret[consts.ClaimSubject] = c.Subject + if c.JTI != "" { + ret[ClaimJWTID] = c.JTI } else { - delete(ret, consts.ClaimSubject) + ret[ClaimJWTID] = uuid.New().String() } if c.Issuer != "" { - ret[consts.ClaimIssuer] = c.Issuer + ret[ClaimIssuer] = c.Issuer } else { delete(ret, consts.ClaimIssuer) } - if c.JTI != "" { - ret[consts.ClaimJWTID] = c.JTI + if c.Subject != "" { + ret[ClaimSubject] = c.Subject } else { - ret[consts.ClaimJWTID] = uuid.New().String() + delete(ret, ClaimSubject) } if len(c.Audience) > 0 { - ret[consts.ClaimAudience] = c.Audience + ret[ClaimAudience] = c.Audience + } else { + delete(ret, ClaimAudience) + } + + if c.ExpirationTime != nil { + ret[ClaimExpirationTime] = c.ExpirationTime.Unix() } else { - ret[consts.ClaimAudience] = []string{} + delete(ret, ClaimExpirationTime) } - if !c.IssuedAt.IsZero() { - ret[consts.ClaimIssuedAt] = c.IssuedAt.Unix() + if c.IssuedAt != nil { + ret[ClaimIssuedAt] = c.IssuedAt.Unix() } else { - delete(ret, consts.ClaimIssuedAt) + delete(ret, ClaimIssuedAt) } - if !c.ExpiresAt.IsZero() { - ret[consts.ClaimExpirationTime] = c.ExpiresAt.Unix() + if c.AuthTime != nil { + ret[ClaimAuthenticationTime] = c.AuthTime.Unix() } else { - delete(ret, consts.ClaimExpirationTime) + delete(ret, ClaimAuthenticationTime) + } + + if c.RequestedAt != nil { + ret[ClaimRequestedAt] = c.RequestedAt.Unix() + } else { + delete(ret, ClaimRequestedAt) } if len(c.Nonce) > 0 { - ret[consts.ClaimNonce] = c.Nonce + ret[ClaimNonce] = c.Nonce } else { - delete(ret, consts.ClaimNonce) + delete(ret, ClaimNonce) } - if len(c.AccessTokenHash) > 0 { - ret[consts.ClaimAccessTokenHash] = c.AccessTokenHash + if len(c.AuthenticationContextClassReference) > 0 { + ret[ClaimAuthenticationContextClassReference] = c.AuthenticationContextClassReference } else { - delete(ret, consts.ClaimAccessTokenHash) + delete(ret, ClaimAuthenticationContextClassReference) } - if len(c.CodeHash) > 0 { - ret[consts.ClaimCodeHash] = c.CodeHash + if len(c.AuthenticationMethodsReferences) > 0 { + ret[ClaimAuthenticationMethodsReference] = c.AuthenticationMethodsReferences } else { - delete(ret, consts.ClaimCodeHash) + delete(ret, ClaimAuthenticationMethodsReference) } - if len(c.StateHash) > 0 { - ret[consts.ClaimStateHash] = c.StateHash + if len(c.AuthorizedParty) > 0 { + ret[ClaimAuthorizedParty] = c.AuthorizedParty } else { - delete(ret, consts.ClaimStateHash) + delete(ret, ClaimAuthorizedParty) } - if !c.AuthTime.IsZero() { - ret[consts.ClaimAuthenticationTime] = c.AuthTime.Unix() + if len(c.AccessTokenHash) > 0 { + ret[ClaimAccessTokenHash] = c.AccessTokenHash } else { - delete(ret, consts.ClaimAuthenticationTime) + delete(ret, ClaimAccessTokenHash) } - if len(c.AuthenticationContextClassReference) > 0 { - ret[consts.ClaimAuthenticationContextClassReference] = c.AuthenticationContextClassReference + if len(c.CodeHash) > 0 { + ret[ClaimCodeHash] = c.CodeHash } else { - delete(ret, consts.ClaimAuthenticationContextClassReference) + delete(ret, ClaimCodeHash) } - if len(c.AuthenticationMethodsReferences) > 0 { - ret[consts.ClaimAuthenticationMethodsReference] = c.AuthenticationMethodsReferences + if len(c.StateHash) > 0 { + ret[ClaimStateHash] = c.StateHash } else { - delete(ret, consts.ClaimAuthenticationMethodsReference) + delete(ret, ClaimStateHash) } return ret } +// ToMapClaims will return a jwt-go MapClaims representation +func (c IDTokenClaims) ToMapClaims() MapClaims { + return c.ToMap() +} + // Add will add a key-value pair to the extra field func (c *IDTokenClaims) Add(key string, value any) { if c.Extra == nil { c.Extra = make(map[string]any) } + c.Extra[key] = value } @@ -128,7 +364,50 @@ func (c *IDTokenClaims) Get(key string) any { return c.ToMap()[key] } -// ToMapClaims will return a jwt-go MapClaims representation -func (c IDTokenClaims) ToMapClaims() MapClaims { - return c.ToMap() +func (c IDTokenClaims) toNumericDate(key string) (date *NumericDate, err error) { + var ( + v any + ok bool + ) + + if v, ok = c.Extra[key]; !ok { + return nil, nil + } + + return toNumericDate(v) +} + +func toStringSlice(value any) (values []string, ok bool) { + switch t := value.(type) { + case nil: + ok = true + case string: + ok = true + + values = []string{t} + case []string: + ok = true + + values = t + case []any: + ok = true + + loop: + for _, tv := range t { + switch vv := tv.(type) { + case string: + values = append(values, vv) + default: + ok = false + + break loop + } + } + } + + return values, ok } + +var ( + _ Claims = (*IDTokenClaims)(nil) +) diff --git a/token/jwt/claims_id_token_test.go b/token/jwt/claims_id_token_test.go index 0fecf56f..8c22119c 100644 --- a/token/jwt/claims_id_token_test.go +++ b/token/jwt/claims_id_token_test.go @@ -14,24 +14,24 @@ import ( ) func TestIDTokenAssert(t *testing.T) { - assert.NoError(t, (&IDTokenClaims{ExpiresAt: time.Now().UTC().Add(time.Hour)}). + assert.NoError(t, (&IDTokenClaims{ExpirationTime: NewNumericDate(time.Now().Add(time.Hour))}). ToMapClaims().Valid()) - assert.Error(t, (&IDTokenClaims{ExpiresAt: time.Now().UTC().Add(-time.Hour)}). + assert.Error(t, (&IDTokenClaims{ExpirationTime: NewNumericDate(time.Now().Add(-time.Hour))}). ToMapClaims().Valid()) - assert.NotEmpty(t, (new(IDTokenClaims)).ToMapClaims()[consts.ClaimJWTID]) + assert.NotEmpty(t, (new(IDTokenClaims)).ToMapClaims()[ClaimJWTID]) } func TestIDTokenClaimsToMap(t *testing.T) { idTokenClaims := &IDTokenClaims{ JTI: "foo-id", Subject: "peter", - IssuedAt: time.Now().UTC().Round(time.Second), + IssuedAt: Now(), Issuer: "authelia", Audience: []string{"tests"}, - ExpiresAt: time.Now().UTC().Add(time.Hour).Round(time.Second), - AuthTime: time.Now().UTC(), - RequestedAt: time.Now().UTC(), + ExpirationTime: NewNumericDate(time.Now().Add(time.Hour)), + AuthTime: Now(), + RequestedAt: Now(), AccessTokenHash: "foobar", CodeHash: "barfoo", StateHash: "boofar", @@ -43,18 +43,19 @@ func TestIDTokenClaimsToMap(t *testing.T) { }, } assert.Equal(t, map[string]any{ - consts.ClaimJWTID: idTokenClaims.JTI, - consts.ClaimSubject: idTokenClaims.Subject, - consts.ClaimIssuedAt: idTokenClaims.IssuedAt.Unix(), - consts.ClaimIssuer: idTokenClaims.Issuer, - consts.ClaimAudience: idTokenClaims.Audience, - consts.ClaimExpirationTime: idTokenClaims.ExpiresAt.Unix(), - "foo": idTokenClaims.Extra["foo"], - "baz": idTokenClaims.Extra["baz"], - consts.ClaimAccessTokenHash: idTokenClaims.AccessTokenHash, - consts.ClaimCodeHash: idTokenClaims.CodeHash, - consts.ClaimStateHash: idTokenClaims.StateHash, - consts.ClaimAuthenticationTime: idTokenClaims.AuthTime.Unix(), + ClaimJWTID: idTokenClaims.JTI, + ClaimSubject: idTokenClaims.Subject, + ClaimIssuedAt: idTokenClaims.IssuedAt.Unix(), + ClaimIssuer: idTokenClaims.Issuer, + ClaimAudience: idTokenClaims.Audience, + ClaimExpirationTime: idTokenClaims.ExpirationTime.Unix(), + ClaimRequestedAt: idTokenClaims.RequestedAt.Unix(), + "foo": idTokenClaims.Extra["foo"], + "baz": idTokenClaims.Extra["baz"], + ClaimAccessTokenHash: idTokenClaims.AccessTokenHash, + ClaimCodeHash: idTokenClaims.CodeHash, + ClaimStateHash: idTokenClaims.StateHash, + ClaimAuthenticationTime: idTokenClaims.AuthTime.Unix(), consts.ClaimAuthenticationContextClassReference: idTokenClaims.AuthenticationContextClassReference, consts.ClaimAuthenticationMethodsReference: idTokenClaims.AuthenticationMethodsReferences, }, idTokenClaims.ToMap()) @@ -66,7 +67,8 @@ func TestIDTokenClaimsToMap(t *testing.T) { consts.ClaimIssuedAt: idTokenClaims.IssuedAt.Unix(), consts.ClaimIssuer: idTokenClaims.Issuer, consts.ClaimAudience: idTokenClaims.Audience, - consts.ClaimExpirationTime: idTokenClaims.ExpiresAt.Unix(), + consts.ClaimExpirationTime: idTokenClaims.ExpirationTime.Unix(), + consts.ClaimRequestedAt: idTokenClaims.RequestedAt.Unix(), "foo": idTokenClaims.Extra["foo"], "baz": idTokenClaims.Extra["baz"], consts.ClaimAccessTokenHash: idTokenClaims.AccessTokenHash, diff --git a/token/jwt/claims_jarm.go b/token/jwt/claims_jarm.go index d43c5a98..dd0a4773 100644 --- a/token/jwt/claims_jarm.go +++ b/token/jwt/claims_jarm.go @@ -1,21 +1,61 @@ package jwt import ( + "fmt" "time" "github.com/google/uuid" - - "authelia.com/provider/oauth2/internal/consts" ) +func NewJARMClaims(issuer string, aud ClaimStrings, lifespan time.Duration) *JARMClaims { + now := time.Now() + + return &JARMClaims{ + Issuer: issuer, + Audience: aud, + JTI: uuid.NewString(), + IssuedAt: NewNumericDate(now), + ExpirationTime: NewNumericDate(now.Add(lifespan)), + Extra: map[string]any{}, + } +} + // JARMClaims represent a token's claims. type JARMClaims struct { - Issuer string - Audience []string - JTI string - IssuedAt time.Time - ExpiresAt time.Time - Extra map[string]any + Issuer string `json:"iss"` + Audience ClaimStrings `json:"aud"` + JTI string `json:"jti"` + IssuedAt *NumericDate `json:"iat,omitempty"` + ExpirationTime *NumericDate `json:"exp,omitempty"` + Extra map[string]any `json:"-"` +} + +func (c *JARMClaims) GetExpirationTime() (exp *NumericDate, err error) { + return c.ExpirationTime, nil +} + +func (c *JARMClaims) GetIssuedAt() (iat *NumericDate, err error) { + return c.IssuedAt, nil +} + +func (c *JARMClaims) GetNotBefore() (nbf *NumericDate, err error) { + return c.toNumericDate(ClaimNotBefore) +} + +func (c *JARMClaims) GetIssuer() (iss string, err error) { + return c.Issuer, nil +} + +func (c *JARMClaims) GetSubject() (sub string, err error) { + return c.toString(ClaimIssuer) +} + +func (c *JARMClaims) GetAudience() (aud ClaimStrings, err error) { + return c.Audience, nil +} + +func (c *JARMClaims) Valid(opts ...ClaimValidationOption) (err error) { + return nil } // ToMap will transform the headers to a map structure @@ -23,65 +63,75 @@ func (c *JARMClaims) ToMap() map[string]any { var ret = Copy(c.Extra) if c.Issuer != "" { - ret[consts.ClaimIssuer] = c.Issuer + ret[ClaimIssuer] = c.Issuer } else { - delete(ret, consts.ClaimIssuer) + delete(ret, ClaimIssuer) } if c.JTI != "" { - ret[consts.ClaimJWTID] = c.JTI + ret[ClaimJWTID] = c.JTI } else { - ret[consts.ClaimJWTID] = uuid.New().String() + ret[ClaimJWTID] = uuid.New().String() } if len(c.Audience) > 0 { - ret[consts.ClaimAudience] = c.Audience + ret[ClaimAudience] = []string(c.Audience) } else { - ret[consts.ClaimAudience] = []string{} + ret[ClaimAudience] = []string{} } - if !c.IssuedAt.IsZero() { - ret[consts.ClaimIssuedAt] = c.IssuedAt.Unix() + if c.IssuedAt != nil { + ret[ClaimIssuedAt] = c.IssuedAt.Unix() } else { - delete(ret, consts.ClaimIssuedAt) + delete(ret, ClaimIssuedAt) } - if !c.ExpiresAt.IsZero() { - ret[consts.ClaimExpirationTime] = c.ExpiresAt.Unix() + if c.ExpirationTime != nil { + ret[ClaimExpirationTime] = c.ExpirationTime.Unix() } else { - delete(ret, consts.ClaimExpirationTime) + delete(ret, ClaimExpirationTime) } return ret } +// ToMapClaims will return a jwt-go MapClaims representation +func (c JARMClaims) ToMapClaims() MapClaims { + return c.ToMap() +} + // FromMap will set the claims based on a mapping func (c *JARMClaims) FromMap(m map[string]any) { c.Extra = make(map[string]any) for k, v := range m { switch k { - case consts.ClaimIssuer: + case ClaimIssuer: if s, ok := v.(string); ok { c.Issuer = s } - case consts.ClaimJWTID: + case ClaimJWTID: if s, ok := v.(string); ok { c.JTI = s } - case consts.ClaimAudience: + case ClaimAudience: if aud, ok := StringSliceFromMap(v); ok { c.Audience = aud } - case consts.ClaimIssuedAt: - c.IssuedAt = toTime(v, c.IssuedAt) - case consts.ClaimExpirationTime: - c.ExpiresAt = toTime(v, c.ExpiresAt) + case ClaimIssuedAt: + c.IssuedAt, _ = toNumericDate(v) + case ClaimExpirationTime: + c.ExpirationTime, _ = toNumericDate(v) default: c.Extra[k] = v } } } +// FromMapClaims will populate claims from a jwt-go MapClaims representation +func (c *JARMClaims) FromMapClaims(mc MapClaims) { + c.FromMap(mc) +} + // Add will add a key-value pair to the extra field func (c *JARMClaims) Add(key string, value any) { if c.Extra == nil { @@ -96,12 +146,36 @@ func (c JARMClaims) Get(key string) any { return c.ToMap()[key] } -// ToMapClaims will return a jwt-go MapClaims representation -func (c JARMClaims) ToMapClaims() MapClaims { - return c.ToMap() +func (c JARMClaims) toNumericDate(key string) (date *NumericDate, err error) { + var ( + v any + ok bool + ) + + if v, ok = c.Extra[key]; !ok { + return nil, nil + } + + return toNumericDate(v) } -// FromMapClaims will populate claims from a jwt-go MapClaims representation -func (c *JARMClaims) FromMapClaims(mc MapClaims) { - c.FromMap(mc) +func (c JARMClaims) toString(key string) (value string, err error) { + var ( + ok bool + raw any + ) + + if raw, ok = c.Extra[key]; !ok { + return "", nil + } + + if value, ok = raw.(string); !ok { + return "", newError(fmt.Sprintf("%s is invalid", key), ErrInvalidType) + } + + return value, nil } + +var ( + _ Claims = (*JARMClaims)(nil) +) diff --git a/token/jwt/claims_jarm_test.go b/token/jwt/claims_jarm_test.go index 941236d7..f114177e 100644 --- a/token/jwt/claims_jarm_test.go +++ b/token/jwt/claims_jarm_test.go @@ -9,16 +9,15 @@ import ( "github.com/stretchr/testify/assert" - "authelia.com/provider/oauth2/internal/consts" . "authelia.com/provider/oauth2/token/jwt" ) var jarmClaims = &JARMClaims{ - Issuer: "authelia", - Audience: []string{"tests"}, - JTI: "abcdef", - IssuedAt: time.Now().UTC().Round(time.Second), - ExpiresAt: time.Now().UTC().Add(time.Hour).Round(time.Second), + Issuer: "authelia", + Audience: []string{"tests"}, + JTI: "abcdef", + IssuedAt: Now(), + ExpirationTime: NewNumericDate(time.Now().Add(time.Hour)), Extra: map[string]any{ "foo": "bar", "baz": "bar", @@ -26,13 +25,13 @@ var jarmClaims = &JARMClaims{ } var jarmClaimsMap = map[string]any{ - consts.ClaimIssuer: jwtClaims.Issuer, - consts.ClaimAudience: jwtClaims.Audience, - consts.ClaimJWTID: jwtClaims.JTI, - consts.ClaimIssuedAt: jwtClaims.IssuedAt.Unix(), - consts.ClaimExpirationTime: jwtClaims.ExpiresAt.Unix(), - "foo": jwtClaims.Extra["foo"], - "baz": jwtClaims.Extra["baz"], + ClaimIssuer: jwtClaims.Issuer, + ClaimAudience: jwtClaims.Audience, + ClaimJWTID: jwtClaims.JTI, + ClaimIssuedAt: jwtClaims.IssuedAt.Unix(), + ClaimExpirationTime: jwtClaims.ExpiresAt.Unix(), + "foo": jwtClaims.Extra["foo"], + "baz": jwtClaims.Extra["baz"], } func TestJARMClaimAddGetString(t *testing.T) { @@ -41,13 +40,13 @@ func TestJARMClaimAddGetString(t *testing.T) { } func TestJARMClaimsToMapSetsID(t *testing.T) { - assert.NotEmpty(t, (&JARMClaims{}).ToMap()[consts.ClaimJWTID]) + assert.NotEmpty(t, (&JARMClaims{}).ToMap()[ClaimJWTID]) } func TestJARMAssert(t *testing.T) { - assert.Nil(t, (&JARMClaims{ExpiresAt: time.Now().UTC().Add(time.Hour)}). + assert.Nil(t, (&JARMClaims{ExpirationTime: NewNumericDate(time.Now().Add(time.Hour))}). ToMapClaims().Valid()) - assert.NotNil(t, (&JARMClaims{ExpiresAt: time.Now().UTC().Add(-2 * time.Hour)}). + assert.NotNil(t, (&JARMClaims{ExpirationTime: NewNumericDate(time.Now().Add(-2 * time.Hour))}). ToMapClaims().Valid()) } diff --git a/token/jwt/claims_jwt.go b/token/jwt/claims_jwt.go index 0ae784de..db1bcf93 100644 --- a/token/jwt/claims_jwt.go +++ b/token/jwt/claims_jwt.go @@ -4,6 +4,7 @@ package jwt import ( + "encoding/json" "strings" "time" @@ -43,7 +44,7 @@ type JWTClaimsContainer interface { // WithScopeField configures how a scope field should be represented in JWT. WithScopeField(scopeField JWTScopeFieldEnum) JWTClaimsContainer - // ToMapClaims returns the claims as a github.com/dgrijalva/jwt-go.MapClaims type. + // ToMapClaims returns the claims as a MapClaims type. ToMapClaims() MapClaims } @@ -100,58 +101,58 @@ func (c *JWTClaims) ToMap() map[string]any { var ret = Copy(c.Extra) if c.Subject != "" { - ret[consts.ClaimSubject] = c.Subject + ret[ClaimSubject] = c.Subject } else { - delete(ret, consts.ClaimSubject) + delete(ret, ClaimSubject) } if c.Issuer != "" { - ret[consts.ClaimIssuer] = c.Issuer + ret[ClaimIssuer] = c.Issuer } else { - delete(ret, consts.ClaimIssuer) + delete(ret, ClaimIssuer) } if c.JTI != "" { - ret[consts.ClaimJWTID] = c.JTI + ret[ClaimJWTID] = c.JTI } else { - ret[consts.ClaimJWTID] = uuid.New().String() + ret[ClaimJWTID] = uuid.New().String() } if len(c.Audience) > 0 { - ret[consts.ClaimAudience] = c.Audience + ret[ClaimAudience] = c.Audience } else { - ret[consts.ClaimAudience] = []string{} + ret[ClaimAudience] = []string{} } if !c.IssuedAt.IsZero() { - ret[consts.ClaimIssuedAt] = c.IssuedAt.Unix() + ret[ClaimIssuedAt] = c.IssuedAt.Unix() } else { - delete(ret, consts.ClaimIssuedAt) + delete(ret, ClaimIssuedAt) } if !c.NotBefore.IsZero() { - ret[consts.ClaimNotBefore] = c.NotBefore.Unix() + ret[ClaimNotBefore] = c.NotBefore.Unix() } else { - delete(ret, consts.ClaimNotBefore) + delete(ret, ClaimNotBefore) } if !c.ExpiresAt.IsZero() { - ret[consts.ClaimExpirationTime] = c.ExpiresAt.Unix() + ret[ClaimExpirationTime] = c.ExpiresAt.Unix() } else { - delete(ret, consts.ClaimExpirationTime) + delete(ret, ClaimExpirationTime) } if c.Scope != nil { // ScopeField default (when value is JWTScopeFieldUnset) is the list for backwards compatibility with old versions of oauth2. if c.ScopeField == JWTScopeFieldUnset || c.ScopeField == JWTScopeFieldList || c.ScopeField == JWTScopeFieldBoth { - ret[consts.ClaimScopeNonStandard] = c.Scope + ret[ClaimScopeNonStandard] = c.Scope } if c.ScopeField == JWTScopeFieldString || c.ScopeField == JWTScopeFieldBoth { - ret[consts.ClaimScope] = strings.Join(c.Scope, " ") + ret[ClaimScope] = strings.Join(c.Scope, " ") } } else { - delete(ret, consts.ClaimScopeNonStandard) - delete(ret, consts.ClaimScope) + delete(ret, ClaimScopeNonStandard) + delete(ret, ClaimScope) } return ret @@ -183,11 +184,11 @@ func (c *JWTClaims) FromMap(m map[string]any) { c.Audience = aud } case consts.ClaimIssuedAt: - c.IssuedAt = toTime(v, c.IssuedAt) + c.IssuedAt, _ = toTime(v, c.IssuedAt) case consts.ClaimNotBefore: - c.NotBefore = toTime(v, c.NotBefore) + c.NotBefore, _ = toTime(v, c.NotBefore) case consts.ClaimExpirationTime: - c.ExpiresAt = toTime(v, c.ExpiresAt) + c.ExpiresAt, _ = toTime(v, c.ExpiresAt) case consts.ClaimScopeNonStandard: switch s := v.(type) { case []string: @@ -225,15 +226,82 @@ func (c *JWTClaims) FromMap(m map[string]any) { } } -func toTime(v any, def time.Time) (t time.Time) { +func toTime(v any, def time.Time) (t time.Time, ok bool) { t = def - switch a := v.(type) { + + var value int64 + + if value, ok = toInt64(v); ok { + t = time.Unix(value, 0).UTC() + } + + return +} + +func toInt64(v any) (val int64, ok bool) { + var err error + + switch t := v.(type) { case float64: - t = time.Unix(int64(a), 0).UTC() + return int64(t), true case int64: - t = time.Unix(a, 0).UTC() + return t, true + case int32: + return int64(t), true + case int: + return int64(t), true + case json.Number: + if val, err = t.Int64(); err == nil { + return val, true + } + + var valf float64 + + if valf, err = t.Float64(); err != nil { + return 0, false + } + + return int64(valf), true } - return + + return 0, false +} + +func toNumericDate(v any) (date *NumericDate, err error) { + switch value := v.(type) { + case nil: + return nil, nil + case float64: + if value == 0 { + return nil, nil + } + + return newNumericDateFromSeconds(value), nil + case int64: + if value == 0 { + return nil, nil + } + + return newNumericDateFromSeconds(float64(value)), nil + case int32: + if value == 0 { + return nil, nil + } + + return newNumericDateFromSeconds(float64(value)), nil + case int: + if value == 0 { + return nil, nil + } + + return newNumericDateFromSeconds(float64(value)), nil + case json.Number: + vv, _ := value.Float64() + + return newNumericDateFromSeconds(vv), nil + } + + return nil, newError("value has invalid type", ErrInvalidType) } // Add will add a key-value pair to the extra field diff --git a/token/jwt/claims_jwt_test.go b/token/jwt/claims_jwt_test.go index 3b612140..7cc2ee07 100644 --- a/token/jwt/claims_jwt_test.go +++ b/token/jwt/claims_jwt_test.go @@ -15,11 +15,11 @@ import ( var jwtClaims = &JWTClaims{ Subject: "peter", - IssuedAt: time.Now().UTC().Round(time.Second), + IssuedAt: time.Now().UTC().Truncate(TimePrecision), Issuer: "authelia", - NotBefore: time.Now().UTC().Round(time.Second), + NotBefore: time.Now().UTC().Truncate(TimePrecision), Audience: []string{"tests"}, - ExpiresAt: time.Now().UTC().Add(time.Hour).Round(time.Second), + ExpiresAt: time.Now().UTC().Add(time.Hour).Truncate(TimePrecision), JTI: "abcdef", Scope: []string{consts.ScopeEmail, consts.ScopeOffline}, Extra: map[string]any{ @@ -30,16 +30,16 @@ var jwtClaims = &JWTClaims{ } var jwtClaimsMap = map[string]any{ - consts.ClaimSubject: jwtClaims.Subject, - consts.ClaimIssuedAt: jwtClaims.IssuedAt.Unix(), - consts.ClaimIssuer: jwtClaims.Issuer, - consts.ClaimNotBefore: jwtClaims.NotBefore.Unix(), - consts.ClaimAudience: jwtClaims.Audience, - consts.ClaimExpirationTime: jwtClaims.ExpiresAt.Unix(), - consts.ClaimJWTID: jwtClaims.JTI, - consts.ClaimScopeNonStandard: []string{consts.ScopeEmail, consts.ScopeOffline}, - "foo": jwtClaims.Extra["foo"], - "baz": jwtClaims.Extra["baz"], + ClaimSubject: jwtClaims.Subject, + ClaimIssuedAt: jwtClaims.IssuedAt.Unix(), + ClaimIssuer: jwtClaims.Issuer, + ClaimNotBefore: jwtClaims.NotBefore.Unix(), + ClaimAudience: jwtClaims.Audience, + ClaimExpirationTime: jwtClaims.ExpiresAt.Unix(), + ClaimJWTID: jwtClaims.JTI, + ClaimScopeNonStandard: []string{consts.ScopeEmail, consts.ScopeOffline}, + "foo": jwtClaims.Extra["foo"], + "baz": jwtClaims.Extra["baz"], } func TestClaimAddGetString(t *testing.T) { @@ -48,7 +48,7 @@ func TestClaimAddGetString(t *testing.T) { } func TestClaimsToMapSetsID(t *testing.T) { - assert.NotEmpty(t, (&JWTClaims{}).ToMap()[consts.ClaimJWTID]) + assert.NotEmpty(t, (&JWTClaims{}).ToMap()[ClaimJWTID]) } func TestAssert(t *testing.T) { @@ -78,8 +78,8 @@ func TestScopeFieldString(t *testing.T) { jwtClaimsWithString := jwtClaims.WithScopeField(JWTScopeFieldString) // Making a copy of jwtClaimsMap. jwtClaimsMapWithString := jwtClaims.ToMap() - delete(jwtClaimsMapWithString, consts.ClaimScopeNonStandard) - jwtClaimsMapWithString[consts.ClaimScope] = "email offline" + delete(jwtClaimsMapWithString, ClaimScopeNonStandard) + jwtClaimsMapWithString[ClaimScope] = "email offline" assert.Equal(t, jwtClaimsMapWithString, map[string]any(jwtClaimsWithString.ToMapClaims())) var claims JWTClaims claims.FromMap(jwtClaimsMapWithString) @@ -90,7 +90,7 @@ func TestScopeFieldBoth(t *testing.T) { jwtClaimsWithBoth := jwtClaims.WithScopeField(JWTScopeFieldBoth) // Making a copy of jwtClaimsMap jwtClaimsMapWithBoth := jwtClaims.ToMap() - jwtClaimsMapWithBoth[consts.ClaimScope] = "email offline" + jwtClaimsMapWithBoth[ClaimScope] = "email offline" assert.Equal(t, jwtClaimsMapWithBoth, map[string]any(jwtClaimsWithBoth.ToMapClaims())) var claims JWTClaims claims.FromMap(jwtClaimsMapWithBoth) diff --git a/token/jwt/claims_map.go b/token/jwt/claims_map.go index 85ed22de..113f804f 100644 --- a/token/jwt/claims_map.go +++ b/token/jwt/claims_map.go @@ -5,122 +5,281 @@ package jwt import ( "bytes" - "crypto/subtle" - "encoding/json" "errors" - "time" + "fmt" jjson "github.com/go-jose/go-jose/v4/json" - "authelia.com/provider/oauth2/internal/consts" "authelia.com/provider/oauth2/x/errorsx" ) -var TimeFunc = time.Now +// NewMapClaims returns a set of MapClaims from an object that has the appropriate JSON tags. +func NewMapClaims(obj any) (claims MapClaims) { + return toMap(obj) +} -// MapClaims provides backwards compatible validations not available in `go-jose`. -// It was taken from [here](https://raw.githubusercontent.com/form3tech-oss/jwt-go/master/map_claims.go). -// -// Claims type that uses the map[string]any for JSON decoding -// This is the default claims type if you don't supply one +// MapClaims is a simple map based claims structure. type MapClaims map[string]any +// GetIssuer returns the 'iss' claim. +func (m MapClaims) GetIssuer() (iss string, err error) { + return m.toString(ClaimIssuer) +} + +// VerifyIssuer compares the iss claim against cmp. +// If required is false, this method will return true if the value matches or is unset +func (m MapClaims) VerifyIssuer(cmp string, required bool) (ok bool) { + var ( + iss string + err error + ) + + if iss, err = m.GetIssuer(); err != nil { + return false + } + + if iss == "" { + return !required + } + + return validString(iss, cmp, required) +} + +// GetSubject returns the 'sub' claim. +func (m MapClaims) GetSubject() (sub string, err error) { + return m.toString(ClaimSubject) +} + +// VerifySubject compares the iss claim against cmp. +// If required is false, this method will return true if the value matches or is unset +func (m MapClaims) VerifySubject(cmp string, required bool) (ok bool) { + var ( + sub string + err error + ) + + if sub, err = m.GetSubject(); err != nil { + return false + } + + if sub == "" { + return !required + } + + return validString(sub, cmp, required) +} + +// GetAudience returns the 'aud' claim. +func (m MapClaims) GetAudience() (aud ClaimStrings, err error) { + return m.toClaimsString(ClaimAudience) +} + // VerifyAudience compares the aud claim against cmp. // If required is false, this method will return true if the value matches or is unset -func (m MapClaims) VerifyAudience(cmp string, req bool) bool { +func (m MapClaims) VerifyAudience(cmp string, required bool) (ok bool) { var ( - aud []string - ok bool + aud ClaimStrings + err error ) - if aud, ok = StringSliceFromMap(m[consts.ClaimAudience]); ok { - return verifyAud(aud, cmp, req) + if aud, err = m.GetAudience(); err != nil { + return false + } + + if aud == nil { + return !required + } + + return verifyAud(aud, cmp, required) +} + +// VerifyAudienceAll compares the aud claim against cmp. +// If required is false, this method will return true if the value matches or is unset. +// This variant requires all of the audience values in the cmp. +func (m MapClaims) VerifyAudienceAll(cmp []string, required bool) (ok bool) { + var ( + aud ClaimStrings + err error + ) + + if aud, err = m.GetAudience(); err != nil { + return false + } + + if aud == nil { + return !required + } + + return verifyAudAll(aud, cmp, required) +} + +// VerifyAudienceAny compares the aud claim against cmp. +// If required is false, this method will return true if the value matches or is unset. +// This variant requires any of the audience values in the cmp. +func (m MapClaims) VerifyAudienceAny(cmp []string, required bool) (ok bool) { + var ( + aud ClaimStrings + err error + ) + + if aud, err = m.GetAudience(); err != nil { + return false + } + + if aud == nil { + return !required } - return false + return verifyAudAny(aud, cmp, required) } -// VerifyExpiresAt compares the exp claim against cmp. +// GetExpirationTime returns the 'exp' claim. +func (m MapClaims) GetExpirationTime() (exp *NumericDate, err error) { + return m.toNumericDate(ClaimExpirationTime) +} + +// VerifyExpirationTime compares the exp claim against cmp. // If required is false, this method will return true if the value matches or is unset -func (m MapClaims) VerifyExpiresAt(cmp int64, req bool) bool { - if v, ok := m.toInt64(consts.ClaimExpirationTime); ok { - return verifyExp(v, cmp, req) +func (m MapClaims) VerifyExpirationTime(cmp int64, required bool) (ok bool) { + var ( + exp *NumericDate + err error + ) + + if exp, err = m.GetExpirationTime(); err != nil { + return false + } + + if exp == nil { + return !required } - return !req + + return validInt64Future(exp.Int64(), cmp, required) +} + +// GetIssuedAt returns the 'iat' claim. +func (m MapClaims) GetIssuedAt() (iat *NumericDate, err error) { + return m.toNumericDate(ClaimIssuedAt) } // VerifyIssuedAt compares the iat claim against cmp. // If required is false, this method will return true if the value matches or is unset -func (m MapClaims) VerifyIssuedAt(cmp int64, req bool) bool { - if v, ok := m.toInt64(consts.ClaimIssuedAt); ok { - return verifyIat(v, cmp, req) +func (m MapClaims) VerifyIssuedAt(cmp int64, required bool) (ok bool) { + var ( + iat *NumericDate + err error + ) + + if iat, err = m.GetIssuedAt(); err != nil { + return false } - return !req + + if iat == nil { + return !required + } + + return validInt64Past(iat.Int64(), cmp, required) } -// VerifyIssuer compares the iss claim against cmp. -// If required is false, this method will return true if the value matches or is unset -func (m MapClaims) VerifyIssuer(cmp string, req bool) bool { - iss, _ := m[consts.ClaimIssuer].(string) - return verifyIss(iss, cmp, req) +// GetNotBefore returns the 'nbf' claim. +func (m MapClaims) GetNotBefore() (nbf *NumericDate, err error) { + return m.toNumericDate(ClaimNotBefore) } // VerifyNotBefore compares the nbf claim against cmp. // If required is false, this method will return true if the value matches or is unset -func (m MapClaims) VerifyNotBefore(cmp int64, req bool) bool { - if v, ok := m.toInt64(consts.ClaimNotBefore); ok { - return verifyNbf(v, cmp, req) +func (m MapClaims) VerifyNotBefore(cmp int64, required bool) (ok bool) { + var ( + nbf *NumericDate + err error + ) + + if nbf, err = m.GetNotBefore(); err != nil { + return false } - return !req + if nbf == nil { + return !required + } + + return validInt64Past(nbf.Int64(), cmp, required) } -func (m MapClaims) toInt64(claim string) (int64, bool) { - switch t := m[claim].(type) { - case float64: - return int64(t), true - case int64: - return t, true - case json.Number: - v, err := t.Int64() - if err == nil { - return v, true - } +func (m MapClaims) ToMapClaims() MapClaims { + if m == nil { + return nil + } - vf, err := t.Float64() - if err != nil { - return 0, false - } + return m +} - return int64(vf), true +func (m MapClaims) ToMap() map[string]any { + return m +} + +// Valid validates the given claims. By default it only validates time based claims "exp, iat, nbf"; there is no +// accounting for clock skew, and if any of the above claims are not in the token, the claims will still be considered +// valid. However all of these options can be tuned by the opts. +func (m MapClaims) Valid(opts ...ClaimValidationOption) (err error) { + vopts := &ClaimValidationOptions{} + + for _, opt := range opts { + opt(vopts) } - return 0, false -} + var now int64 + + if vopts.timef != nil { + now = vopts.timef().UTC().Unix() + } else { + now = TimeFunc().UTC().Unix() + } -// Valid validates time based claims "exp, iat, nbf". -// There is no accounting for clock skew. -// As well, if any of the above claims are not in the token, it will still -// be considered a valid claim. -func (m MapClaims) Valid() error { vErr := new(ValidationError) - now := TimeFunc().Unix() - if !m.VerifyExpiresAt(now, false) { + if !m.VerifyExpirationTime(now, vopts.expRequired) { vErr.Inner = errors.New("Token is expired") vErr.Errors |= ValidationErrorExpired } - if !m.VerifyIssuedAt(now, false) { + if !m.VerifyIssuedAt(now, vopts.iatRequired) { vErr.Inner = errors.New("Token used before issued") vErr.Errors |= ValidationErrorIssuedAt } - if !m.VerifyNotBefore(now, false) { + if !m.VerifyNotBefore(now, vopts.nbfRequired) { vErr.Inner = errors.New("Token is not valid yet") vErr.Errors |= ValidationErrorNotValidYet } + if len(vopts.iss) != 0 { + if !m.VerifyIssuer(vopts.iss, true) { + vErr.Inner = errors.New("Token has invalid issuer") + vErr.Errors |= ValidationErrorIssuer + } + } + + if len(vopts.sub) != 0 { + if !m.VerifySubject(vopts.sub, true) { + vErr.Inner = errors.New("Token has invalid subject") + vErr.Errors |= ValidationErrorSubject + } + } + + if len(vopts.aud) != 0 { + if !m.VerifyAudienceAny(vopts.aud, true) { + vErr.Inner = errors.New("Token has invalid audience") + vErr.Errors |= ValidationErrorAudience + } + } + + if len(vopts.audAll) != 0 { + if !m.VerifyAudienceAll(vopts.audAll, true) { + vErr.Inner = errors.New("Token has invalid audience") + vErr.Errors |= ValidationErrorAudience + } + } + if vErr.valid() { return nil } @@ -128,68 +287,80 @@ func (m MapClaims) Valid() error { return vErr } -func (m MapClaims) UnmarshalJSON(b []byte) error { - // This custom unmarshal allows to configure the - // go-jose decoding settings since there is no other way - // see https://github.com/square/go-jose/issues/353. - // If issue is closed with a better solution - // this custom Unmarshal method can be removed - d := jjson.NewDecoder(bytes.NewReader(b)) +func (m MapClaims) UnmarshalJSON(data []byte) error { + decoder := jjson.NewDecoder(bytes.NewReader(data)) + decoder.SetNumberType(jjson.UnmarshalIntOrFloat) + mp := map[string]any(m) - d.SetNumberType(jjson.UnmarshalIntOrFloat) - if err := d.Decode(&mp); err != nil { + + if err := decoder.Decode(&mp); err != nil { return errorsx.WithStack(err) } return nil } -func verifyAud(aud []string, cmp string, required bool) bool { - if len(aud) == 0 { - return !required - } +func (m MapClaims) toInt64(claim string) (val int64, ok bool) { + var v any - for _, a := range aud { - if subtle.ConstantTimeCompare([]byte(a), []byte(cmp)) != 0 { - return true - } + if v, ok = m[claim]; !ok { + return 0, false } - return false + return toInt64(v) } -func verifyExp(exp int64, now int64, required bool) bool { - if exp == 0 { - return !required - } - - return now <= exp -} +func (m MapClaims) toNumericDate(key string) (date *NumericDate, err error) { + var ( + v any + ok bool + ) -func verifyIat(iat int64, now int64, required bool) bool { - if iat == 0 { - return !required + if v, ok = m[key]; !ok { + return nil, nil } - return now >= iat + return toNumericDate(v) } -func verifyIss(iss string, cmp string, required bool) bool { - if iss == "" { - return !required +func (m MapClaims) toString(key string) (value string, err error) { + var ( + ok bool + raw any + ) + + if raw, ok = m[key]; !ok { + return "", nil } - if subtle.ConstantTimeCompare([]byte(iss), []byte(cmp)) != 0 { - return true - } else { - return false + if value, ok = raw.(string); !ok { + return "", newError(fmt.Sprintf("%s is invalid", key), ErrInvalidType) } + + return value, nil } -func verifyNbf(nbf int64, now int64, required bool) bool { - if nbf == 0 { - return !required +func (m MapClaims) toClaimsString(key string) (ClaimStrings, error) { + var cs []string + + switch v := m[key].(type) { + case string: + cs = append(cs, v) + case []string: + cs = v + case []any: + for _, a := range v { + if vs, ok := a.(string); !ok { + return nil, newError(fmt.Sprintf("%s is invalid", key), ErrInvalidType) + } else { + cs = append(cs, vs) + } + } + case nil: + return nil, nil + default: + return cs, newError(fmt.Sprintf("%s is invalid", key), ErrInvalidType) } - return now >= nbf + return cs, nil } diff --git a/token/jwt/claims_map_test.go b/token/jwt/claims_map_test.go index 9613e461..ba16becf 100644 --- a/token/jwt/claims_map_test.go +++ b/token/jwt/claims_map_test.go @@ -4,102 +4,965 @@ package jwt import ( + "errors" "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" "authelia.com/provider/oauth2/internal/consts" ) -// Test taken from taken from [here](https://raw.githubusercontent.com/form3tech-oss/jwt-go/master/map_claims_test.go). -func Test_mapClaims_list_aud(t *testing.T) { - mapClaims := MapClaims{ - consts.ClaimAudience: []string{"foo"}, +func TestMapClaims_VerifyAudience(t *testing.T) { + testCases := []struct { + name string + have MapClaims + cmp string + required bool + expected bool + }{ + { + "ShouldPass", + MapClaims{ + ClaimAudience: []string{"foo"}, + }, + "foo", + true, + true, + }, + { + "ShouldPassMultiple", + MapClaims{ + ClaimAudience: []string{"foo", "bar"}, + }, + "foo", + true, + true, + }, + { + "ShouldFailNoClaim", + MapClaims{}, + "foo", + true, + false, + }, + { + "ShouldFailNoMatch", + MapClaims{ + ClaimAudience: []string{"bar"}, + }, + "foo", + true, + false, + }, + { + "ShouldPassNoClaim", + MapClaims{}, + "foo", + false, + true, + }, + { + "ShouldPassTypeAny", + MapClaims{ + ClaimAudience: []any{"foo"}, + }, + "foo", + true, + true, + }, + { + "ShouldPassTypeString", + MapClaims{ + ClaimAudience: "foo", + }, + "foo", + true, + true, + }, + { + "ShouldFailTypeString", + MapClaims{ + consts.ClaimAudience: "bar", + }, + "foo", + true, + false, + }, + { + "ShouldFailTypeNil", + MapClaims{ + consts.ClaimAudience: nil, + }, + "foo", + true, + false, + }, + { + "ShouldFailTypeSliceAnyInt", + MapClaims{ + consts.ClaimAudience: []any{1, 2, 3}, + }, + "foo", + true, + false, + }, + { + "ShouldFailTypeInt", + MapClaims{ + consts.ClaimAudience: 1, + }, + "foo", + true, + false, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + assert.Equal(t, tc.expected, tc.have.VerifyAudience(tc.cmp, tc.required)) + }) + } +} + +func TestMapClaims_VerifyAudienceAll(t *testing.T) { + testCases := []struct { + name string + have MapClaims + cmp []string + required bool + expected bool + }{ + { + "ShouldPass", + MapClaims{ + consts.ClaimAudience: []string{"foo"}, + }, + []string{"foo"}, + true, + true, + }, + { + "ShouldFailMultipleAll", + MapClaims{ + consts.ClaimAudience: []string{"foo"}, + }, + []string{"foo", "bar"}, + true, + false, + }, + { + "ShouldPassMultiple", + MapClaims{ + consts.ClaimAudience: []string{"foo", "bar"}, + }, + []string{"foo"}, + true, + true, + }, + { + "ShouldPassMultipleAll", + MapClaims{ + consts.ClaimAudience: []string{"foo", "bar"}, + }, + []string{"foo", "bar"}, + true, + true, + }, + { + "ShouldFailNoClaim", + MapClaims{}, + []string{"foo"}, + true, + false, + }, + { + "ShouldFailNoMatch", + MapClaims{ + consts.ClaimAudience: []string{"bar"}, + }, + []string{"foo"}, + true, + false, + }, + { + "ShouldPassNoClaim", + MapClaims{}, + []string{"foo"}, + false, + true, + }, + { + "ShouldPassTypeAny", + MapClaims{ + consts.ClaimAudience: []any{"foo"}, + }, + []string{"foo"}, + true, + true, + }, + { + "ShouldPassTypeString", + MapClaims{ + consts.ClaimAudience: "foo", + }, + []string{"foo"}, + true, + true, + }, + { + "ShouldFailTypeString", + MapClaims{ + consts.ClaimAudience: "bar", + }, + []string{"foo"}, + true, + false, + }, + { + "ShouldFailTypeNil", + MapClaims{ + consts.ClaimAudience: nil, + }, + []string{"foo"}, + true, + false, + }, + { + "ShouldFailTypeSliceAnyInt", + MapClaims{ + consts.ClaimAudience: []any{1, 2, 3}, + }, + []string{"foo"}, + true, + false, + }, + { + "ShouldFailTypeInt", + MapClaims{ + consts.ClaimAudience: 1, + }, + []string{"foo"}, + true, + false, + }, } - want := true - got := mapClaims.VerifyAudience("foo", true) - if want != got { - t.Fatalf("Failed to verify claims, wanted: %v got %v", want, got) + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + assert.Equal(t, tc.expected, tc.have.VerifyAudienceAll(tc.cmp, tc.required)) + }) } } -// This is a custom test to check that an empty -// list with require == false returns valid -func Test_mapClaims_empty_list_aud(t *testing.T) { - mapClaims := MapClaims{ - consts.ClaimAudience: []string{}, +func TestMapClaims_VerifyAudienceAny(t *testing.T) { + testCases := []struct { + name string + have MapClaims + cmp []string + required bool + expected bool + }{ + { + "ShouldPass", + MapClaims{ + consts.ClaimAudience: []string{"foo"}, + }, + []string{"foo"}, + true, + true, + }, + { + "ShouldPassMultipleAny", + MapClaims{ + consts.ClaimAudience: []string{"foo", "baz"}, + }, + []string{"bar", "baz"}, + true, + true, + }, + { + "ShouldPassMultiple", + MapClaims{ + consts.ClaimAudience: []string{"foo", "bar"}, + }, + []string{"foo"}, + true, + true, + }, + { + "ShouldPassMultipleAll", + MapClaims{ + consts.ClaimAudience: []string{"foo", "bar"}, + }, + []string{"foo", "bar"}, + true, + true, + }, + { + "ShouldFailNoClaim", + MapClaims{}, + []string{"foo"}, + true, + false, + }, + { + "ShouldFailNoMatch", + MapClaims{ + consts.ClaimAudience: []string{"bar"}, + }, + []string{"foo"}, + true, + false, + }, + { + "ShouldPassNoClaim", + MapClaims{}, + []string{"foo"}, + false, + true, + }, + { + "ShouldPassTypeAny", + MapClaims{ + consts.ClaimAudience: []any{"foo"}, + }, + []string{"foo"}, + true, + true, + }, + { + "ShouldPassTypeString", + MapClaims{ + consts.ClaimAudience: "foo", + }, + []string{"foo"}, + true, + true, + }, + { + "ShouldFailTypeString", + MapClaims{ + consts.ClaimAudience: "bar", + }, + []string{"foo"}, + true, + false, + }, + { + "ShouldFailTypeNil", + MapClaims{ + consts.ClaimAudience: nil, + }, + []string{"foo"}, + true, + false, + }, + { + "ShouldFailTypeSliceAnyInt", + MapClaims{ + consts.ClaimAudience: []any{1, 2, 3}, + }, + []string{"foo"}, + true, + false, + }, + { + "ShouldFailTypeInt", + MapClaims{ + consts.ClaimAudience: 1, + }, + []string{"foo"}, + true, + false, + }, } - want := true - got := mapClaims.VerifyAudience("foo", false) - if want != got { - t.Fatalf("Failed to verify claims, wanted: %v got %v", want, got) + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + assert.Equal(t, tc.expected, tc.have.VerifyAudienceAny(tc.cmp, tc.required)) + }) } } -func Test_mapClaims_list_interface_aud(t *testing.T) { - mapClaims := MapClaims{ - consts.ClaimAudience: []any{"foo"}, +func TestMapClaims_VerifyIssuer(t *testing.T) { + testCases := []struct { + name string + have MapClaims + cmp string + required bool + expected bool + }{ + { + "ShouldPass", + MapClaims{ + consts.ClaimIssuer: "foo", + }, + "foo", + true, + true, + }, + { + "ShouldFailEmptyString", + MapClaims{ + consts.ClaimIssuer: "", + }, + "foo", + true, + false, + }, + { + "ShouldFailNoClaim", + MapClaims{}, + "foo", + true, + false, + }, + { + "ShouldPassNoClaim", + MapClaims{}, + "foo", + false, + true, + }, + { + "ShouldFailNoMatch", + MapClaims{ + consts.ClaimIssuer: "bar", + }, + "foo", + true, + false, + }, + { + "ShouldFailWrongType", + MapClaims{ + consts.ClaimIssuer: 5, + }, + "5", + true, + false, + }, + { + "ShouldFailNil", + MapClaims{ + consts.ClaimIssuer: nil, + }, + "foo", + true, + false, + }, } - want := true - got := mapClaims.VerifyAudience("foo", true) - if want != got { - t.Fatalf("Failed to verify claims, wanted: %v got %v", want, got) + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + assert.Equal(t, tc.expected, tc.have.VerifyIssuer(tc.cmp, tc.required)) + }) } } -func Test_mapClaims_string_aud(t *testing.T) { - mapClaims := MapClaims{ - consts.ClaimAudience: "foo", +func TestMapClaims_VerifySubject(t *testing.T) { + testCases := []struct { + name string + have MapClaims + cmp string + required bool + expected bool + }{ + { + "ShouldPass", + MapClaims{ + consts.ClaimSubject: "foo", + }, + "foo", + true, + true, + }, + { + "ShouldFailNoClaim", + MapClaims{}, + "foo", + true, + false, + }, + { + "ShouldPassNoClaim", + MapClaims{}, + "foo", + false, + true, + }, + { + "ShouldFailNoMatch", + MapClaims{ + consts.ClaimSubject: "bar", + }, + "foo", + true, + false, + }, + { + "ShouldFailWrongType", + MapClaims{ + consts.ClaimSubject: 5, + }, + "5", + true, + false, + }, + { + "ShouldFailNil", + MapClaims{ + consts.ClaimSubject: nil, + }, + "foo", + true, + false, + }, } - want := true - got := mapClaims.VerifyAudience("foo", true) - if want != got { - t.Fatalf("Failed to verify claims, wanted: %v got %v", want, got) + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + assert.Equal(t, tc.expected, tc.have.VerifySubject(tc.cmp, tc.required)) + }) } } -func Test_mapClaims_list_aud_no_match(t *testing.T) { - mapClaims := MapClaims{ - consts.ClaimAudience: []string{"bar"}, +func TestMapClaims_VerifyExpiresAt(t *testing.T) { + testCases := []struct { + name string + have MapClaims + cmp int64 + required bool + expected bool + }{ + { + "ShouldPass", + MapClaims{ + consts.ClaimExpirationTime: int64(123), + }, + int64(123), + true, + true, + }, + { + "ShouldPassStandardInt", + MapClaims{ + consts.ClaimExpirationTime: 123, + }, + int64(123), + true, + true, + }, + { + "ShouldPassStandardInt32", + MapClaims{ + consts.ClaimExpirationTime: int32(123), + }, + int64(123), + true, + true, + }, + { + "ShouldFailNoClaim", + MapClaims{}, + int64(123), + true, + false, + }, + { + "ShouldPassNoClaim", + MapClaims{}, + int64(123), + false, + true, + }, + { + "ShouldFailNoMatch", + MapClaims{ + consts.ClaimExpirationTime: 4, + }, + int64(123), + true, + false, + }, + { + "ShouldFailWrongType", + MapClaims{ + consts.ClaimExpirationTime: true, + }, + int64(123), + true, + false, + }, + { + "ShouldFailNil", + MapClaims{ + consts.ClaimExpirationTime: nil, + }, + int64(123), + true, + false, + }, } - want := false - got := mapClaims.VerifyAudience("foo", true) - if want != got { - t.Fatalf("Failed to verify claims, wanted: %v got %v", want, got) + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + assert.Equal(t, tc.expected, tc.have.VerifyExpirationTime(tc.cmp, tc.required)) + }) } } -func Test_mapClaims_string_aud_fail(t *testing.T) { - mapClaims := MapClaims{ - consts.ClaimAudience: "bar", +func TestMapClaims_VerifyIssuedAt(t *testing.T) { + testCases := []struct { + name string + have MapClaims + cmp int64 + required bool + expected bool + }{ + { + "ShouldPass", + MapClaims{ + consts.ClaimIssuedAt: int64(123), + }, + int64(123), + true, + true, + }, + { + "ShouldPassStandardInt", + MapClaims{ + consts.ClaimIssuedAt: 123, + }, + int64(123), + true, + true, + }, + { + "ShouldPassStandardInt32", + MapClaims{ + consts.ClaimIssuedAt: int32(123), + }, + int64(123), + true, + true, + }, + { + "ShouldFailNoClaim", + MapClaims{}, + int64(123), + true, + false, + }, + { + "ShouldPassNoClaim", + MapClaims{}, + int64(123), + false, + true, + }, + { + "ShouldFailFuture", + MapClaims{ + consts.ClaimIssuedAt: 9000, + }, + int64(123), + true, + false, + }, + { + "ShouldPassPast", + MapClaims{ + consts.ClaimIssuedAt: 4, + }, + int64(123), + true, + true, + }, + { + "ShouldFailWrongType", + MapClaims{ + consts.ClaimIssuedAt: true, + }, + int64(123), + true, + false, + }, + { + "ShouldFailNil", + MapClaims{ + consts.ClaimIssuedAt: nil, + }, + int64(123), + true, + false, + }, } - want := false - got := mapClaims.VerifyAudience("foo", true) - if want != got { - t.Fatalf("Failed to verify claims, wanted: %v got %v", want, got) + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + assert.Equal(t, tc.expected, tc.have.VerifyIssuedAt(tc.cmp, tc.required)) + }) } } -func Test_mapClaims_string_aud_no_claim(t *testing.T) { - mapClaims := MapClaims{} - want := false - got := mapClaims.VerifyAudience("foo", true) +func TestMapClaims_VerifyNotBefore(t *testing.T) { + testCases := []struct { + name string + have MapClaims + cmp int64 + required bool + expected bool + }{ + { + "ShouldPass", + MapClaims{ + consts.ClaimNotBefore: int64(123), + }, + int64(123), + true, + true, + }, + { + "ShouldPassStandardInt", + MapClaims{ + consts.ClaimNotBefore: 123, + }, + int64(123), + true, + true, + }, + { + "ShouldPassStandardInt32", + MapClaims{ + consts.ClaimNotBefore: int32(123), + }, + int64(123), + true, + true, + }, + { + "ShouldFailNoClaim", + MapClaims{}, + int64(123), + true, + false, + }, + { + "ShouldPassNoClaim", + MapClaims{}, + int64(123), + false, + true, + }, + { + "ShouldFailFuture", + MapClaims{ + consts.ClaimNotBefore: 9000, + }, + int64(123), + true, + false, + }, + { + "ShouldPassPast", + MapClaims{ + consts.ClaimNotBefore: 4, + }, + int64(123), + true, + true, + }, + { + "ShouldFailWrongType", + MapClaims{ + consts.ClaimNotBefore: true, + }, + int64(123), + true, + false, + }, + { + "ShouldFailNil", + MapClaims{ + consts.ClaimNotBefore: nil, + }, + int64(123), + true, + false, + }, + } - if want != got { - t.Fatalf("Failed to verify claims, wanted: %v got %v", want, got) + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + assert.Equal(t, tc.expected, tc.have.VerifyNotBefore(tc.cmp, tc.required)) + }) } } -func Test_mapClaims_string_aud_no_claim_not_required(t *testing.T) { - mapClaims := MapClaims{} - want := true - got := mapClaims.VerifyAudience("foo", false) +func TestMapClaims_Valid(t *testing.T) { + testCases := []struct { + name string + have MapClaims + opts []ClaimValidationOption + errs []uint32 + err string + }{ + { + "ShouldPass", + MapClaims{}, + nil, + nil, + "", + }, + { + "ShouldFailEXPNotPresent", + MapClaims{}, + []ClaimValidationOption{ValidateRequireExpiresAt(), ValidateTimeFunc(func() time.Time { return time.Unix(0, 0) })}, + []uint32{ValidationErrorExpired}, + "Token is expired", + }, + { + "ShouldFailIATNotPresent", + MapClaims{}, + []ClaimValidationOption{ValidateRequireIssuedAt()}, + []uint32{ValidationErrorIssuedAt}, + "Token used before issued", + }, + { + "ShouldFailNBFNotPresent", + MapClaims{}, + []ClaimValidationOption{ValidateRequireNotBefore()}, + []uint32{ValidationErrorNotValidYet}, + "Token is not valid yet", + }, + { + "ShouldFailExpPast", + MapClaims{ + consts.ClaimExpirationTime: 1, + }, + nil, + []uint32{ValidationErrorExpired}, + "Token is expired", + }, + { + "ShouldFailIssuedFuture", + MapClaims{ + consts.ClaimIssuedAt: 999999999999999, + }, + nil, + []uint32{ValidationErrorIssuedAt}, + "Token used before issued", + }, + { + "ShouldFailMultiple", + MapClaims{ + consts.ClaimExpirationTime: 1, + consts.ClaimIssuedAt: 999999999999999, + }, + nil, + []uint32{ValidationErrorIssuedAt, ValidationErrorExpired}, + "Token used before issued", + }, + { + "ShouldPassIssuer", + MapClaims{ + consts.ClaimIssuer: "abc", + }, + []ClaimValidationOption{ValidateIssuer("abc")}, + nil, + "", + }, + { + "ShouldFailIssuer", + MapClaims{ + consts.ClaimIssuer: "abc", + }, + []ClaimValidationOption{ValidateIssuer("abc2"), ValidateTimeFunc(time.Now)}, + []uint32{ValidationErrorIssuer}, + "Token has invalid issuer", + }, + { + "ShouldFailIssuerAbsent", + MapClaims{}, + []ClaimValidationOption{ValidateIssuer("abc2")}, + []uint32{ValidationErrorIssuer}, + "Token has invalid issuer", + }, + { + "ShouldPassSubject", + MapClaims{ + consts.ClaimSubject: "abc", + }, + []ClaimValidationOption{ValidateSubject("abc")}, + nil, + "", + }, + { + "ShouldFailSubject", + MapClaims{ + consts.ClaimSubject: "abc", + }, + []ClaimValidationOption{ValidateSubject("abc2")}, + []uint32{ValidationErrorSubject}, + "Token has invalid subject", + }, + { + "ShouldFailSubjectAbsent", + MapClaims{}, + []ClaimValidationOption{ValidateSubject("abc2")}, + []uint32{ValidationErrorSubject}, + "Token has invalid subject", + }, + { + "ShouldPassAudienceAll", + MapClaims{ + consts.ClaimAudience: []any{"abc", "123"}, + }, + []ClaimValidationOption{ValidateAudienceAll("abc", "123")}, + nil, + "", + }, + { + "ShouldFailAudienceAll", + MapClaims{ + consts.ClaimAudience: []any{"abc", "123"}, + }, + []ClaimValidationOption{ValidateAudienceAll("abc", "123", "456")}, + []uint32{ValidationErrorAudience}, + "Token has invalid audience", + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + actual := tc.have.Valid(tc.opts...) + + if len(tc.err) == 0 && tc.err == "" { + assert.NoError(t, actual) + } else { + if tc.err != "" { + assert.EqualError(t, actual, tc.err) + } + + var e *ValidationError + + errors.As(actual, &e) + + require.NotNil(t, e) + + var errs uint32 + + for _, err := range tc.errs { + errs |= err + + assert.True(t, e.Has(err)) + } - if want != got { - t.Fatalf("Failed to verify claims, wanted: %v got %v", want, got) + assert.Equal(t, errs, e.Errors) + } + }) } } diff --git a/token/jwt/claims_test.go b/token/jwt/claims_test.go index 667de76c..c5245036 100644 --- a/token/jwt/claims_test.go +++ b/token/jwt/claims_test.go @@ -21,8 +21,49 @@ func TestToTime(t *testing.T) { assert.Equal(t, time.Time{}, ToTime(nil)) assert.Equal(t, time.Time{}, ToTime("1234")) - now := time.Now().UTC().Round(time.Second) + now := time.Now().UTC().Truncate(TimePrecision) assert.Equal(t, now, ToTime(now)) assert.Equal(t, now, ToTime(now.Unix())) assert.Equal(t, now, ToTime(float64(now.Unix()))) } + +func TestFilter(t *testing.T) { + testCases := []struct { + name string + have map[string]any + filter []string + expected map[string]any + }{ + { + "ShouldFilterNone", + map[string]any{"abc": 123}, + []string{}, + map[string]any{"abc": 123}, + }, + { + "ShouldFilterNoneNil", + map[string]any{"abc": 123}, + []string{}, + map[string]any{"abc": 123}, + }, + { + "ShouldFilterAll", + map[string]any{"abc": 123, "example": 123}, + []string{"abc", "example"}, + map[string]any{}, + }, + { + "ShouldFilterSome", + map[string]any{"abc": 123, "example": 123}, + []string{"abc"}, + map[string]any{"example": 123}, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + have := Filter(tc.have, tc.filter...) + assert.Equal(t, tc.expected, have) + }) + } +} diff --git a/token/jwt/client.go b/token/jwt/client.go new file mode 100644 index 00000000..0082f94e --- /dev/null +++ b/token/jwt/client.go @@ -0,0 +1,494 @@ +package jwt + +import ( + "github.com/go-jose/go-jose/v4" +) + +// NewJARClient converts a type into a Client provided it implements the JARClient. +func NewJARClient(client any) Client { + switch c := client.(type) { + case JARClient: + return &decoratedJARClient{JARClient: c} + default: + return nil + } +} + +// NewIDTokenClient converts a type into a Client provided it implements the IDTokenClient. +func NewIDTokenClient(client any) Client { + switch c := client.(type) { + case IDTokenClient: + return &decoratedIDTokenClient{IDTokenClient: c} + default: + return nil + } +} + +// NewJARMClient converts a type into a Client provided it implements the JARMClient. +func NewJARMClient(client any) Client { + switch c := client.(type) { + case JARMClient: + return &decoratedJARMClient{JARMClient: c} + default: + return nil + } +} + +// NewUserInfoClient converts a type into a Client provided it implements the UserInfoClient. +func NewUserInfoClient(client any) Client { + switch c := client.(type) { + case UserInfoClient: + return &decoratedUserInfoClient{UserInfoClient: c} + default: + return nil + } +} + +// NewJWTProfileAccessTokenClient converts a type into a Client provided it implements the JWTProfileAccessTokenClient. +func NewJWTProfileAccessTokenClient(client any) Client { + switch c := client.(type) { + case JWTProfileAccessTokenClient: + return &decoratedJWTProfileAccessTokenClient{JWTProfileAccessTokenClient: c} + default: + return nil + } +} + +// NewIntrospectionClient converts a type into a Client provided it implements the IntrospectionClient. +func NewIntrospectionClient(client any) Client { + switch c := client.(type) { + case IntrospectionClient: + return &decoratedIntrospectionClient{IntrospectionClient: c} + default: + return nil + } +} + +// NewStatelessJWTProfileIntrospectionClient converts a type into a Client provided it implements either the +// IntrospectionClient or JWTProfileAccessTokenClient. +func NewStatelessJWTProfileIntrospectionClient(client any) Client { + switch c := client.(type) { + case IntrospectionClient: + return &decoratedIntrospectionClient{IntrospectionClient: c} + case JWTProfileAccessTokenClient: + return &decoratedJWTProfileAccessTokenClient{JWTProfileAccessTokenClient: c} + default: + return nil + } +} + +// Client represents a client which can be used to sign, verify, encrypt, and decrypt JWT's. +type Client interface { + GetSigningKeyID() (kid string) + GetSigningAlg() (alg string) + GetEncryptionKeyID() (kid string) + GetEncryptionAlg() (alg string) + GetEncryptionEnc() (enc string) + + IsClientSigned() (is bool) + + BaseClient +} + +// BaseClient represents the base implementation for any JWT compatible client. +type BaseClient interface { + // GetID returns the client ID. + GetID() string + + // GetClientSecretPlainText returns the ClientSecret as plaintext if available. The semantics of this function + // return values are important. + // If the client is not configured with a secret the return should be: + // - secret with value nil, ok with value false, and err with value of nil + // If the client is configured with a secret but is hashed or otherwise not a plaintext value: + // - secret with value nil, ok with value true, and err with value of nil + // If an error occurs retrieving the secret other than this: + // - secret with value nil, ok with value true, and err with value of the error + // If the plaintext secret is successful: + // - secret with value of the bytes of the plaintext secret, ok with value true, and err with value of nil + GetClientSecretPlainText() (secret []byte, ok bool, err error) + + // GetJSONWebKeys returns the JSON Web Key Set containing the public key used by the client to authenticate. + GetJSONWebKeys() (jwks *jose.JSONWebKeySet) + + // GetJSONWebKeysURI returns the URL for lookup of JSON Web Key Set containing the + // public key used by the client to authenticate. + GetJSONWebKeysURI() (uri string) +} + +// JARClient represents the implementation for any JWT Authorization Request compatible client. +type JARClient interface { + // GetRequestObjectSigningKeyID returns the specific key identifier used to satisfy JWS requirements of the request + // object specifications. If unspecified the other available parameters will be utilized to select an appropriate + // key. + GetRequestObjectSigningKeyID() (kid string) + + // GetRequestObjectSigningAlg is equivalent to the 'request_object_signing_alg' client metadata + // value which determines the JWS alg algorithm [JWA] that MUST be used for signing Request Objects sent to the OP. + // All Request Objects from this Client MUST be rejected, if not signed with this algorithm. Request Objects are + // described in Section 6.1 of OpenID Connect Core 1.0 [OpenID.Core]. This algorithm MUST be used both when the + // Request Object is passed by value (using the request parameter) and when it is passed by reference (using the + // request_uri parameter). Servers SHOULD support RS256. The value none MAY be used. The default, if omitted, is + // that any algorithm supported by the OP and the RP MAY be used. + GetRequestObjectSigningAlg() (alg string) + + // GetRequestObjectEncryptionKeyID returns the specific key identifier used to satisfy JWE requirements of the + // request object specifications. If unspecified the other available parameters will be utilized to select an + // appropriate key. + GetRequestObjectEncryptionKeyID() (kid string) + + // GetRequestObjectEncryptionAlg is equivalent to the 'request_object_encryption_alg' client metadata value which + // determines the JWE alg algorithm [JWA] the RP is declaring that it may use for encrypting Request Objects sent to + // the OP. This parameter SHOULD be included when symmetric encryption will be used, since this signals to the OP + // that a client_secret value needs to be returned from which the symmetric key will be derived, that might not + // otherwise be returned. The RP MAY still use other supported encryption algorithms or send unencrypted Request + // Objects, even when this parameter is present. If both signing and encryption are requested, the Request Object + // will be signed then encrypted, with the result being a Nested JWT, as defined in [JWT]. The default, if omitted, + // is that the RP is not declaring whether it might encrypt any Request Objects. + GetRequestObjectEncryptionAlg() (alg string) + + // GetRequestObjectEncryptionEnc is equivalent to the 'request_object_encryption_enc' client metadata value which + // determines the JWE enc algorithm [JWA] the RP is declaring that it may use for encrypting Request Objects sent to + // the OP. If request_object_encryption_alg is specified, the default request_object_encryption_enc value is + // A128CBC-HS256. When request_object_encryption_enc is included, request_object_encryption_alg MUST also be + // provided. + GetRequestObjectEncryptionEnc() (enc string) + + BaseClient +} + +type decoratedJARClient struct { + JARClient +} + +func (r *decoratedJARClient) GetSigningKeyID() (kid string) { + return r.GetRequestObjectSigningKeyID() +} + +func (r *decoratedJARClient) GetSigningAlg() (alg string) { + return r.GetRequestObjectSigningAlg() +} + +func (r *decoratedJARClient) GetEncryptionKeyID() (kid string) { + return r.GetRequestObjectEncryptionKeyID() +} + +func (r *decoratedJARClient) GetEncryptionAlg() (alg string) { + return r.GetRequestObjectEncryptionAlg() +} + +func (r *decoratedJARClient) GetEncryptionEnc() (enc string) { + return r.GetRequestObjectEncryptionEnc() +} + +func (r *decoratedJARClient) IsClientSigned() (is bool) { + return true +} + +type IDTokenClient interface { + // GetIDTokenSignedResponseKeyID returns the specific key identifier used to satisfy JWS requirements of the ID + // Token specifications. If unspecified the other available parameters will be utilized to select an appropriate + // key. + GetIDTokenSignedResponseKeyID() (kid string) + + // GetIDTokenSignedResponseAlg is equivalent to the 'id_token_signed_response_alg' client metadata value which + // determines the JWS alg algorithm [JWA] REQUIRED for signing the ID Token issued to this Client. The value none + // MUST NOT be used as the ID Token alg value unless the Client uses only Response Types that return no ID Token + // from the Authorization Endpoint (such as when only using the Authorization Code Flow). The default, if omitted, + // is RS256. The public key for validating the signature is provided by retrieving the JWK Set referenced by the + // jwks_uri element from OpenID Connect Discovery 1.0 [OpenID.Discovery]. + GetIDTokenSignedResponseAlg() (alg string) + + // GetIDTokenEncryptedResponseKeyID returns the specific key identifier used to satisfy JWE requirements of the ID + // Token specifications. If unspecified the other available parameters will be utilized to select an appropriate + // key. + GetIDTokenEncryptedResponseKeyID() (kid string) + + // GetIDTokenEncryptedResponseAlg is equivalent to the 'id_token_encrypted_response_alg' client metadata value which + // determines the JWE alg algorithm [JWA] REQUIRED for encrypting the ID Token issued to this Client. If this is + // requested, the response will be signed then encrypted, with the result being a Nested JWT, as defined in [JWT]. + // The default, if omitted, is that no encryption is performed. + GetIDTokenEncryptedResponseAlg() (alg string) + + // GetIDTokenEncryptedResponseEnc is equivalent to the 'id_token_encrypted_response_enc' client metadata value which + // determines the JWE enc algorithm [JWA] REQUIRED for encrypting the ID Token issued to this Client. If + // id_token_encrypted_response_alg is specified, the default id_token_encrypted_response_enc value is A128CBC-HS256. + // When id_token_encrypted_response_enc is included, id_token_encrypted_response_alg MUST also be provided. + GetIDTokenEncryptedResponseEnc() (enc string) + + BaseClient +} + +type decoratedIDTokenClient struct { + IDTokenClient +} + +func (r *decoratedIDTokenClient) GetSigningKeyID() (kid string) { + return r.GetIDTokenSignedResponseKeyID() +} + +func (r *decoratedIDTokenClient) GetSigningAlg() (alg string) { + return r.GetIDTokenSignedResponseAlg() +} + +func (r *decoratedIDTokenClient) GetEncryptionKeyID() (kid string) { + return r.GetIDTokenEncryptedResponseKeyID() +} + +func (r *decoratedIDTokenClient) GetEncryptionAlg() (alg string) { + return r.GetIDTokenEncryptedResponseAlg() +} + +func (r *decoratedIDTokenClient) GetEncryptionEnc() (enc string) { + return r.GetIDTokenEncryptedResponseEnc() +} + +func (r *decoratedIDTokenClient) IsClientSigned() (is bool) { + return false +} + +type JARMClient interface { + // GetAuthorizationSignedResponseKeyID returns the specific key identifier used to satisfy JWS requirements of the + // JWT-secured Authorization Response Method (JARM) specifications. If unspecified the other available parameters + // will be utilized to select an appropriate key. + GetAuthorizationSignedResponseKeyID() (kid string) + + // GetAuthorizationSignedResponseAlg is equivalent to the 'authorization_signed_response_alg' client metadata + // value which determines the JWS [RFC7515] alg algorithm JWA [RFC7518] REQUIRED for signing authorization + // responses. If this is specified, the response will be signed using JWS and the configured algorithm. The + // algorithm none is not allowed. The default, if omitted, is RS256. + GetAuthorizationSignedResponseAlg() (alg string) + + // GetAuthorizationEncryptedResponseKeyID returns the specific key identifier used to satisfy JWE requirements of + // the JWT-secured Authorization Response Method (JARM) specifications. If unspecified the other available parameters will be + // utilized to select an appropriate key. + GetAuthorizationEncryptedResponseKeyID() (kid string) + + // GetAuthorizationEncryptedResponseAlg is equivalent to the 'authorization_encrypted_response_alg' client metadata + // value which determines the JWE [RFC7516] alg algorithm JWA [RFC7518] REQUIRED for encrypting authorization + // responses. If both signing and encryption are requested, the response will be signed then encrypted, with the + // result being a Nested JWT, as defined in JWT [RFC7519]. The default, if omitted, is that no encryption is + // performed. + GetAuthorizationEncryptedResponseAlg() (alg string) + + // GetAuthorizationEncryptedResponseEnc is equivalent to the 'authorization_encrypted_response_enc' client + // metadata value which determines the JWE [RFC7516] enc algorithm JWA [RFC7518] REQUIRED for encrypting + // authorization responses. If authorization_encrypted_response_alg is specified, the default for this value is + // A128CBC-HS256. When authorization_encrypted_response_enc is included, authorization_encrypted_response_alg MUST + // also be provided. + GetAuthorizationEncryptedResponseEnc() (alg string) + + BaseClient +} + +type decoratedJARMClient struct { + JARMClient +} + +func (r *decoratedJARMClient) GetSigningKeyID() (kid string) { + return r.GetAuthorizationSignedResponseKeyID() +} + +func (r *decoratedJARMClient) GetSigningAlg() (alg string) { + return r.GetAuthorizationSignedResponseAlg() +} + +func (r *decoratedJARMClient) GetEncryptionKeyID() (kid string) { + return r.GetAuthorizationEncryptedResponseKeyID() +} + +func (r *decoratedJARMClient) GetEncryptionAlg() (alg string) { + return r.GetAuthorizationEncryptedResponseAlg() +} + +func (r *decoratedJARMClient) GetEncryptionEnc() (enc string) { + return r.GetAuthorizationEncryptedResponseEnc() +} + +func (r *decoratedJARMClient) IsClientSigned() (is bool) { + return false +} + +type UserInfoClient interface { + // GetUserinfoSignedResponseKeyID returns the specific key identifier used to satisfy JWS requirements of the User + // Info specifications. If unspecified the other available parameters will be utilized to select an appropriate + // key. + GetUserinfoSignedResponseKeyID() (kid string) + + // GetUserinfoSignedResponseAlg is equivalent to the 'userinfo_signed_response_alg' client metadata value which + // determines the JWS alg algorithm [JWA] REQUIRED for signing UserInfo Responses. If this is specified, the + // response will be JWT [JWT] serialized, and signed using JWS. The default, if omitted, is for the UserInfo + // Response to return the Claims as a UTF-8 [RFC3629] encoded JSON object using the application/json content-type. + GetUserinfoSignedResponseAlg() (alg string) + + // GetUserinfoEncryptedResponseKeyID returns the specific key identifier used to satisfy JWE requirements of the + // User Info specifications. If unspecified the other available parameters will be utilized to select an appropriate + // key. + GetUserinfoEncryptedResponseKeyID() (kid string) + + // GetUserinfoEncryptedResponseAlg is equivalent to the 'userinfo_encrypted_response_alg' client metadata value + // which determines the JWE alg algorithm [JWA] REQUIRED for encrypting the ID Token issued to this Client. If + // this is requested, the response will be signed then encrypted, with the result being a Nested JWT, as defined in + // [JWT]. The default, if omitted, is that no encryption is performed. + GetUserinfoEncryptedResponseAlg() (alg string) + + // GetUserinfoEncryptedResponseEnc is equivalent to the 'userinfo_encrypted_response_enc' client metadata value + // which determines the JWE enc algorithm [JWA] REQUIRED for encrypting UserInfo Responses. If + // userinfo_encrypted_response_alg is specified, the default userinfo_encrypted_response_enc value is A128CBC-HS256. + // When userinfo_encrypted_response_enc is included, userinfo_encrypted_response_alg MUST also be provided. + GetUserinfoEncryptedResponseEnc() (enc string) + + BaseClient +} + +type decoratedUserInfoClient struct { + UserInfoClient +} + +func (r *decoratedUserInfoClient) GetSigningKeyID() (kid string) { + return r.GetUserinfoSignedResponseKeyID() +} + +func (r *decoratedUserInfoClient) GetSigningAlg() (alg string) { + return r.GetUserinfoSignedResponseAlg() +} + +func (r *decoratedUserInfoClient) GetEncryptionKeyID() (kid string) { + return r.GetUserinfoEncryptedResponseKeyID() +} + +func (r *decoratedUserInfoClient) GetEncryptionAlg() (alg string) { + return r.GetUserinfoEncryptedResponseAlg() +} + +func (r *decoratedUserInfoClient) GetEncryptionEnc() (enc string) { + return r.GetUserinfoEncryptedResponseEnc() +} + +func (r *decoratedUserInfoClient) IsClientSigned() (is bool) { + return false +} + +type JWTProfileAccessTokenClient interface { + // GetAccessTokenSignedResponseKeyID returns the specific key identifier used to satisfy JWS requirements for + // JWT Profile for OAuth 2.0 Access Tokens specifications. If unspecified the other available parameters will be + // utilized to select an appropriate key. + GetAccessTokenSignedResponseKeyID() (kid string) + + // GetAccessTokenSignedResponseAlg determines the JWS [RFC7515] algorithm (alg value) as defined in JWA [RFC7518] + // for signing JWT Profile Access Token responses. If this is specified, the response will be signed using JWS and + // the configured algorithm. The default, if omitted, is none; i.e. unsigned responses unless the + // GetEnableJWTProfileOAuthAccessTokens receiver returns true in which case the default is RS256. + GetAccessTokenSignedResponseAlg() (alg string) + + // GetAccessTokenEncryptedResponseKeyID returns the specific key identifier used to satisfy JWE requirements for + // JWT Profile for OAuth 2.0 Access Tokens specifications. If unspecified the other available parameters will be + // utilized to select an appropriate key. + GetAccessTokenEncryptedResponseKeyID() (kid string) + + // GetAccessTokenEncryptedResponseAlg determines the JWE [RFC7516] algorithm (alg value) as defined in JWA [RFC7518] + // for content key encryption. If this is specified, the response will be encrypted using JWE and the configured + // content encryption algorithm (access_token_encrypted_response_enc). The default, if omitted, is that no + // encryption is performed. If both signing and encryption are requested, the response will be signed then + // encrypted, with the result being a Nested JWT, as defined in JWT [RFC7519]. + GetAccessTokenEncryptedResponseAlg() (alg string) + + // GetAccessTokenEncryptedResponseEnc determines the JWE [RFC7516] algorithm (enc value) as defined in JWA [RFC7518] + // for content encryption of access token responses. The default, if omitted, is A128CBC-HS256. Note: This parameter + // MUST NOT be specified without setting access_token_encrypted_response_alg. + GetAccessTokenEncryptedResponseEnc() (alg string) + + // GetEnableJWTProfileOAuthAccessTokens indicates this client should or should not issue JWT Profile Access Tokens. + GetEnableJWTProfileOAuthAccessTokens() (enforce bool) + + BaseClient +} + +type decoratedJWTProfileAccessTokenClient struct { + JWTProfileAccessTokenClient +} + +func (r *decoratedJWTProfileAccessTokenClient) GetSigningKeyID() (kid string) { + return r.GetAccessTokenSignedResponseKeyID() +} + +func (r *decoratedJWTProfileAccessTokenClient) GetSigningAlg() (alg string) { + return r.GetAccessTokenSignedResponseAlg() +} + +func (r *decoratedJWTProfileAccessTokenClient) GetEncryptionKeyID() (kid string) { + return r.GetAccessTokenEncryptedResponseKeyID() +} + +func (r *decoratedJWTProfileAccessTokenClient) GetEncryptionAlg() (alg string) { + return r.GetAccessTokenEncryptedResponseAlg() +} + +func (r *decoratedJWTProfileAccessTokenClient) GetEncryptionEnc() (enc string) { + return r.GetAccessTokenEncryptedResponseEnc() +} + +func (r *decoratedJWTProfileAccessTokenClient) IsClientSigned() (is bool) { + return false +} + +type IntrospectionClient interface { + // GetIntrospectionSignedResponseKeyID returns the specific key identifier used to satisfy JWS requirements for + // OAuth 2.0 JWT introspection response specifications. If unspecified the other available parameters will be + // utilized to select an appropriate key. + GetIntrospectionSignedResponseKeyID() (kid string) + + // GetIntrospectionSignedResponseAlg is equivalent to the 'introspection_signed_response_alg' client metadata + // value which determines the JWS [RFC7515] algorithm (alg value) as defined in JWA [RFC7518] for signing + // introspection responses. If this is specified, the response will be signed using JWS and the configured + // algorithm. The default, if omitted, is RS256. + GetIntrospectionSignedResponseAlg() (alg string) + + // GetIntrospectionEncryptedResponseKeyID returns the specific key identifier used to satisfy JWE requirements for + // OAuth 2.0 JWT introspection response specifications. If unspecified the other available parameters will be + // utilized to select an appropriate key. + GetIntrospectionEncryptedResponseKeyID() (kid string) + + // GetIntrospectionEncryptedResponseAlg is equivalent to the 'introspection_encrypted_response_alg' client metadata + // value which determines the JWE [RFC7516] algorithm (alg value) as defined in JWA [RFC7518] for content key + // encryption. If this is specified, the response will be encrypted using JWE and the configured content encryption + // algorithm (introspection_encrypted_response_enc). The default, if omitted, is that no encryption is performed. + // If both signing and encryption are requested, the response will be signed then encrypted, with the result being + // a Nested JWT, as defined in JWT [RFC7519]. + GetIntrospectionEncryptedResponseAlg() (alg string) + + // GetIntrospectionEncryptedResponseEnc is equivalent to the 'introspection_encrypted_response_enc' client metadata + // value which determines the JWE [RFC7516] algorithm (enc value) as defined in JWA [RFC7518] for content + // encryption of introspection responses. The default, if omitted, is A128CBC-HS256. Note: This parameter MUST NOT + // be specified without setting introspection_encrypted_response_alg. + GetIntrospectionEncryptedResponseEnc() (enc string) + + BaseClient +} + +type decoratedIntrospectionClient struct { + IntrospectionClient +} + +func (r *decoratedIntrospectionClient) GetSigningKeyID() (kid string) { + return r.GetIntrospectionSignedResponseKeyID() +} + +func (r *decoratedIntrospectionClient) GetSigningAlg() (alg string) { + return r.GetIntrospectionSignedResponseAlg() +} + +func (r *decoratedIntrospectionClient) GetEncryptionKeyID() (kid string) { + return r.GetIntrospectionEncryptedResponseKeyID() +} + +func (r *decoratedIntrospectionClient) GetEncryptionAlg() (alg string) { + return r.GetIntrospectionEncryptedResponseAlg() +} + +func (r *decoratedIntrospectionClient) GetEncryptionEnc() (enc string) { + return r.GetIntrospectionEncryptedResponseEnc() +} + +func (r *decoratedIntrospectionClient) IsClientSigned() (is bool) { + return false +} diff --git a/token/jwt/client_test.go b/token/jwt/client_test.go new file mode 100644 index 00000000..91e90280 --- /dev/null +++ b/token/jwt/client_test.go @@ -0,0 +1,71 @@ +package jwt + +import ( + "fmt" + + "github.com/go-jose/go-jose/v4" +) + +type testClient struct { + id string + secret []byte + secretNotPlainText bool + secretNotDefined bool + kid, alg string + encKID, encAlg, enc string + csigned bool + jwks *jose.JSONWebKeySet + jwksURI string +} + +func (r *testClient) GetID() string { + return r.id +} + +func (r *testClient) GetClientSecretPlainText() (secret []byte, ok bool, err error) { + if r.secretNotDefined { + return nil, false, nil + } + + if r.secretNotPlainText { + return nil, true, nil + } + + if r.secret != nil { + return r.secret, true, nil + } + + return nil, true, fmt.Errorf("not supported") +} + +func (r *testClient) GetSigningKeyID() (kid string) { + return r.kid +} + +func (r *testClient) GetSigningAlg() (alg string) { + return r.alg +} + +func (r *testClient) GetEncryptionKeyID() (kid string) { + return r.encKID +} + +func (r *testClient) GetEncryptionAlg() (alg string) { + return r.encAlg +} + +func (r *testClient) GetEncryptionEnc() (enc string) { + return r.enc +} + +func (r *testClient) IsClientSigned() (is bool) { + return r.csigned +} + +func (r *testClient) GetJSONWebKeys() (jwks *jose.JSONWebKeySet) { + return r.jwks +} + +func (r *testClient) GetJSONWebKeysURI() (uri string) { + return r.jwksURI +} diff --git a/token/jwt/consts.go b/token/jwt/consts.go new file mode 100644 index 00000000..c27c8ae8 --- /dev/null +++ b/token/jwt/consts.go @@ -0,0 +1,95 @@ +package jwt + +import ( + "github.com/go-jose/go-jose/v4" + + "authelia.com/provider/oauth2/internal/consts" +) + +const ( + SigningMethodNone = jose.SignatureAlgorithm(JSONWebTokenAlgNone) + + // UnsafeAllowNoneSignatureType is unsafe to use and should be use to correctly sign and verify alg:none JWT tokens. + UnsafeAllowNoneSignatureType unsafeNoneMagicConstant = "none signing method allowed" +) + +type unsafeNoneMagicConstant string + +// Keyfunc is used by parsing methods to supply the key for verification. The function receives the parsed, but +// unverified Token. This allows you to use properties in the Header of the token (such as `kid`) to identify which key +// to use. +type Keyfunc func(token *Token) (key any, err error) + +var ( + // SignatureAlgorithmsNone contain all algorithms including 'none'. + SignatureAlgorithmsNone = []jose.SignatureAlgorithm{JSONWebTokenAlgNone, jose.HS256, jose.HS384, jose.HS512, jose.RS256, jose.RS384, jose.RS512, jose.PS256, jose.PS384, jose.PS512, jose.ES256, jose.ES384, jose.ES512} + + // SignatureAlgorithms contain all algorithms excluding 'none'. + SignatureAlgorithms = []jose.SignatureAlgorithm{jose.HS256, jose.HS384, jose.HS512, jose.RS256, jose.RS384, jose.RS512, jose.PS256, jose.PS384, jose.PS512, jose.ES256, jose.ES384, jose.ES512} + + // EncryptionKeyAlgorithms contains all valid JWE's for OAuth 2.0 and OpenID Connect 1.0. + EncryptionKeyAlgorithms = []jose.KeyAlgorithm{jose.RSA1_5, jose.RSA_OAEP, jose.RSA_OAEP_256, jose.A128KW, jose.A192KW, jose.A256KW, jose.DIRECT, jose.ECDH_ES, jose.ECDH_ES_A128KW, jose.ECDH_ES_A192KW, jose.ECDH_ES_A256KW, jose.A128GCMKW, jose.A192GCMKW, jose.A256GCMKW, jose.PBES2_HS256_A128KW, jose.PBES2_HS384_A192KW, jose.PBES2_HS512_A256KW} + + ContentEncryptionAlgorithms = []jose.ContentEncryption{jose.A128CBC_HS256, jose.A192CBC_HS384, jose.A256CBC_HS512, jose.A128GCM, jose.A192GCM, jose.A256GCM} +) + +const ( + ClaimJWTID = consts.ClaimJWTID + ClaimSessionID = consts.ClaimSessionID + ClaimIssuedAt = consts.ClaimIssuedAt + ClaimNotBefore = consts.ClaimNotBefore + ClaimRequestedAt = consts.ClaimRequestedAt + ClaimExpirationTime = consts.ClaimExpirationTime + ClaimAuthenticationTime = consts.ClaimAuthenticationTime + ClaimIssuer = consts.ClaimIssuer + ClaimSubject = consts.ClaimSubject + ClaimAudience = consts.ClaimAudience + ClaimGroups = consts.ClaimGroups + ClaimFullName = consts.ClaimFullName + ClaimPreferredUsername = consts.ClaimPreferredUsername + ClaimPreferredEmail = consts.ClaimPreferredEmail + ClaimEmailVerified = consts.ClaimEmailVerified + ClaimAuthorizedParty = consts.ClaimAuthorizedParty + ClaimAuthenticationContextClassReference = consts.ClaimAuthenticationContextClassReference + ClaimAuthenticationMethodsReference = consts.ClaimAuthenticationMethodsReference + ClaimClientIdentifier = consts.ClaimClientIdentifier + ClaimScope = consts.ClaimScope + ClaimScopeNonStandard = consts.ClaimScopeNonStandard + ClaimExtra = consts.ClaimExtra + ClaimActive = consts.ClaimActive + ClaimUsername = consts.ClaimUsername + ClaimTokenIntrospection = consts.ClaimTokenIntrospection + ClaimAccessTokenHash = consts.ClaimAccessTokenHash + ClaimCodeHash = consts.ClaimCodeHash + ClaimStateHash = consts.ClaimStateHash + ClaimNonce = consts.ClaimNonce + ClaimAuthorizedActor = consts.ClaimAuthorizedActor + ClaimActor = consts.ClaimActor +) + +const ( + JSONWebTokenHeaderKeyIdentifier = consts.JSONWebTokenHeaderKeyIdentifier + JSONWebTokenHeaderAlgorithm = consts.JSONWebTokenHeaderAlgorithm + JSONWebTokenHeaderEncryptionAlgorithm = consts.JSONWebTokenHeaderEncryptionAlgorithm + JSONWebTokenHeaderCompressionAlgorithm = consts.JSONWebTokenHeaderCompressionAlgorithm + JSONWebTokenHeaderPBES2Count = consts.JSONWebTokenHeaderPBES2Count + + JSONWebTokenHeaderType = consts.JSONWebTokenHeaderType + JSONWebTokenHeaderContentType = consts.JSONWebTokenHeaderContentType +) + +const ( + JSONWebTokenUseSignature = consts.JSONWebTokenUseSignature + JSONWebTokenUseEncryption = consts.JSONWebTokenUseEncryption +) + +const ( + JSONWebTokenTypeJWT = consts.JSONWebTokenTypeJWT + JSONWebTokenTypeAccessToken = consts.JSONWebTokenTypeAccessToken + JSONWebTokenTypeAccessTokenAlternative = consts.JSONWebTokenTypeAccessTokenAlternative + JSONWebTokenTypeTokenIntrospection = consts.JSONWebTokenTypeTokenIntrospection +) + +const ( + JSONWebTokenAlgNone = consts.JSONWebTokenAlgNone +) diff --git a/token/jwt/date.go b/token/jwt/date.go new file mode 100644 index 00000000..53eab0ac --- /dev/null +++ b/token/jwt/date.go @@ -0,0 +1,172 @@ +package jwt + +import ( + "crypto/subtle" + "encoding/json" + "errors" + "fmt" + "math" + "strconv" + "time" +) + +type NumericDate struct { + time.Time +} + +func Now() *NumericDate { + return NewNumericDate(time.Now()) +} + +func NewNumericDate(t time.Time) *NumericDate { + return &NumericDate{t.UTC().Truncate(TimePrecision)} +} + +func newNumericDateFromSeconds(f float64) *NumericDate { + round, frac := math.Modf(f) + + return NewNumericDate(time.Unix(int64(round), int64(frac*1e9))) +} + +func (date NumericDate) MarshalJSON() (b []byte, err error) { + var prec int + + if TimePrecision < time.Second { + prec = int(math.Log10(float64(time.Second) / float64(TimePrecision))) + } + + truncatedDate := date.UTC().Truncate(TimePrecision) + + seconds := strconv.FormatInt(truncatedDate.Unix(), 10) + nanosecondsOffset := strconv.FormatFloat(float64(truncatedDate.Nanosecond())/float64(time.Second), 'f', prec, 64) + + output := append([]byte(seconds), []byte(nanosecondsOffset)[1:]...) + + return output, nil +} + +func (date *NumericDate) UnmarshalJSON(b []byte) (err error) { + var ( + number json.Number + f float64 + ) + + if err = json.Unmarshal(b, &number); err != nil { + return fmt.Errorf("could not parse NumericData: %w", err) + } + + if f, err = number.Float64(); err != nil { + return fmt.Errorf("could not convert json number value to float: %w", err) + } + + n := newNumericDateFromSeconds(f) + *date = *n + + return nil +} + +// Int64 returns the time value with UTC as the location, truncated with TimePrecision; as a number of +// since the Unix epoch. +func (date *NumericDate) Int64() (val int64) { + if date == nil { + return 0 + } + + return date.UTC().Truncate(TimePrecision).Unix() +} + +type ClaimStrings []string + +func (s ClaimStrings) Valid(cmp string, required bool) (valid bool) { + if len(s) == 0 { + return !required + } + + for _, str := range s { + if subtle.ConstantTimeCompare([]byte(str), []byte(cmp)) == 1 { + return true + } + } + + return false +} + +func (s ClaimStrings) ValidAny(cmp ClaimStrings, required bool) (valid bool) { + if len(s) == 0 { + return !required + } + + for _, strCmp := range cmp { + for _, str := range s { + if subtle.ConstantTimeCompare([]byte(str), []byte(strCmp)) == 1 { + return true + } + } + } + + return false +} + +func (s ClaimStrings) ValidAll(cmp ClaimStrings, required bool) (valid bool) { + if len(s) == 0 { + return !required + } + +outer: + for _, strCmp := range cmp { + for _, str := range s { + if subtle.ConstantTimeCompare([]byte(str), []byte(strCmp)) == 1 { + continue outer + } + } + + return false + } + + return true +} + +func (s *ClaimStrings) UnmarshalJSON(data []byte) (err error) { + var value interface{} + + if err = json.Unmarshal(data, &value); err != nil { + return err + } + + var aud []string + + switch v := value.(type) { + case string: + aud = append(aud, v) + case []string: + aud = ClaimStrings(v) + case []interface{}: + for _, vv := range v { + vs, ok := vv.(string) + if !ok { + return ErrInvalidType + } + aud = append(aud, vs) + } + case nil: + return nil + default: + return ErrInvalidType + } + + *s = aud + + return +} + +func (s ClaimStrings) MarshalJSON() (b []byte, err error) { + if len(s) == 1 && !MarshalSingleStringAsArray { + return json.Marshal(s[0]) + } + + return json.Marshal([]string(s)) +} + +var ( + ErrInvalidType = errors.New("invalid type for claim") +) diff --git a/token/jwt/header.go b/token/jwt/header.go index 06c0571e..36550c13 100644 --- a/token/jwt/header.go +++ b/token/jwt/header.go @@ -3,10 +3,6 @@ package jwt -import ( - "authelia.com/provider/oauth2/internal/consts" -) - // Headers is the jwt headers type Headers struct { Extra map[string]any `json:"extra"` @@ -18,7 +14,7 @@ func NewHeaders() *Headers { // ToMap will transform the headers to a map structure func (h *Headers) ToMap() map[string]any { - var filter = map[string]bool{consts.JSONWebTokenHeaderAlgorithm: true} + var filter = map[string]bool{JSONWebTokenHeaderAlgorithm: true} var extra = map[string]any{} // filter known values from extra. diff --git a/token/jwt/issuer.go b/token/jwt/issuer.go new file mode 100644 index 00000000..d90f5595 --- /dev/null +++ b/token/jwt/issuer.go @@ -0,0 +1,144 @@ +package jwt + +import ( + "context" + "crypto/rand" + "crypto/rsa" + "errors" + "fmt" + + "github.com/go-jose/go-jose/v4" +) + +// NewDefaultIssuer returns a new issuer and verifies that one RS256 key exists. +func NewDefaultIssuer(keys ...jose.JSONWebKey) (issuer *DefaultIssuer, err error) { + jwks := &jose.JSONWebKeySet{ + Keys: make([]jose.JSONWebKey, len(keys)), + } + + hasRS256 := false + + for i, key := range keys { + jwks.Keys[i] = key + + if hasRS256 { + continue + } + + if key.Use != JSONWebTokenUseSignature { + continue + } + + if key.Algorithm != string(jose.RS256) { + continue + } + + hasRS256 = true + } + + if !hasRS256 { + return nil, errors.New("no RS256 signature algorithm found") + } + + return NewDefaultIssuerUnverifiedFromJWKS(jwks), nil +} + +func NewDefaultIssuerFromJWKS(jwks *jose.JSONWebKeySet) (issuer *DefaultIssuer, err error) { + for _, key := range jwks.Keys { + if key.Use != JSONWebTokenUseSignature { + continue + } + + if key.Algorithm != string(jose.RS256) { + continue + } + + return &DefaultIssuer{jwks: jwks}, nil + } + + return nil, errors.New("no RS256 signature algorithm found") +} + +// NewDefaultIssuerUnverifiedFromJWKS returns a new issuer from a jose.JSONWebKeySet without verification. +func NewDefaultIssuerUnverifiedFromJWKS(jwks *jose.JSONWebKeySet) (issuer *DefaultIssuer) { + return &DefaultIssuer{jwks: jwks} +} + +// MustNewDefaultIssuerRS256 is the same as NewDefaultIssuerRS256 but it panics if an error occurs. +func MustNewDefaultIssuerRS256(key any) (issuer *DefaultIssuer) { + var err error + + if issuer, err = NewDefaultIssuerRS256(key); err != nil { + panic(err) + } + + return issuer +} + +// NewDefaultIssuerRS256 returns an issuer with a single key and returns an error if it's not an RSA2048 or higher key. +func NewDefaultIssuerRS256(key any) (issuer *DefaultIssuer, err error) { + switch k := key.(type) { + case *rsa.PrivateKey: + if n := k.Size(); n < 256 { + return nil, fmt.Errorf("key must be an *rsa.PrivateKey with at least 2048 bits but got %d", n*8) + } + + return NewDefaultIssuerRS256Unverified(key), nil + default: + return nil, fmt.Errorf("key must be an *rsa.PrivateKey but got %T", k) + } +} + +// NewDefaultIssuerRS256Unverified returns an issuer with a single key asserting the type is an RSA key. +func NewDefaultIssuerRS256Unverified(key any) (issuer *DefaultIssuer) { + return &DefaultIssuer{ + jwks: &jose.JSONWebKeySet{ + Keys: []jose.JSONWebKey{ + { + Key: key, + KeyID: "default", + Algorithm: string(jose.RS256), + Use: JSONWebTokenUseSignature, + }, + }, + }, + } +} + +// GenDefaultIssuer generates a *DefaultIssuer with a random RSA key. +func GenDefaultIssuer() (issuer *DefaultIssuer, err error) { + key, err := rsa.GenerateKey(rand.Reader, 2048) + if err != nil { + return nil, err + } + + return NewDefaultIssuerRS256(key) +} + +// MustGenDefaultIssuer is the same as GenDefaultIssuer but it panics on an error. +func MustGenDefaultIssuer() (issuer *DefaultIssuer) { + var err error + + if issuer, err = GenDefaultIssuer(); err != nil { + panic(err) + } + + return issuer +} + +type DefaultIssuer struct { + jwks *jose.JSONWebKeySet +} + +func (i *DefaultIssuer) GetIssuerJWK(ctx context.Context, kid, alg, use string) (jwk *jose.JSONWebKey, err error) { + return SearchJWKS(i.jwks, kid, alg, use, false) +} + +func (i *DefaultIssuer) GetIssuerStrictJWK(ctx context.Context, kid, alg, use string) (jwk *jose.JSONWebKey, err error) { + return SearchJWKS(i.jwks, kid, alg, use, true) +} + +type Issuer interface { + GetIssuerJWK(ctx context.Context, kid, alg, use string) (jwk *jose.JSONWebKey, err error) + GetIssuerStrictJWK(ctx context.Context, kid, alg, use string) (jwk *jose.JSONWebKey, err error) +} diff --git a/token/jwt/jwt.go b/token/jwt/jwt.go deleted file mode 100644 index 4b325415..00000000 --- a/token/jwt/jwt.go +++ /dev/null @@ -1,198 +0,0 @@ -// Copyright © 2023 Ory Corp -// SPDX-License-Identifier: Apache-2.0 - -// Package jwt is able to generate and validate json web tokens. -// Follows https://datatracker.ietf.org/doc/html/rfc7519 - -package jwt - -import ( - "context" - "crypto" - "crypto/ecdsa" - "crypto/rsa" - "crypto/sha256" - "strings" - - "github.com/go-jose/go-jose/v4" - "github.com/pkg/errors" - - "authelia.com/provider/oauth2/x/errorsx" -) - -type Signer interface { - Generate(ctx context.Context, claims MapClaims, header Mapper) (tokenString string, signature string, err error) - Validate(ctx context.Context, tokenString string) (signature string, err error) - Hash(ctx context.Context, in []byte) ([]byte, error) - Decode(ctx context.Context, tokenString string) (token *Token, err error) - GetSignature(ctx context.Context, token string) (signature string, err error) - GetSigningMethodLength(ctx context.Context) (length int) -} - -var SHA256HashSize = crypto.SHA256.Size() - -type GetPrivateKeyFunc func(ctx context.Context) (any, error) - -// DefaultSigner is responsible for generating and validating JWT challenges -type DefaultSigner struct { - GetPrivateKey GetPrivateKeyFunc -} - -// Generate generates a new authorize code or returns an error. set secret -func (j *DefaultSigner) Generate(ctx context.Context, claims MapClaims, header Mapper) (string, string, error) { - key, err := j.GetPrivateKey(ctx) - if err != nil { - return "", "", err - } - - switch t := key.(type) { - case *jose.JSONWebKey: - return generateToken(claims, header, jose.SignatureAlgorithm(t.Algorithm), t.Key) - case jose.JSONWebKey: - return generateToken(claims, header, jose.SignatureAlgorithm(t.Algorithm), t.Key) - case *rsa.PrivateKey: - return generateToken(claims, header, jose.RS256, t) - case *ecdsa.PrivateKey: - return generateToken(claims, header, jose.ES256, t) - case jose.OpaqueSigner: - switch tt := t.Public().Key.(type) { - case *rsa.PrivateKey: - alg := jose.RS256 - if len(t.Algs()) > 0 { - alg = t.Algs()[0] - } - - return generateToken(claims, header, alg, t) - case *ecdsa.PrivateKey: - alg := jose.ES256 - if len(t.Algs()) > 0 { - alg = t.Algs()[0] - } - - return generateToken(claims, header, alg, t) - default: - return "", "", errors.Errorf("unsupported private / public key pairs: %T, %T", t, tt) - } - default: - return "", "", errors.Errorf("unsupported private key type: %T", t) - } -} - -// Validate validates a token and returns its signature or an error if the token is not valid. -func (j *DefaultSigner) Validate(ctx context.Context, token string) (string, error) { - key, err := j.GetPrivateKey(ctx) - if err != nil { - return "", err - } - - if t, ok := key.(*jose.JSONWebKey); ok { - key = t.Key - } - - switch t := key.(type) { - case *rsa.PrivateKey: - return validateToken(token, t.PublicKey) - case *ecdsa.PrivateKey: - return validateToken(token, t.PublicKey) - case jose.OpaqueSigner: - return validateToken(token, t.Public().Key) - default: - return "", errors.New("Unable to validate token. Invalid PrivateKey type") - } -} - -// Decode will decode a JWT token -func (j *DefaultSigner) Decode(ctx context.Context, token string) (*Token, error) { - key, err := j.GetPrivateKey(ctx) - if err != nil { - return nil, err - } - - if t, ok := key.(*jose.JSONWebKey); ok { - key = t.Key - } - - switch t := key.(type) { - case *rsa.PrivateKey: - return decodeToken(token, t.PublicKey) - case *ecdsa.PrivateKey: - return decodeToken(token, t.PublicKey) - case jose.OpaqueSigner: - return decodeToken(token, t.Public().Key) - default: - return nil, errors.New("Unable to decode token. Invalid PrivateKey type") - } -} - -// GetSignature will return the signature of a token -func (j *DefaultSigner) GetSignature(ctx context.Context, token string) (string, error) { - return getTokenSignature(token) -} - -// Hash will return a given hash based on the byte input or an error upon fail -func (j *DefaultSigner) Hash(ctx context.Context, in []byte) ([]byte, error) { - return hashSHA256(in) -} - -// GetSigningMethodLength will return the length of the signing method -func (j *DefaultSigner) GetSigningMethodLength(ctx context.Context) int { - return SHA256HashSize -} - -func generateToken(claims MapClaims, header Mapper, signingMethod jose.SignatureAlgorithm, privateKey any) (rawToken string, sig string, err error) { - if header == nil || claims == nil { - err = errors.New("either claims or header is nil") - return - } - - token := NewWithClaims(signingMethod, claims) - token.Header = assign(token.Header, header.ToMap()) - - rawToken, err = token.SignedString(privateKey) - if err != nil { - return - } - - sig, err = getTokenSignature(rawToken) - return -} - -func decodeToken(token string, verificationKey any) (*Token, error) { - keyFunc := func(*Token) (any, error) { return verificationKey, nil } - return ParseWithClaims(token, MapClaims{}, keyFunc) -} - -func validateToken(tokenStr string, verificationKey any) (string, error) { - _, err := decodeToken(tokenStr, verificationKey) - if err != nil { - return "", err - } - return getTokenSignature(tokenStr) -} - -func getTokenSignature(token string) (string, error) { - split := strings.Split(token, ".") - if len(split) != 3 { - return "", errors.New("header, body and signature must all be set") - } - return split[2], nil -} - -func hashSHA256(in []byte) ([]byte, error) { - hash := sha256.New() - _, err := hash.Write(in) - if err != nil { - return []byte{}, errorsx.WithStack(err) - } - return hash.Sum([]byte{}), nil -} - -func assign(a, b map[string]any) map[string]any { - for k, w := range b { - if _, ok := a[k]; ok { - continue - } - a[k] = w - } - return a -} diff --git a/token/jwt/jwt_strategy.go b/token/jwt/jwt_strategy.go new file mode 100644 index 00000000..21c54ad9 --- /dev/null +++ b/token/jwt/jwt_strategy.go @@ -0,0 +1,304 @@ +package jwt + +import ( + "context" + "fmt" + + "github.com/go-jose/go-jose/v4" + "github.com/go-jose/go-jose/v4/jwt" + + "authelia.com/provider/oauth2/x/errorsx" +) + +// Strategy represents the strategy for encoding and decoding JWT's. It's important to note that this is an interface +// specifically so it can be mocked and the opts values have very important semantics which are difficult to document. +type Strategy interface { + // Encode a JWT as either a JWS or JWE nested JWS. + Encode(ctx context.Context, claims Claims, opts ...StrategyOpt) (tokenString string, signature string, err error) + + // Decrypt a JWT or if the provided JWT is a JWS just return it. + Decrypt(ctx context.Context, tokenStringEnc string, opts ...StrategyOpt) (tokenString, signature string, jwe *jose.JSONWebEncryption, err error) + + // Decode a JWT. This performs decryption as well as basic signature validation. Optionally the signature validation + // can be skipped and validated later using Validate. + Decode(ctx context.Context, tokenString string, opts ...StrategyOpt) (token *Token, err error) + + // Validate allows performing the signature validation step after using the Decode function without a client while + // also using WithAllowUnverified. + Validate(ctx context.Context, token *Token, opts ...StrategyOpt) (err error) +} + +type StrategyConfig interface { + // GetJWKSFetcherStrategy returns the JWKS fetcher strategy. + GetJWKSFetcherStrategy(ctx context.Context) (strategy JWKSFetcherStrategy) +} + +type JWKSFetcherStrategy interface { + // Resolve returns the JSON Web Key Set, or an error if something went wrong. The forceRefresh, if true, forces + // the strategy to fetch the key from the remote. If forceRefresh is false, the strategy may use a caching strategy + // to fetch the key. + Resolve(ctx context.Context, location string, ignoreCache bool) (jwks *jose.JSONWebKeySet, err error) +} + +// DefaultStrategy is responsible for providing JWK encoding and cryptographic functionality. +type DefaultStrategy struct { + Config StrategyConfig + Issuer Issuer +} + +func (j *DefaultStrategy) Encode(ctx context.Context, claims Claims, opts ...StrategyOpt) (tokenString string, signature string, err error) { + o := &StrategyOpts{ + headers: NewHeaders(), + } + + for _, opt := range opts { + if err = opt(o); err != nil { + return "", "", err + } + } + + var ( + keySig *jose.JSONWebKey + ) + + if o.client == nil { + if keySig, err = j.Issuer.GetIssuerJWK(ctx, "", string(jose.RS256), JSONWebTokenUseSignature); err != nil { + return "", "", errorsx.WithStack(fmt.Errorf("error occurred retrieving issuer jwk: %w", err)) + } + } else if keySig, err = j.Issuer.GetIssuerJWK(ctx, o.client.GetSigningKeyID(), o.client.GetSigningAlg(), JSONWebTokenUseSignature); err != nil { + return "", "", errorsx.WithStack(fmt.Errorf("error occurred retrieving issuer jwk: %w", err)) + } + + if o.client == nil { + return EncodeCompactSigned(ctx, claims, o.headers, keySig) + } + + kid, alg, enc := o.client.GetEncryptionKeyID(), o.client.GetEncryptionAlg(), o.client.GetEncryptionEnc() + + if len(kid) == 0 && len(alg) == 0 { + return EncodeCompactSigned(ctx, claims, o.headers, keySig) + } + + if len(enc) == 0 { + enc = string(jose.A128CBC_HS256) + } + + var keyEnc *jose.JSONWebKey + + if IsEncryptedJWTClientSecretAlg(alg) { + if keyEnc, err = NewClientSecretJWKFromClient(ctx, o.client, kid, alg, enc, JSONWebTokenUseEncryption); err != nil { + return "", "", errorsx.WithStack(fmt.Errorf("Failed to encrypt the JWT using the client secret. %w", err)) + } + } else if keyEnc, err = FindClientPublicJWK(ctx, o.client, j.Config.GetJWKSFetcherStrategy(ctx), kid, alg, JSONWebTokenUseEncryption, false); err != nil { + return "", "", errorsx.WithStack(fmt.Errorf("Failed to encrypt the JWT using the client configuration. %w", err)) + } + + return EncodeNestedCompactEncrypted(ctx, claims, o.headers, o.headersJWE, keySig, keyEnc, jose.ContentEncryption(enc)) +} + +func (j *DefaultStrategy) Decrypt(ctx context.Context, tokenStringEnc string, opts ...StrategyOpt) (tokenString, signature string, jwe *jose.JSONWebEncryption, err error) { + if !IsEncryptedJWT(tokenStringEnc) { + if IsSignedJWT(tokenStringEnc) { + return tokenStringEnc, "", nil, nil + } else { + return tokenStringEnc, "", nil, errorsx.WithStack(&ValidationError{text: "Provided value does not appear to be a JWE or JWS compact serialized JWT", Errors: ValidationErrorMalformedNotCompactSerialized}) + } + } + + o := &StrategyOpts{ + sigAlgorithm: SignatureAlgorithms, + keyAlgorithm: EncryptionKeyAlgorithms, + contentEncryption: ContentEncryptionAlgorithms, + } + + for _, opt := range opts { + if err = opt(o); err != nil { + return "", "", nil, errorsx.WithStack(&ValidationError{Errors: ValidationErrorUnverifiable, Inner: err}) + } + } + + var ( + key *jose.JSONWebKey + ) + + if jwe, err = jose.ParseEncryptedCompact(tokenStringEnc, o.keyAlgorithm, o.contentEncryption); err != nil { + return "", "", nil, errorsx.WithStack(&ValidationError{Errors: ValidationErrorMalformed, Inner: err}) + } + + var ( + kid, alg, enc string + ) + + if kid, alg, enc, _, err = headerValidateJWE(jwe.Header); err != nil { + return "", "", nil, errorsx.WithStack(&ValidationError{Errors: ValidationErrorMalformed, Inner: err}) + } + + if o.jweKeyFunc != nil { + if key, err = o.jweKeyFunc(ctx, jwe, kid, alg); err != nil { + return "", "", nil, errorsx.WithStack(&ValidationError{Errors: ValidationErrorUnverifiable, Inner: err}) + } + } else if IsEncryptedJWTClientSecretAlg(alg) { + if o.client == nil { + return "", "", nil, errorsx.WithStack(&ValidationError{Errors: ValidationErrorUnverifiable, Inner: err}) + } + + if key, err = NewClientSecretJWKFromClient(ctx, o.client, kid, alg, enc, JSONWebTokenUseEncryption); err != nil { + return "", "", nil, errorsx.WithStack(&ValidationError{Errors: ValidationErrorUnverifiable, Inner: err}) + } + } else if key, err = j.Issuer.GetIssuerStrictJWK(ctx, kid, alg, JSONWebTokenUseEncryption); err != nil { + return "", "", nil, errorsx.WithStack(&ValidationError{Errors: ValidationErrorUnverifiable, Inner: err}) + } + + var tokenRaw []byte + + if tokenRaw, err = jwe.Decrypt(key); err != nil { + return "", "", nil, errorsx.WithStack(&ValidationError{Errors: ValidationErrorMalformed, Inner: err}) + } + + tokenString = string(tokenRaw) + + if signature, err = getJWTSignature(tokenString); err != nil { + return "", "", nil, errorsx.WithStack(&ValidationError{Errors: ValidationErrorMalformed, Inner: err}) + } + + return tokenString, signature, jwe, nil +} + +func (j *DefaultStrategy) Decode(ctx context.Context, tokenString string, opts ...StrategyOpt) (token *Token, err error) { + o := &StrategyOpts{ + sigAlgorithm: SignatureAlgorithms, + keyAlgorithm: EncryptionKeyAlgorithms, + contentEncryption: ContentEncryptionAlgorithms, + jwsKeyFunc: nil, + jweKeyFunc: nil, + } + + for _, opt := range opts { + if err = opt(o); err != nil { + return token, errorsx.WithStack(err) + } + } + + var ( + t *jwt.JSONWebToken + jwe *jose.JSONWebEncryption + ) + + tokenString, _, jwe, err = j.Decrypt(ctx, tokenString, opts...) + if err != nil { + return token, err + } + + if t, err = jwt.ParseSigned(tokenString, o.sigAlgorithm); err != nil { + return token, errorsx.WithStack(&ValidationError{Errors: ValidationErrorMalformed, Inner: err}) + } + + if token, err = newToken(t, nil); err != nil { + return token, errorsx.WithStack(err) + } + + token.AssignJWE(jwe) + + claims := MapClaims{} + + if err = t.UnsafeClaimsWithoutVerification(&claims); err != nil { + return token, errorsx.WithStack(&ValidationError{Errors: ValidationErrorClaimsInvalid, Inner: err}) + } + + token.Claims = claims + + var alg string + + if _, alg, err = headerValidateJWS(t.Headers); err != nil { + return token, errorsx.WithStack(&ValidationError{Errors: ValidationErrorMalformed, Inner: err}) + } + + validate := o.client != nil || !o.allowUnverified + + if alg != JSONWebTokenAlgNone && validate { + if err = j.validate(ctx, t, &claims, o); err != nil { + return nil, errorsx.WithStack(err) + } + } + + token.valid = validate + + return token, nil +} + +func (j *DefaultStrategy) Validate(ctx context.Context, token *Token, opts ...StrategyOpt) (err error) { + if token == nil { + return errorsx.WithStack(fmt.Errorf("token is nil")) + } + + if token.valid { + return nil + } + + if token.parsedToken == nil { + return errorsx.WithStack(fmt.Errorf("token is in an inconsistent state")) + } + + o := &StrategyOpts{ + sigAlgorithm: SignatureAlgorithms, + keyAlgorithm: EncryptionKeyAlgorithms, + contentEncryption: ContentEncryptionAlgorithms, + jwsKeyFunc: nil, + jweKeyFunc: nil, + } + + for _, opt := range opts { + if err = opt(o); err != nil { + return errorsx.WithStack(err) + } + } + + if err = j.validate(ctx, token.parsedToken, &MapClaims{}, o); err != nil { + return err + } + + token.valid = true + + return nil +} + +func (j *DefaultStrategy) validate(ctx context.Context, t *jwt.JSONWebToken, dest any, o *StrategyOpts) (err error) { + var ( + key *jose.JSONWebKey + kid, alg string + ) + + if kid, alg, err = headerValidateJWS(t.Headers); err != nil { + return errorsx.WithStack(&ValidationError{Errors: ValidationErrorMalformed, Inner: err}) + } + + claims := MapClaims{} + + if o.jwsKeyFunc != nil { + if key, err = o.jwsKeyFunc(ctx, t, claims); err != nil { + return errorsx.WithStack(&ValidationError{Errors: ValidationErrorUnverifiable, Inner: err}) + } + } else if o.client != nil && o.client.IsClientSigned() { + if IsSignedJWTClientSecretAlgStr(alg) { + if kid != "" { + return errorsx.WithStack(&ValidationError{Errors: ValidationErrorHeaderKeyIDInvalid, Inner: fmt.Errorf("error validating the jws header: alg '%s' does not support tokens with a kid but the token has kid '%s'", alg, kid)}) + } + + if key, err = NewClientSecretJWKFromClient(ctx, o.client, "", alg, "", JSONWebTokenUseSignature); err != nil { + return errorsx.WithStack(&ValidationError{Errors: ValidationErrorUnverifiable, Inner: err}) + } + } else { + if key, err = FindClientPublicJWK(ctx, o.client, j.Config.GetJWKSFetcherStrategy(ctx), kid, alg, JSONWebTokenUseSignature, true); err != nil { + return errorsx.WithStack(&ValidationError{Errors: ValidationErrorUnverifiable, Inner: err}) + } + } + } else if key, err = j.Issuer.GetIssuerStrictJWK(ctx, kid, alg, JSONWebTokenUseSignature); err != nil { + return errorsx.WithStack(&ValidationError{Errors: ValidationErrorUnverifiable, Inner: err}) + } + + if err = t.Claims(getPublicJWK(key), &dest); err != nil { + return errorsx.WithStack(&ValidationError{Errors: ValidationErrorSignatureInvalid, Inner: err}) + } + + return nil +} diff --git a/token/jwt/jwt_strategy_opts.go b/token/jwt/jwt_strategy_opts.go new file mode 100644 index 00000000..e9a442c8 --- /dev/null +++ b/token/jwt/jwt_strategy_opts.go @@ -0,0 +1,180 @@ +package jwt + +import ( + "context" + + "github.com/go-jose/go-jose/v4" + "github.com/go-jose/go-jose/v4/jwt" +) + +type StrategyOpts struct { + client Client + + headers, headersJWE Mapper + + sigAlgorithm []jose.SignatureAlgorithm + keyAlgorithm []jose.KeyAlgorithm + contentEncryption []jose.ContentEncryption + + jwsKeyFunc KeyFuncJWS + jweKeyFunc KeyFuncJWE + + allowUnverified bool +} + +type ( + KeyFuncJWS func(ctx context.Context, token *jwt.JSONWebToken, claims MapClaims) (jwk *jose.JSONWebKey, err error) + KeyFuncJWE func(ctx context.Context, jwe *jose.JSONWebEncryption, kid, alg string) (jwk *jose.JSONWebKey, err error) + StrategyOpt func(opts *StrategyOpts) (err error) +) + +func WithAllowUnverified() StrategyOpt { + return func(opts *StrategyOpts) (err error) { + opts.allowUnverified = true + + return nil + } +} + +func WithHeaders(headers Mapper) StrategyOpt { + return func(opts *StrategyOpts) (err error) { + opts.headers = headers + + return nil + } +} + +func WithHeadersJWE(headers Mapper) StrategyOpt { + return func(opts *StrategyOpts) (err error) { + opts.headersJWE = headers + + return nil + } +} + +func WithClient(client Client) StrategyOpt { + return func(opts *StrategyOpts) (err error) { + opts.client = client + + return nil + } +} + +func WithIDTokenClient(client any) StrategyOpt { + return func(opts *StrategyOpts) (err error) { + switch c := client.(type) { + case IDTokenClient: + opts.client = &decoratedIDTokenClient{IDTokenClient: c} + } + + return nil + } +} + +func WithUserInfoClient(client any) StrategyOpt { + return func(opts *StrategyOpts) (err error) { + switch c := client.(type) { + case UserInfoClient: + opts.client = &decoratedUserInfoClient{UserInfoClient: c} + } + + return nil + } +} + +func WithIntrospectionClient(client any) StrategyOpt { + return func(opts *StrategyOpts) (err error) { + switch c := client.(type) { + case IntrospectionClient: + opts.client = &decoratedIntrospectionClient{IntrospectionClient: c} + } + + return nil + } +} + +func WithJARMClient(client any) StrategyOpt { + return func(opts *StrategyOpts) (err error) { + switch c := client.(type) { + case JARMClient: + opts.client = &decoratedJARMClient{JARMClient: c} + } + + return nil + } +} + +func WithJARClient(client any) StrategyOpt { + return func(opts *StrategyOpts) (err error) { + switch c := client.(type) { + case JARClient: + opts.client = &decoratedJARClient{JARClient: c} + } + + return nil + } +} + +func WithJWTProfileAccessTokenClient(client any) StrategyOpt { + return func(opts *StrategyOpts) (err error) { + switch c := client.(type) { + case JWTProfileAccessTokenClient: + opts.client = &decoratedJWTProfileAccessTokenClient{JWTProfileAccessTokenClient: c} + } + + return nil + } +} + +func WithStatelessJWTProfileIntrospectionClient(client any) StrategyOpt { + return func(opts *StrategyOpts) (err error) { + switch c := client.(type) { + case IntrospectionClient: + opts.client = &decoratedIntrospectionClient{IntrospectionClient: c} + case JWTProfileAccessTokenClient: + opts.client = &decoratedJWTProfileAccessTokenClient{JWTProfileAccessTokenClient: c} + } + + return nil + } +} + +func WithSigAlgorithm(algs ...jose.SignatureAlgorithm) StrategyOpt { + return func(opts *StrategyOpts) (err error) { + opts.sigAlgorithm = algs + + return nil + } +} + +func WithKeyAlgorithm(algs ...jose.KeyAlgorithm) StrategyOpt { + return func(opts *StrategyOpts) (err error) { + opts.keyAlgorithm = algs + + return nil + } +} + +func WithContentEncryption(enc ...jose.ContentEncryption) StrategyOpt { + return func(opts *StrategyOpts) (err error) { + opts.contentEncryption = enc + + return nil + } +} + +func WithKeyFunc(f KeyFuncJWS) StrategyOpt { + return func(opts *StrategyOpts) (err error) { + opts.jwsKeyFunc = f + + return nil + } +} + +func WithKeyFuncJWE(f KeyFuncJWE) StrategyOpt { + return func(opts *StrategyOpts) (err error) { + opts.jweKeyFunc = f + + return nil + } +} diff --git a/token/jwt/jwt_strategy_test.go b/token/jwt/jwt_strategy_test.go new file mode 100644 index 00000000..00738fa8 --- /dev/null +++ b/token/jwt/jwt_strategy_test.go @@ -0,0 +1,724 @@ +package jwt + +import ( + "context" + "crypto/ecdsa" + "crypto/elliptic" + "crypto/rand" + "crypto/rsa" + "crypto/x509" + "encoding/json" + "fmt" + "net/http" + "testing" + "time" + + "github.com/go-jose/go-jose/v4" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestDefaultStrategy(t *testing.T) { + ctx := context.TODO() + + config := &testConfig{} + + issuerRS256, err := rsa.GenerateKey(rand.Reader, 2048) + require.NoError(t, err) + + issuerES512, err := ecdsa.GenerateKey(elliptic.P521(), rand.Reader) + require.NoError(t, err) + + issuerES512enc, err := ecdsa.GenerateKey(elliptic.P521(), rand.Reader) + require.NoError(t, err) + + clientES512, err := ecdsa.GenerateKey(elliptic.P521(), rand.Reader) + require.NoError(t, err) + + clientES512enc, err := ecdsa.GenerateKey(elliptic.P521(), rand.Reader) + require.NoError(t, err) + + issuerJWKS := &jose.JSONWebKeySet{ + Keys: []jose.JSONWebKey{ + { + KeyID: "rs256-sig", + Key: issuerRS256, + Use: JSONWebTokenUseSignature, + Algorithm: string(jose.RS256), + }, + { + KeyID: "es512-sig", + Key: issuerES512, + Use: JSONWebTokenUseSignature, + Algorithm: string(jose.ES512), + }, + { + KeyID: "es512-enc", + Key: issuerES512enc, + Use: JSONWebTokenUseEncryption, + Algorithm: string(jose.ECDH_ES_A256KW), + }, + }, + } + + issuerClientJWKS := &jose.JSONWebKeySet{ + Keys: []jose.JSONWebKey{ + { + KeyID: "rs256-sig", + Key: &issuerRS256.PublicKey, + Use: JSONWebTokenUseSignature, + Algorithm: string(jose.RS256), + }, + { + KeyID: "es512-sig", + Key: &issuerES512.PublicKey, + Use: JSONWebTokenUseSignature, + Algorithm: string(jose.ES512), + }, + { + KeyID: "es512-enc", + Key: &issuerES512enc.PublicKey, + Use: JSONWebTokenUseEncryption, + Algorithm: string(jose.ECDH_ES_A256KW), + }, + }, + } + + issuer := &DefaultIssuer{ + jwks: issuerJWKS, + } + + clientIssuerJWKS := &jose.JSONWebKeySet{ + Keys: []jose.JSONWebKey{ + { + KeyID: "es512-sig", + Key: clientES512, + Use: JSONWebTokenUseSignature, + Algorithm: string(jose.ES512), + }, + { + KeyID: "es512-enc", + Key: clientES512enc, + Use: JSONWebTokenUseEncryption, + Algorithm: string(jose.ECDH_ES_A256KW), + }, + }, + } + + clientJWKS := &jose.JSONWebKeySet{ + Keys: []jose.JSONWebKey{ + { + KeyID: "es512-sig", + Key: &clientES512.PublicKey, + Use: JSONWebTokenUseSignature, + Algorithm: string(jose.ES512), + }, + { + KeyID: "es512-enc", + Key: &clientES512enc.PublicKey, + Use: JSONWebTokenUseEncryption, + Algorithm: string(jose.ECDH_ES_A256KW), + }, + }, + } + + issuerJWKSenc := &jose.JSONWebKeySet{ + Keys: []jose.JSONWebKey{ + { + KeyID: "es512-sig", + Key: &issuerES512.PublicKey, + Use: JSONWebTokenUseSignature, + Algorithm: string(jose.ES512), + }, + { + KeyID: "es512-enc", + Key: &issuerES512enc.PublicKey, + Use: JSONWebTokenUseEncryption, + Algorithm: string(jose.ECDH_ES_A256KW), + }, + }, + } + + clientJWKSenc := &jose.JSONWebKeySet{ + Keys: []jose.JSONWebKey{ + { + KeyID: "es512-sig", + Key: &clientES512.PublicKey, + Use: JSONWebTokenUseSignature, + Algorithm: string(jose.ES512), + }, + { + KeyID: "es512-enc", + Key: &clientES512enc.PublicKey, + Use: JSONWebTokenUseEncryption, + Algorithm: string(jose.ECDH_ES_A256KW), + }, + }, + } + + client := &testClient{ + kid: "es512-sig", + alg: "ES512", + encKID: "", + encAlg: "", + enc: "", + csigned: false, + jwks: clientJWKS, + jwksURI: "", + } + + clientEnc := &testClient{ + kid: "es512-sig", + alg: "ES512", + encKID: "es512-enc", + encAlg: string(jose.ECDH_ES_A256KW), + enc: string(jose.A256GCM), + csigned: false, + jwks: clientJWKSenc, + jwksURI: "", + } + + key128 := make([]byte, 32) + + _, err = rand.Read(key128) + require.NoError(t, err) + + clientEncAsymmetric := &testClient{ + kid: "es512-sig", + alg: "ES512", + encKID: "", + encAlg: string(jose.PBES2_HS256_A128KW), + enc: string(jose.A256GCM), + csigned: true, + secret: key128, + jwks: issuerJWKSenc, + jwksURI: "", + } + + strategy := &DefaultStrategy{ + Config: config, + Issuer: issuer, + } + + claims := MapClaims{ + "value": 1, + } + + headers1 := &Headers{ + Extra: map[string]any{ + JSONWebTokenHeaderType: JSONWebTokenTypeAccessToken, + }, + } + + var headersEnc *Headers + + var ( + token1, signature1 string + ) + + token1, signature1, err = strategy.Encode(ctx, claims, WithHeaders(headers1), WithClient(client)) + require.NoError(t, err) + assert.NotEmpty(t, signature1) + + require.True(t, IsSignedJWT(token1)) + + headersEnc = &Headers{} + + var ( + token2, signature2 string + ) + + headers2 := &Headers{ + Extra: map[string]any{ + JSONWebTokenHeaderType: JSONWebTokenTypeJWT, + }, + } + + token2, signature2, err = strategy.Encode(ctx, claims, WithHeaders(headers2), WithHeadersJWE(headersEnc), WithClient(clientEnc)) + require.NoError(t, err) + require.True(t, IsEncryptedJWT(token2)) + require.NotEmpty(t, signature2) + + var ( + token3, signature3 string + ) + + token3, signature3, err = strategy.Encode(ctx, claims, WithHeaders(headers1), WithHeadersJWE(headersEnc), WithClient(clientEncAsymmetric)) + require.NoError(t, err) + assert.NotEmpty(t, signature3) + + clientIssuer := &DefaultIssuer{ + jwks: clientIssuerJWKS, + } + + clientStrategy := &DefaultStrategy{ + Config: config, + Issuer: clientIssuer, + } + + issuerClient := &testClient{ + kid: "es512-sig", + alg: "ES512", + encKID: "", + encAlg: "", + enc: "", + csigned: true, + jwks: issuerClientJWKS, + jwksURI: "", + } + + tokenString, signature, jwe, err := clientStrategy.Decrypt(ctx, token2, WithClient(clientEncAsymmetric)) + require.NoError(t, err) + assert.NotEmpty(t, signature) + assert.NotEmpty(t, tokenString) + assert.NotNil(t, jwe) + + tokenString, signature, jwe, err = clientStrategy.Decrypt(ctx, token3, WithClient(clientEncAsymmetric)) + require.NoError(t, err) + + tok, err := clientStrategy.Decode(ctx, token1, WithClient(issuerClient)) + require.NoError(t, err) + require.NotNil(t, tok) + + tok, err = clientStrategy.Decode(ctx, token2, WithClient(issuerClient)) + require.NoError(t, err) + + tok, err = clientStrategy.Decode(ctx, token3, WithClient(clientEncAsymmetric)) + require.NoError(t, err) + require.NotNil(t, tok) +} + +func TestDefaultStrategy_Decode_RejectNonCompactSerializedJWT(t *testing.T) { + testCases := []struct { + name string + strategy Strategy + }{ + { + name: "RS256", + strategy: &DefaultStrategy{}, + }, + { + name: "ES256", + strategy: &DefaultStrategy{}, + }, + } + + inputs := []struct { + name string + value string + }{ + {"Empty", ""}, + {"Space", " "}, + {"TwoParts", "foo.bar"}, + {"TwoPartsEmptySecond", "foo."}, + {"TwoPartsEmptyFirst", "foo."}, + } + + for _, tc := range testCases { + for _, input := range inputs { + t.Run(fmt.Sprintf("%s/%s", tc.name, input.name), func(t *testing.T) { + _, err := tc.strategy.Decode(context.TODO(), input.value) + + assert.EqualError(t, err, "Provided value does not appear to be a JWE or JWS compact serialized JWT") + }) + } + } +} + +func TestNestedJWTEncodeDecode(t *testing.T) { + claims := MapClaims{ + "iss": "example.com", + "sub": "john", + "iat": time.Now().UTC().Unix(), + "exp": time.Now().Add(time.Hour).UTC().Unix(), + "aud": []string{"test"}, + } + + providerStrategy := &DefaultStrategy{ + Config: &testConfig{}, + Issuer: NewDefaultIssuerUnverifiedFromJWKS(&jose.JSONWebKeySet{ + Keys: []jose.JSONWebKey{ + testKeySigRSA, + testKeySigECDSA, + }, + }), + } + + encodeClientRSA := &testClient{ + id: "test", + kid: "test-rsa-sig", + alg: string(jose.RS256), + encKID: "test-rsa-enc", + encAlg: string(jose.RSA_OAEP_256), + enc: string(jose.A128GCM), + jwks: &jose.JSONWebKeySet{ + Keys: []jose.JSONWebKey{ + testKeyPublicEncRSA, + testKeyPublicEncECDSA, + }, + }, + } + + tokenString, sig, err := providerStrategy.Encode(context.TODO(), claims, WithClient(encodeClientRSA)) + require.NoError(t, err) + assert.NotEmpty(t, sig) + assert.NotEmpty(t, tokenString) + + clientStrategy := &DefaultStrategy{ + Config: &testConfig{}, + Issuer: NewDefaultIssuerUnverifiedFromJWKS(&jose.JSONWebKeySet{ + Keys: []jose.JSONWebKey{ + testKeyEncRSA, + testKeyEncECDSA, + }, + }), + } + + decodeClientRSA := &testClient{ + id: "test", + kid: "test-rsa-sig", + alg: string(jose.RS256), + encKID: "test-rsa-enc", + encAlg: string(jose.RSA_OAEP_256), + enc: string(jose.A128GCM), + jwks: &jose.JSONWebKeySet{ + Keys: []jose.JSONWebKey{ + testKeyPublicSigRSA, + testKeyPublicSigECDSA, + }, + }, + csigned: true, + } + + token, err := clientStrategy.Decode(context.TODO(), tokenString, WithClient(decodeClientRSA)) + require.NoError(t, err) + + assert.NotNil(t, token) + + assert.NoError(t, token.Valid(ValidateAlgorithm(string(jose.RS256)), ValidateKeyAlgorithm(string(jose.RSA_OAEP_256)), ValidateContentEncryption(string(jose.A128GCM)), ValidateKeyID("test-rsa-sig"), ValidateEncryptionKeyID("test-rsa-enc"))) + assert.NoError(t, token.Claims.Valid(ValidateRequireExpiresAt(), ValidateRequireIssuedAt(), ValidateIssuer("example.com"), ValidateAudienceAny("test"))) + assert.EqualError(t, token.Claims.Valid(ValidateAudienceAny("nope")), "Token has invalid audience") + + encodeClientECDSA := &testClient{ + id: "test", + kid: "test-ecdsa-sig", + alg: string(jose.ES256), + encKID: "test-ecdsa-enc", + encAlg: string(jose.ECDH_ES_A128KW), + enc: string(jose.A128GCM), + jwks: &jose.JSONWebKeySet{ + Keys: []jose.JSONWebKey{ + testKeyPublicEncRSA, + testKeyPublicEncECDSA, + }, + }, + } + + tokenString, sig, err = providerStrategy.Encode(context.TODO(), claims, WithClient(encodeClientECDSA)) + require.NoError(t, err) + assert.NotEmpty(t, sig) + assert.NotEmpty(t, tokenString) + + clientStrategy = &DefaultStrategy{ + Config: &testConfig{}, + Issuer: NewDefaultIssuerUnverifiedFromJWKS(&jose.JSONWebKeySet{ + Keys: []jose.JSONWebKey{ + testKeyEncRSA, + testKeyEncECDSA, + }, + }), + } + + decodeClientECDSA := &testClient{ + id: "test", + kid: "test-ecdsa-sig", + alg: string(jose.RS256), + encKID: "test-ecdsa-enc", + encAlg: string(jose.RSA_OAEP_256), + enc: string(jose.A128GCM), + jwks: &jose.JSONWebKeySet{ + Keys: []jose.JSONWebKey{ + testKeyPublicSigRSA, + testKeyPublicSigECDSA, + }, + }, + csigned: true, + } + + token, err = clientStrategy.Decode(context.TODO(), tokenString, WithClient(decodeClientECDSA)) + require.NoError(t, err) + + assert.NotNil(t, token) + + assert.NoError(t, token.Valid(ValidateAlgorithm(string(jose.ES256)), ValidateKeyAlgorithm(string(jose.ECDH_ES_A128KW)), ValidateContentEncryption(string(jose.A128GCM)), ValidateKeyID("test-ecdsa-sig"), ValidateEncryptionKeyID("test-ecdsa-enc"))) + assert.NoError(t, token.Claims.Valid(ValidateRequireExpiresAt(), ValidateRequireIssuedAt(), ValidateIssuer("example.com"), ValidateAudienceAny("test"))) + assert.EqualError(t, token.Claims.Valid(ValidateAudienceAny("nope")), "Token has invalid audience") + + k, _ := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) + + decodeClientECDSA = &testClient{ + id: "test", + kid: "test-ecdsa-sig", + alg: string(jose.RS256), + encKID: "test-ecdsa-enc", + encAlg: string(jose.RSA_OAEP_256), + enc: string(jose.A128GCM), + jwks: &jose.JSONWebKeySet{ + Keys: []jose.JSONWebKey{ + testKeyPublicSigRSA, + { + Key: k, + KeyID: "test-ecdsa-sig", + Use: "sig", + Algorithm: string(jose.ES256), + }, + }, + }, + csigned: true, + } + + token, err = clientStrategy.Decode(context.TODO(), tokenString, WithClient(decodeClientECDSA)) + assert.EqualError(t, err, "go-jose/go-jose: error in cryptographic primitive") + + clientStrategy = &DefaultStrategy{ + Config: &testConfig{}, + Issuer: NewDefaultIssuerUnverifiedFromJWKS(&jose.JSONWebKeySet{ + Keys: []jose.JSONWebKey{ + testKeyEncRSA, + }, + }), + } + + token, err = clientStrategy.Decode(context.TODO(), tokenString, WithClient(decodeClientECDSA)) + assert.EqualError(t, err, "Error occurred retrieving the JSON Web Key. The JSON Web Token uses signing key with kid 'test-ecdsa-enc' which was not found") + + clientStrategy = &DefaultStrategy{ + Config: &testConfig{}, + Issuer: NewDefaultIssuerUnverifiedFromJWKS(&jose.JSONWebKeySet{ + Keys: []jose.JSONWebKey{ + testKeyEncRSA, + { + Key: k, + KeyID: "test-ecdsa-enc", + Algorithm: string(jose.ECDH_ES_A128KW), + Use: "enc", + }, + }, + }), + } + + token, err = clientStrategy.Decode(context.TODO(), tokenString, WithClient(decodeClientECDSA)) + assert.EqualError(t, err, "go-jose/go-jose: error in cryptographic primitive") +} + +func TestShouldDecodeEncrypedTokens(t *testing.T) { + testCases := []struct { + name string + have string + }{ + { + "ShouldDecodeRS256", + testCompactSerializedNestedJWEWithRSA, + }, + { + "ShouldDecodeES256", + testCompactSerializedNestedJWEWithECDSA, + }, + } + + for _, tc := range testCases { + strategy := &DefaultStrategy{ + Config: &testConfig{}, + Issuer: NewDefaultIssuerUnverifiedFromJWKS(&jose.JSONWebKeySet{ + Keys: []jose.JSONWebKey{ + testKeyEncRSA, + testKeyEncECDSA, + }, + }), + } + + client := &testClient{ + id: "test", + jwks: &jose.JSONWebKeySet{ + Keys: []jose.JSONWebKey{ + testKeyPublicSigRSA, + testKeyPublicSigECDSA, + }, + }, + csigned: true, + } + + token, err := strategy.Decode(context.Background(), tc.have, WithClient(client)) + assert.NoError(t, err) + assert.NotNil(t, token) + + assert.NoError(t, token.Valid()) + assert.NoError(t, token.Claims.Valid(ValidateIssuer("example.com"), ValidateRequireIssuedAt(), ValidateRequireExpiresAt(), ValidateSubject("john"))) + } +} + +type testConfig struct{} + +func (*testConfig) GetJWKSFetcherStrategy(ctx context.Context) (strategy JWKSFetcherStrategy) { + return &testFetcher{client: http.DefaultClient} +} + +type testFetcher struct { + client *http.Client +} + +func (f *testFetcher) Resolve(ctx context.Context, location string, _ bool) (jwks *jose.JSONWebKeySet, err error) { + var req *http.Request + + if req, err = http.NewRequest(http.MethodGet, location, nil); err != nil { + return nil, err + } + + req.WithContext(ctx) + + var resp *http.Response + + if resp, err = f.client.Do(req); err != nil { + return nil, err + } + + defer resp.Body.Close() + + decoder := json.NewDecoder(resp.Body) + + jwks = &jose.JSONWebKeySet{} + + if err = decoder.Decode(jwks); err != nil { + return nil, err + } + + return jwks, nil +} + +var ( + testKeyBytesSigRSA = []byte{48, 130, 4, 189, 2, 1, 0, 48, 13, 6, 9, 42, 134, 72, 134, 247, 13, 1, 1, 1, 5, 0, 4, 130, 4, 167, 48, 130, 4, 163, 2, 1, 0, 2, 130, 1, 1, 0, 225, 139, 207, 167, 201, 162, 77, 23, 186, 56, 185, 177, 22, 31, 170, 37, 145, 252, 98, 102, 229, 160, 101, 31, 227, 76, 122, 70, 190, 1, 96, 53, 29, 58, 134, 211, 50, 166, 81, 143, 163, 141, 118, 165, 192, 238, 26, 250, 66, 42, 204, 164, 160, 205, 124, 175, 215, 202, 172, 74, 210, 134, 147, 10, 64, 58, 134, 46, 185, 15, 164, 126, 106, 106, 89, 206, 12, 95, 74, 229, 236, 132, 101, 98, 140, 40, 94, 177, 113, 24, 215, 184, 7, 210, 23, 98, 129, 207, 27, 221, 233, 168, 88, 68, 180, 157, 65, 141, 145, 53, 186, 230, 143, 128, 77, 47, 240, 181, 142, 236, 105, 119, 177, 231, 0, 123, 86, 215, 201, 165, 158, 234, 185, 236, 225, 72, 237, 28, 192, 193, 45, 212, 194, 19, 195, 9, 140, 141, 143, 3, 148, 118, 244, 25, 50, 18, 9, 37, 199, 202, 225, 142, 44, 138, 99, 84, 25, 162, 48, 243, 44, 31, 38, 88, 62, 244, 2, 64, 84, 53, 105, 92, 175, 156, 152, 32, 191, 26, 39, 153, 166, 188, 190, 33, 221, 116, 180, 174, 229, 191, 129, 198, 183, 23, 178, 20, 183, 250, 29, 42, 63, 63, 107, 170, 126, 121, 31, 90, 17, 180, 63, 137, 86, 255, 49, 163, 181, 91, 110, 160, 163, 147, 67, 149, 152, 218, 212, 23, 231, 76, 133, 208, 190, 161, 230, 4, 156, 186, 206, 145, 106, 7, 213, 50, 161, 159, 2, 3, 1, 0, 1, 2, 130, 1, 0, 106, 140, 145, 220, 193, 244, 90, 87, 11, 50, 33, 6, 247, 92, 158, 20, 129, 146, 169, 41, 210, 240, 162, 213, 29, 155, 211, 103, 247, 250, 206, 104, 73, 22, 140, 250, 216, 194, 153, 101, 49, 238, 114, 78, 123, 134, 0, 88, 153, 73, 126, 195, 134, 243, 140, 35, 197, 221, 136, 231, 15, 237, 99, 41, 68, 142, 97, 53, 81, 87, 130, 109, 245, 247, 167, 213, 31, 35, 37, 78, 217, 28, 242, 136, 75, 142, 6, 173, 236, 175, 191, 184, 192, 121, 15, 115, 9, 191, 189, 122, 104, 23, 143, 27, 101, 247, 164, 48, 44, 153, 37, 98, 38, 8, 134, 110, 79, 88, 117, 220, 89, 54, 162, 100, 110, 101, 213, 239, 215, 216, 210, 212, 103, 101, 141, 155, 163, 163, 92, 200, 89, 244, 21, 136, 197, 41, 119, 24, 87, 64, 179, 1, 128, 223, 166, 65, 16, 163, 99, 42, 251, 49, 59, 200, 176, 174, 9, 79, 90, 174, 171, 221, 68, 38, 200, 224, 123, 116, 79, 105, 97, 196, 164, 173, 200, 47, 199, 130, 84, 201, 111, 87, 76, 249, 117, 200, 83, 104, 195, 123, 171, 176, 39, 221, 101, 23, 152, 39, 148, 179, 79, 32, 135, 148, 86, 252, 85, 226, 50, 222, 84, 230, 174, 202, 149, 64, 87, 170, 2, 7, 20, 26, 118, 251, 7, 161, 55, 11, 127, 143, 78, 169, 209, 225, 94, 173, 164, 149, 37, 191, 21, 182, 38, 56, 168, 129, 2, 129, 129, 0, 238, 131, 86, 36, 159, 247, 119, 204, 99, 214, 255, 44, 244, 160, 224, 151, 30, 172, 198, 6, 76, 189, 52, 147, 11, 99, 164, 161, 245, 49, 224, 145, 118, 186, 229, 106, 207, 53, 208, 16, 74, 222, 118, 57, 230, 237, 5, 165, 224, 90, 194, 146, 162, 98, 85, 30, 162, 195, 214, 117, 130, 141, 43, 225, 169, 222, 247, 190, 77, 187, 244, 50, 2, 35, 153, 192, 253, 128, 34, 227, 128, 209, 145, 41, 192, 79, 185, 78, 169, 144, 71, 211, 58, 107, 50, 125, 152, 174, 177, 58, 121, 239, 95, 47, 248, 156, 53, 97, 126, 48, 112, 232, 206, 60, 139, 36, 111, 213, 98, 254, 233, 211, 168, 187, 115, 205, 142, 45, 2, 129, 129, 0, 242, 21, 25, 43, 57, 228, 145, 144, 61, 19, 122, 145, 201, 78, 122, 47, 209, 218, 227, 90, 209, 224, 174, 252, 191, 209, 172, 99, 217, 112, 127, 131, 200, 134, 90, 159, 37, 16, 46, 86, 118, 145, 140, 31, 83, 194, 111, 23, 83, 18, 151, 223, 126, 58, 186, 235, 33, 180, 24, 200, 101, 114, 148, 199, 203, 57, 190, 239, 21, 45, 194, 140, 182, 45, 222, 30, 162, 173, 189, 249, 203, 158, 162, 138, 246, 185, 164, 21, 216, 228, 146, 180, 162, 165, 48, 170, 215, 113, 204, 223, 200, 194, 140, 54, 191, 157, 251, 30, 218, 126, 228, 27, 228, 158, 30, 126, 131, 243, 169, 230, 172, 88, 183, 51, 70, 241, 218, 123, 2, 129, 129, 0, 156, 61, 214, 201, 73, 45, 0, 2, 25, 8, 246, 193, 201, 66, 53, 189, 104, 239, 191, 12, 211, 106, 66, 45, 109, 17, 138, 0, 58, 49, 193, 45, 40, 252, 199, 90, 79, 128, 173, 218, 110, 97, 10, 75, 101, 213, 176, 148, 119, 194, 156, 161, 23, 212, 152, 115, 232, 37, 167, 175, 244, 164, 107, 177, 120, 232, 193, 155, 157, 42, 89, 142, 4, 206, 179, 98, 179, 237, 35, 109, 170, 174, 29, 140, 159, 24, 218, 136, 8, 21, 166, 167, 93, 38, 105, 189, 210, 173, 229, 21, 44, 89, 61, 30, 156, 154, 31, 113, 205, 11, 8, 123, 200, 213, 234, 68, 37, 42, 64, 158, 66, 40, 79, 232, 243, 180, 28, 197, 2, 129, 128, 115, 240, 44, 212, 169, 238, 80, 212, 142, 155, 180, 152, 251, 155, 77, 35, 119, 210, 232, 14, 7, 244, 30, 122, 71, 247, 200, 35, 45, 241, 21, 240, 236, 105, 132, 31, 49, 229, 244, 251, 77, 223, 217, 6, 235, 219, 115, 206, 236, 231, 59, 187, 58, 190, 47, 229, 10, 136, 49, 82, 80, 91, 182, 235, 148, 229, 252, 14, 142, 203, 18, 160, 199, 99, 98, 60, 179, 214, 151, 228, 121, 99, 105, 31, 58, 152, 160, 0, 34, 151, 29, 183, 203, 41, 104, 12, 122, 16, 51, 121, 125, 177, 198, 235, 53, 140, 24, 199, 167, 7, 28, 130, 75, 84, 122, 240, 70, 139, 188, 244, 15, 216, 145, 44, 202, 174, 107, 223, 2, 129, 128, 106, 85, 157, 106, 91, 201, 27, 113, 197, 111, 239, 104, 141, 30, 73, 67, 30, 204, 18, 195, 1, 99, 13, 200, 69, 81, 13, 185, 250, 196, 26, 127, 67, 184, 226, 65, 176, 119, 163, 86, 176, 24, 120, 179, 50, 36, 76, 156, 108, 138, 164, 204, 65, 133, 112, 236, 122, 246, 227, 137, 244, 216, 112, 246, 212, 114, 24, 155, 88, 42, 17, 161, 70, 196, 67, 90, 209, 73, 58, 73, 82, 26, 116, 15, 229, 107, 35, 158, 89, 49, 241, 154, 7, 230, 219, 92, 234, 144, 136, 4, 221, 149, 130, 120, 64, 127, 225, 248, 241, 183, 6, 25, 225, 10, 236, 21, 141, 152, 122, 70, 111, 82, 177, 175, 205, 116, 72, 142} + testKeyBytesEncRSA = []byte{48, 130, 4, 191, 2, 1, 0, 48, 13, 6, 9, 42, 134, 72, 134, 247, 13, 1, 1, 1, 5, 0, 4, 130, 4, 169, 48, 130, 4, 165, 2, 1, 0, 2, 130, 1, 1, 0, 172, 196, 70, 138, 2, 105, 68, 113, 29, 120, 167, 117, 74, 253, 195, 218, 40, 62, 221, 198, 42, 216, 67, 65, 21, 181, 204, 211, 51, 45, 62, 127, 3, 219, 96, 95, 31, 191, 226, 255, 108, 87, 135, 133, 134, 197, 74, 188, 3, 244, 121, 123, 171, 192, 82, 213, 54, 61, 142, 226, 42, 71, 79, 59, 30, 197, 1, 67, 182, 236, 39, 62, 5, 234, 69, 4, 167, 72, 82, 76, 46, 146, 234, 117, 99, 90, 189, 205, 19, 75, 210, 105, 225, 110, 172, 236, 14, 158, 33, 176, 193, 58, 136, 147, 140, 151, 222, 181, 95, 11, 121, 56, 172, 215, 239, 222, 200, 76, 237, 183, 50, 104, 240, 1, 164, 197, 247, 180, 54, 216, 196, 34, 137, 211, 108, 74, 222, 188, 190, 84, 118, 244, 249, 97, 192, 147, 126, 67, 209, 24, 80, 37, 180, 88, 169, 112, 37, 242, 249, 49, 28, 152, 168, 182, 175, 21, 80, 88, 153, 23, 132, 136, 53, 25, 84, 128, 216, 88, 118, 173, 4, 241, 238, 122, 226, 190, 134, 116, 167, 94, 196, 131, 175, 156, 213, 115, 140, 105, 63, 15, 31, 237, 25, 243, 85, 156, 85, 37, 4, 238, 154, 36, 10, 235, 11, 213, 222, 238, 208, 69, 197, 201, 139, 76, 3, 137, 214, 175, 63, 112, 150, 41, 24, 122, 110, 250, 27, 14, 48, 7, 63, 158, 35, 107, 59, 185, 141, 200, 178, 123, 10, 97, 247, 100, 9, 85, 2, 3, 1, 0, 1, 2, 130, 1, 1, 0, 152, 106, 91, 244, 187, 37, 213, 60, 153, 140, 108, 231, 156, 109, 253, 207, 195, 123, 154, 185, 141, 232, 214, 132, 95, 187, 208, 100, 110, 156, 182, 170, 229, 99, 47, 69, 28, 68, 115, 229, 116, 214, 79, 119, 236, 42, 183, 192, 225, 24, 87, 232, 83, 224, 74, 243, 80, 115, 196, 79, 32, 143, 98, 133, 188, 162, 126, 120, 23, 179, 132, 247, 65, 206, 168, 110, 239, 137, 109, 25, 74, 105, 80, 48, 153, 163, 95, 24, 193, 178, 61, 130, 45, 96, 47, 107, 221, 133, 130, 33, 102, 134, 214, 32, 157, 131, 9, 246, 38, 80, 127, 244, 17, 0, 59, 220, 230, 6, 128, 29, 3, 122, 242, 105, 240, 204, 185, 182, 47, 40, 194, 94, 4, 152, 28, 251, 15, 21, 148, 149, 219, 180, 253, 154, 53, 60, 3, 82, 206, 27, 97, 137, 228, 105, 87, 0, 66, 43, 198, 230, 33, 30, 84, 92, 138, 116, 76, 202, 250, 90, 102, 81, 67, 19, 159, 126, 130, 200, 208, 208, 60, 228, 166, 103, 212, 19, 138, 196, 57, 183, 244, 13, 175, 147, 198, 124, 164, 40, 50, 3, 64, 158, 35, 238, 116, 55, 215, 168, 63, 173, 11, 78, 244, 200, 130, 120, 89, 164, 58, 35, 68, 254, 141, 69, 157, 97, 123, 97, 217, 112, 51, 2, 126, 109, 182, 243, 11, 147, 93, 125, 140, 86, 157, 156, 31, 197, 119, 225, 67, 65, 64, 190, 173, 0, 143, 1, 2, 129, 129, 0, 207, 234, 215, 66, 208, 61, 102, 219, 140, 108, 97, 58, 204, 132, 211, 206, 210, 223, 212, 210, 6, 208, 177, 231, 250, 31, 35, 164, 213, 46, 55, 179, 221, 134, 79, 253, 76, 75, 187, 250, 75, 70, 163, 150, 122, 190, 243, 196, 169, 187, 255, 91, 67, 42, 110, 35, 23, 169, 114, 47, 165, 89, 10, 97, 50, 61, 53, 45, 153, 81, 209, 82, 82, 20, 4, 15, 61, 204, 185, 24, 246, 63, 153, 3, 135, 170, 45, 112, 59, 36, 209, 49, 9, 224, 203, 226, 40, 178, 20, 19, 19, 73, 249, 21, 122, 188, 247, 161, 20, 226, 124, 14, 66, 4, 12, 37, 75, 35, 37, 36, 22, 216, 236, 189, 44, 179, 102, 193, 2, 129, 129, 0, 212, 184, 112, 164, 6, 140, 127, 190, 0, 170, 229, 42, 252, 99, 70, 29, 251, 176, 41, 87, 91, 226, 204, 173, 164, 208, 192, 70, 0, 238, 113, 97, 239, 147, 247, 156, 80, 98, 231, 47, 28, 186, 99, 25, 144, 18, 165, 68, 117, 99, 54, 99, 241, 109, 77, 82, 250, 105, 218, 148, 205, 139, 49, 181, 68, 110, 50, 146, 212, 149, 191, 97, 10, 83, 55, 240, 23, 114, 83, 116, 25, 118, 195, 85, 72, 74, 142, 57, 97, 31, 65, 4, 93, 46, 177, 32, 248, 146, 5, 0, 189, 181, 123, 116, 27, 197, 98, 195, 158, 196, 112, 60, 177, 31, 246, 237, 246, 126, 248, 63, 245, 185, 1, 224, 165, 186, 251, 149, 2, 129, 129, 0, 150, 10, 132, 249, 68, 73, 107, 54, 184, 169, 101, 169, 6, 250, 59, 215, 159, 57, 195, 221, 36, 233, 233, 216, 220, 25, 40, 161, 196, 237, 171, 104, 243, 77, 255, 223, 108, 245, 162, 91, 199, 130, 220, 126, 181, 105, 163, 132, 162, 112, 118, 160, 167, 97, 177, 69, 69, 200, 20, 12, 234, 39, 205, 99, 194, 219, 132, 202, 185, 63, 223, 236, 166, 42, 167, 155, 80, 31, 178, 219, 158, 168, 218, 133, 63, 155, 193, 90, 162, 115, 185, 58, 200, 68, 31, 29, 191, 252, 114, 156, 41, 105, 82, 132, 251, 163, 238, 151, 161, 248, 167, 73, 170, 190, 60, 253, 148, 177, 114, 22, 15, 30, 208, 8, 220, 127, 66, 129, 2, 129, 128, 60, 200, 195, 111, 43, 107, 228, 104, 191, 186, 21, 168, 21, 220, 172, 65, 143, 21, 4, 139, 48, 247, 122, 243, 55, 128, 107, 32, 213, 205, 76, 218, 230, 97, 202, 196, 128, 247, 242, 5, 181, 88, 209, 78, 145, 171, 178, 76, 0, 155, 44, 4, 157, 13, 85, 166, 27, 102, 58, 14, 129, 57, 128, 39, 194, 249, 22, 60, 124, 192, 153, 162, 58, 24, 19, 136, 232, 186, 67, 124, 142, 118, 48, 84, 227, 70, 98, 163, 164, 204, 16, 129, 21, 187, 108, 227, 246, 3, 139, 168, 109, 141, 57, 76, 177, 78, 210, 237, 1, 38, 50, 200, 52, 248, 228, 79, 149, 59, 44, 230, 225, 233, 78, 207, 9, 172, 135, 141, 2, 129, 129, 0, 159, 97, 92, 8, 76, 63, 113, 168, 5, 233, 58, 228, 155, 140, 19, 198, 225, 22, 110, 129, 71, 157, 118, 55, 94, 254, 112, 242, 244, 198, 217, 99, 175, 76, 225, 115, 14, 167, 243, 188, 219, 109, 122, 165, 15, 135, 5, 225, 235, 233, 121, 203, 211, 175, 135, 69, 205, 46, 93, 119, 95, 41, 226, 44, 233, 73, 97, 76, 15, 211, 139, 211, 206, 9, 176, 77, 30, 228, 93, 183, 93, 115, 237, 93, 140, 190, 8, 156, 58, 73, 145, 4, 224, 39, 159, 83, 234, 216, 224, 97, 165, 120, 123, 253, 159, 147, 119, 201, 144, 155, 171, 114, 129, 31, 44, 3, 58, 232, 180, 193, 155, 180, 143, 220, 132, 233, 153, 219} + testKeySigRSA jose.JSONWebKey + testKeyPublicSigRSA jose.JSONWebKey + testKeyEncRSA jose.JSONWebKey + testKeyPublicEncRSA jose.JSONWebKey + + testKeyBytesSigECDSA = []byte{48, 129, 135, 2, 1, 0, 48, 19, 6, 7, 42, 134, 72, 206, 61, 2, 1, 6, 8, 42, 134, 72, 206, 61, 3, 1, 7, 4, 109, 48, 107, 2, 1, 1, 4, 32, 225, 45, 16, 217, 198, 48, 46, 37, 59, 165, 201, 242, 244, 143, 253, 127, 88, 84, 100, 25, 17, 39, 23, 128, 105, 241, 16, 227, 43, 47, 141, 187, 161, 68, 3, 66, 0, 4, 155, 65, 151, 191, 69, 30, 154, 72, 179, 179, 4, 5, 106, 97, 81, 20, 114, 8, 188, 137, 58, 81, 123, 3, 26, 111, 172, 26, 107, 212, 60, 52, 154, 177, 135, 254, 199, 8, 246, 198, 147, 23, 228, 46, 70, 145, 133, 222, 82, 222, 243, 113, 9, 10, 149, 59, 21, 144, 195, 215, 174, 175, 82, 51} + testKeyBytesEncECDSA = []byte{48, 129, 135, 2, 1, 0, 48, 19, 6, 7, 42, 134, 72, 206, 61, 2, 1, 6, 8, 42, 134, 72, 206, 61, 3, 1, 7, 4, 109, 48, 107, 2, 1, 1, 4, 32, 107, 125, 138, 26, 157, 158, 163, 251, 241, 207, 65, 183, 174, 68, 61, 135, 25, 188, 173, 245, 37, 8, 122, 233, 53, 113, 233, 221, 63, 91, 240, 100, 161, 68, 3, 66, 0, 4, 65, 254, 168, 68, 75, 55, 98, 71, 60, 193, 218, 125, 32, 224, 117, 19, 63, 145, 62, 117, 104, 107, 245, 157, 112, 192, 2, 44, 153, 73, 158, 193, 235, 150, 58, 174, 115, 25, 79, 68, 111, 142, 179, 168, 231, 172, 54, 214, 101, 18, 244, 173, 69, 22, 255, 235, 73, 227, 247, 206, 254, 183, 17, 177} + testKeySigECDSA jose.JSONWebKey + testKeyPublicSigECDSA jose.JSONWebKey + testKeyEncECDSA jose.JSONWebKey + testKeyPublicEncECDSA jose.JSONWebKey + + testCompactSerializedNestedJWEWithRSA = "eyJhbGciOiJSU0EtT0FFUC0yNTYiLCJjdHkiOiJKV1QiLCJlbmMiOiJBMTI4R0NNIiwia2lkIjoidGVzdC1yc2EtZW5jIiwidHlwIjoiSldUIn0.oWwJ7cVudU3U_EYu9vc4bXQqg7uH3xOvmjlRDKKFssc7oRDnM5IG3mvPpPnXm-jJhB0q4pGAHsVrCeRARUHTrkNWLgU4NtNhEhaeNVBoV38KxNyMXvfBdYoc91wPwN1TBQgSSBJ7FZPwDnUODAIIxh4NTyVJt6mEPlkDNrzViwv2zRhkfosdoiUJjbee2G5tQOV3Jj5o9gKrQOwZ9fJry-zKgOSeb0VR9s9vdfL8STKnwUnQ2HIYpsEG19IpXVut17Y8cIg43n65NrAHNxv2wRj-2vWij-bXM-YugSbNB7LH1n5H2wNW20nBKlYX7oLGWJHPuVxYftFIaPsm7m3sBA.6vLOfe12rdKCd0RW.-zYmgfbJNag4h0x3NQhzZCVAgczLDFYnW7fWKE51-ZyiUllRPiNuTQg-EQvwwNmFRNvwWpWFEpuhH9ceYYqfwrY0ZAqT7m3nDM24xvx46B2jZJvY3fCxB6XA72Afw4yxnRio_KMvS_vrTGsp4sLlHujolAAc3j-8y53uJ1y1mvfsKPT8i8YLyBtOT0wm1hVZAdyue-TA7iXjXRrbe2IdbW_FagXXyBW8JhOhFe9O5-T1Ts_SbPUQhwSYb5pIfUjfDLkVeEcOB9QNQw1Ai_bN2bgOh23sD4aS6G1alWEN0zHAzp0qgS7NXbyqBQefaUlcquSdM9JB9arLVGcy2IPFuC2zy7oppfcpCqmjPhXXJ0-WYA92FSAsREJCh6KsVu445KrfEQalyfmMd6qE2agDxqnDgdrIxlMzzxwCc4FvlwZ-3c3SDY5sZK_-auynoHx1adeguU1wPWY3Wy-2tEr5qsEK_P4M2h-AvcKRvMrqp6JLu3tQeSdhBGlLoEehABWp8fQqAASHlKEOLRe8znx2qSUMmdDfry8OzDYhYzcUf14YhneoEBv-HOLKykqaPWSlE-7Mkc6BY5gnoyHqznhP5cK6q-jCIUJVBRbKeTrWT2SiwZUcv89nDJL-YAMXG3GL_POPYr_TvTueRfyQL4xm1TRmZQ.fLEVOYxpohATkcNIolk1UA" + testCompactSerializedNestedJWEWithECDSA = "eyJhbGciOiJFQ0RILUVTK0ExMjhLVyIsImN0eSI6IkpXVCIsImVuYyI6IkExMjhHQ00iLCJlcGsiOnsia3R5IjoiRUMiLCJjcnYiOiJQLTI1NiIsIngiOiJRcmdXc2wwN3lFclFjaXFMM1dKZE5FU015S01SR3d2bWRnYmlYNmRYSmpjIiwieSI6Ikl5MWRpellaZFVSTnpKS1FaVlZPVnkwSVVKWjhLaWVyN19LSllJS1hzaTAifSwia2lkIjoidGVzdC1lY2RzYS1lbmMiLCJ0eXAiOiJKV1QifQ.IDkNo5DGa6VQrL8ReJrBVixYN0S_VYYE.soDQKUuDrakfrQer.aaQWcWYoUF8ISUQ_EvkRCa75GeLFMsSK3imQjc3T0OalHsCIEXYCoV_vmjDPTd4svswMQtTiZxeajevnsBl_dtaEmykjqshxHww-07r36RhWKlix3gSTJTKUvGAhFDl24HcLnWUkZZjh0Vw0G9hLidax1OGoc43Rh08aHJ3swbj6yOA-KH-0SIBBmeK1Mfb0-1I4LRdAkCeyy4P6y8z2TvEqAtfCFDfAs5O8Zm9yVA6sxFAB5l6dK4WdotOh4F8lu-vE6MfD67Qi8xiW92ccYX7fBliNyypkDaN3B1k25N374qGXYl0_z0cX2T5ba_doVYgFNDFp.bmjYOk_ZXNTN2yZQjpFLjw" +) + +func init() { + var ( + key any + err error + ) + + if key, err = x509.ParsePKCS8PrivateKey(testKeyBytesSigRSA); err != nil { + panic(err) + } + + switch k := key.(type) { + case *rsa.PrivateKey: + testKeySigRSA = jose.JSONWebKey{ + Key: k, + KeyID: "test-rsa-sig", + Use: JSONWebTokenUseSignature, + Algorithm: string(jose.RS256), + } + testKeyPublicSigRSA = jose.JSONWebKey{ + Key: k.Public(), + KeyID: "test-rsa-sig", + Use: JSONWebTokenUseSignature, + Algorithm: string(jose.RS256), + } + default: + panic("unsupported private key") + } + + if key, err = x509.ParsePKCS8PrivateKey(testKeyBytesEncRSA); err != nil { + panic(err) + } + + switch k := key.(type) { + case *rsa.PrivateKey: + testKeyEncRSA = jose.JSONWebKey{ + Key: k, + KeyID: "test-rsa-enc", + Use: JSONWebTokenUseEncryption, + Algorithm: string(jose.RSA_OAEP_256), + } + testKeyPublicEncRSA = jose.JSONWebKey{ + Key: k.Public(), + KeyID: "test-rsa-enc", + Use: JSONWebTokenUseEncryption, + Algorithm: string(jose.RSA_OAEP_256), + } + default: + panic("unsupported private key") + } + + if key, err = x509.ParsePKCS8PrivateKey(testKeyBytesSigECDSA); err != nil { + panic(err) + } + + switch k := key.(type) { + case *ecdsa.PrivateKey: + testKeySigECDSA = jose.JSONWebKey{ + Key: k, + KeyID: "test-ecdsa-sig", + Use: JSONWebTokenUseSignature, + Algorithm: string(jose.ES256), + } + testKeyPublicSigECDSA = jose.JSONWebKey{ + Key: k.Public(), + KeyID: "test-ecdsa-sig", + Use: JSONWebTokenUseSignature, + Algorithm: string(jose.ES256), + } + default: + panic("unsupported private key") + } + + if key, err = x509.ParsePKCS8PrivateKey(testKeyBytesEncECDSA); err != nil { + panic(err) + } + + switch k := key.(type) { + case *ecdsa.PrivateKey: + testKeyEncECDSA = jose.JSONWebKey{ + Key: k, + KeyID: "test-ecdsa-enc", + Use: JSONWebTokenUseEncryption, + Algorithm: string(jose.ECDH_ES_A128KW), + } + testKeyPublicEncECDSA = jose.JSONWebKey{ + Key: k.Public(), + KeyID: "test-ecdsa-enc", + Use: JSONWebTokenUseEncryption, + Algorithm: string(jose.ECDH_ES_A128KW), + } + default: + panic("unsupported private key") + } +} + +func TestIniit(t *testing.T) { + claims := MapClaims{ + "iss": "example.com", + "sub": "john", + "iat": time.Now().UTC().Unix(), + "exp": time.Now().Add(time.Hour * 24 * 365 * 40).UTC().Unix(), + } + + out, _, err := EncodeNestedCompactEncrypted(context.TODO(), claims, &Headers{}, &Headers{}, &testKeySigECDSA, &testKeyPublicEncECDSA, jose.A128GCM) + + fmt.Println(err) + fmt.Println(out) +} diff --git a/token/jwt/jwt_test.go b/token/jwt/jwt_test.go deleted file mode 100644 index 42236f31..00000000 --- a/token/jwt/jwt_test.go +++ /dev/null @@ -1,214 +0,0 @@ -// Copyright © 2023 Ory Corp -// SPDX-License-Identifier: Apache-2.0 - -package jwt - -import ( - "context" - "fmt" - "strings" - "testing" - "time" - - "github.com/go-jose/go-jose/v4" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" - - "authelia.com/provider/oauth2/internal/gen" -) - -var header = &Headers{ - Extra: map[string]any{ - "foo": "bar", - }, -} - -func TestHash(t *testing.T) { - for k, tc := range []struct { - d string - strategy Signer - }{ - { - d: "RS256", - strategy: &DefaultSigner{GetPrivateKey: func(_ context.Context) (any, error) { - return gen.MustRSAKey(), nil - }}, - }, - { - d: "ES256", - strategy: &DefaultSigner{GetPrivateKey: func(_ context.Context) (any, error) { - return gen.MustES256Key(), nil - }}, - }, - } { - t.Run(fmt.Sprintf("case=%d/strategy=%s", k, tc.d), func(t *testing.T) { - in := []byte("foo") - out, err := tc.strategy.Hash(context.TODO(), in) - assert.NoError(t, err) - assert.NotEqual(t, in, out) - }) - } -} - -func TestAssign(t *testing.T) { - for k, c := range [][]map[string]any{ - { - {"foo": "bar"}, - {"baz": "bar"}, - {"foo": "bar", "baz": "bar"}, - }, - { - {"foo": "bar"}, - {"foo": "baz"}, - {"foo": "bar"}, - }, - { - {}, - {"foo": "baz"}, - {"foo": "baz"}, - }, - { - {"foo": "bar"}, - {"foo": "baz", "bar": "baz"}, - {"foo": "bar", "bar": "baz"}, - }, - } { - assert.EqualValues(t, c[2], assign(c[0], c[1]), "Case %d", k) - } -} - -func TestGenerateJWT(t *testing.T) { - var key any = gen.MustRSAKey() - for k, tc := range []struct { - d string - strategy Signer - resetKey func(strategy Signer) - }{ - { - d: "DefaultSigner", - strategy: &DefaultSigner{ - GetPrivateKey: func(_ context.Context) (any, error) { - return key, nil - }, - }, - resetKey: func(strategy Signer) { - key = gen.MustRSAKey() - }, - }, - { - d: "ES256JWTStrategy", - strategy: &DefaultSigner{ - GetPrivateKey: func(_ context.Context) (any, error) { - return key, nil - }, - }, - resetKey: func(strategy Signer) { - key = &jose.JSONWebKey{ - Key: gen.MustES521Key(), - Algorithm: "ES512", - } - }, - }, - { - d: "ES256JWTStrategy", - strategy: &DefaultSigner{ - GetPrivateKey: func(_ context.Context) (any, error) { - return key, nil - }, - }, - resetKey: func(strategy Signer) { - key = gen.MustES256Key() - }, - }, - } { - t.Run(fmt.Sprintf("case=%d/strategy=%s", k, tc.d), func(t *testing.T) { - claims := &JWTClaims{ - ExpiresAt: time.Now().UTC().Add(time.Hour), - } - - token, sig, err := tc.strategy.Generate(context.TODO(), claims.ToMapClaims(), header) - require.NoError(t, err) - require.NotNil(t, token) - assert.NotEmpty(t, sig) - - sig, err = tc.strategy.Validate(context.TODO(), token) - require.NoError(t, err) - assert.NotEmpty(t, sig) - - sig, err = tc.strategy.Validate(context.TODO(), token+"."+"0123456789") - require.Error(t, err) - assert.Empty(t, sig) - - partToken := strings.Split(token, ".")[2] - - sig, err = tc.strategy.Validate(context.TODO(), partToken) - require.Error(t, err) - assert.Empty(t, sig) - - tc.resetKey(tc.strategy) - - claims = &JWTClaims{ - ExpiresAt: time.Now().UTC().Add(-time.Hour), - } - - token, sig, err = tc.strategy.Generate(context.TODO(), claims.ToMapClaims(), header) - require.NoError(t, err) - require.NotNil(t, token) - assert.NotEmpty(t, sig) - - sig, err = tc.strategy.Validate(context.TODO(), token) - require.Error(t, err) - require.Empty(t, sig) - - claims = &JWTClaims{ - NotBefore: time.Now().UTC().Add(time.Hour), - } - - token, sig, err = tc.strategy.Generate(context.TODO(), claims.ToMapClaims(), header) - require.NoError(t, err) - require.NotNil(t, token) - assert.NotEmpty(t, sig) - - sig, err = tc.strategy.Validate(context.TODO(), token) - require.Error(t, err) - require.Empty(t, sig, "%s", err) - }) - } -} - -func TestValidateSignatureRejectsJWT(t *testing.T) { - for k, tc := range []struct { - d string - strategy Signer - }{ - { - d: "RS256", - strategy: &DefaultSigner{GetPrivateKey: func(_ context.Context) (any, error) { - return gen.MustRSAKey(), nil - }, - }, - }, - { - d: "ES256", - strategy: &DefaultSigner{ - GetPrivateKey: func(_ context.Context) (any, error) { - return gen.MustES256Key(), nil - }, - }, - }, - } { - t.Run(fmt.Sprintf("case=%d/strategy=%s", k, tc.d), func(t *testing.T) { - for k, c := range []string{ - "", - " ", - "foo.bar", - "foo.", - ".foo", - } { - _, err := tc.strategy.Validate(context.TODO(), c) - assert.Error(t, err) - t.Logf("Passed test case %d", k) - } - }) - } -} diff --git a/token/jwt/token.go b/token/jwt/token.go index 150b5da3..2dd97764 100644 --- a/token/jwt/token.go +++ b/token/jwt/token.go @@ -17,41 +17,138 @@ import ( "authelia.com/provider/oauth2/x/errorsx" ) -// Token represets a JWT Token -// This token provide an adaptation to -// transit from [jwt-go](https://github.com/dgrijalva/jwt-go) -// to [go-jose](https://github.com/square/go-jose) -// It provides method signatures compatible with jwt-go but implemented -// using go-json -type Token struct { - Header map[string]any // The first segment of the token - Claims MapClaims // The second segment of the token - Method jose.SignatureAlgorithm - valid bool +// New returns a new Token. +func New() *Token { + return &Token{ + Header: map[string]any{}, + HeaderJWE: map[string]any{}, + } } -const ( - SigningMethodNone = jose.SignatureAlgorithm(consts.JSONWebTokenAlgNone) - // This key should be use to correctly sign and verify alg:none JWT tokens - UnsafeAllowNoneSignatureType unsafeNoneMagicConstant = "none signing method allowed" +// NewWithClaims creates an unverified Token with the given claims and signing method +func NewWithClaims(alg jose.SignatureAlgorithm, claims MapClaims) *Token { + return &Token{ + Claims: claims, + SignatureAlgorithm: alg, + Header: map[string]any{}, + HeaderJWE: map[string]any{}, + } +} - JWTHeaderType = jose.HeaderKey(consts.JSONWebTokenHeaderType) -) +// Parse is an overload for ParseCustom which accepts all normal algs including 'none'. +func Parse(tokenString string, keyFunc Keyfunc) (*Token, error) { + return ParseCustom(tokenString, keyFunc, SignatureAlgorithmsNone...) +} -const ( - JWTHeaderKeyValueType = consts.JSONWebTokenHeaderType -) +// ParseCustom parses, validates, and returns a token. The keyFunc will receive the parsed token and should +// return the key for validating. If everything is kosher, err will be nil. +func ParseCustom(tokenString string, keyFunc Keyfunc, algs ...jose.SignatureAlgorithm) (token *Token, err error) { + return ParseCustomWithClaims(tokenString, MapClaims{}, keyFunc, algs...) +} -const ( - JWTHeaderTypeValueJWT = consts.JSONWebTokenTypeJWT - JWTHeaderTypeValueAccessTokenJWT = consts.JSONWebTokenTypeAccessToken -) +// ParseWithClaims is an overload for ParseCustomWithClaims which accepts all normal algs including 'none'. +func ParseWithClaims(tokenString string, claims MapClaims, keyFunc Keyfunc) (token *Token, err error) { + return ParseCustomWithClaims(tokenString, claims, keyFunc, SignatureAlgorithmsNone...) +} + +// ParseCustomWithClaims parses, validates, and returns a token with its respective claims. The keyFunc will receive the parsed token and should +// return the key for validating. If everything is kosher, err will be nil. +func ParseCustomWithClaims(tokenString string, claims MapClaims, keyFunc Keyfunc, algs ...jose.SignatureAlgorithm) (token *Token, err error) { + var parsed *jwt.JSONWebToken + + if parsed, err = jwt.ParseSigned(tokenString, algs); err != nil { + return &Token{Claims: MapClaims(nil)}, &ValidationError{Errors: ValidationErrorMalformed, Inner: err} + } + + // fill unverified claims + // This conversion is required because go-jose supports + // only marshalling structs or maps but not alias types from maps + // + // The KeyFunc(*Token) function requires the claims to be set into the + // Token, that is an unverified token, therefore an UnsafeClaimsWithoutVerification is done first + // then with the returned key, the claims gets verified. + if err = parsed.UnsafeClaimsWithoutVerification(&claims); err != nil { + return &Token{Claims: MapClaims(nil)}, &ValidationError{Errors: ValidationErrorClaimsInvalid, Inner: err} + } + + // creates an unsafe token + if token, err = newToken(parsed, claims); err != nil { + return &Token{Claims: MapClaims(nil)}, err + } + + if keyFunc == nil { + return token, &ValidationError{Errors: ValidationErrorUnverifiable, text: "no Keyfunc was provided."} + } + + var key any + + if key, err = keyFunc(token); err != nil { + // keyFunc returned an error + var ve *ValidationError + + if errors.As(err, &ve) { + return token, ve + } + + return token, &ValidationError{Errors: ValidationErrorUnverifiable, Inner: err} + } + + if key == nil { + return token, &ValidationError{Errors: ValidationErrorSignatureInvalid, text: "keyfunc returned a nil verification key"} + } + // To verify signature go-jose requires a pointer to + // public key instead of the public key value. + // The pointer values provides that pointer. + // E.g. transform rsa.PublicKey -> *rsa.PublicKey + key = pointer(key) + + // verify signature with returned key + _, validNoneKey := key.(*unsafeNoneMagicConstant) + isSignedToken := !(token.SignatureAlgorithm == SigningMethodNone && validNoneKey) + if isSignedToken { + if err = parsed.Claims(key, &claims); err != nil { + return token, &ValidationError{Errors: ValidationErrorSignatureInvalid, text: err.Error()} + } + } + + // Validate claims + // This validation is performed to be backwards compatible + // with jwt-go library behavior + if err = claims.Valid(); err != nil { + if e, ok := err.(*ValidationError); !ok { + err = &ValidationError{Inner: e, Errors: ValidationErrorClaimsInvalid} + } + + return token, err + } -type unsafeNoneMagicConstant string + token.valid = true + + return token, nil +} + +// Token represets a JWT Token. +type Token struct { + KeyID string + SignatureAlgorithm jose.SignatureAlgorithm // alg (JWS) + EncryptionKeyID string + KeyAlgorithm jose.KeyAlgorithm // alg (JWE) + ContentEncryption jose.ContentEncryption // enc (JWE) + CompressionAlgorithm jose.CompressionAlgorithm // zip (JWE) + + Header map[string]any + HeaderJWE map[string]any + + Claims Claims -// Valid informs if the token was verified against a given verification key + parsedToken *jwt.JSONWebToken + + valid bool +} + +// IsSignatureValid informs if the token was verified against a given verification key // and claims are valid -func (t *Token) Valid() bool { +func (t *Token) IsSignatureValid() bool { return t.valid } @@ -60,198 +157,411 @@ func (t *Token) Valid() bool { // // > For a type to be a Claims object, it must just have a Valid method that determines // if the token is invalid for any supported reason -type Claims interface { - Valid() error +// type Claims interface { +// Valid() error +//} + +func (t *Token) toSignedJoseHeader() (header map[jose.HeaderKey]any) { + header = map[jose.HeaderKey]any{ + JSONWebTokenHeaderType: JSONWebTokenTypeJWT, + } + + for k, v := range t.Header { + header[jose.HeaderKey(k)] = v + } + + return header } -// NewWithClaims creates an unverified Token with the given claims and signing method -func NewWithClaims(method jose.SignatureAlgorithm, claims MapClaims) *Token { - return &Token{ - Claims: claims, - Method: method, - Header: map[string]any{}, +func (t *Token) toEncryptedJoseHeader() (header map[jose.HeaderKey]any) { + header = map[jose.HeaderKey]any{ + JSONWebTokenHeaderType: JSONWebTokenTypeJWT, } + + if cty, ok := t.Header[JSONWebTokenHeaderType]; ok { + header[JSONWebTokenHeaderContentType] = cty + } + + for k, v := range t.HeaderJWE { + header[jose.HeaderKey(k)] = v + } + + return header +} + +// SetJWS sets the JWS output values. +func (t *Token) SetJWS(header Mapper, claims Claims, kid string, alg jose.SignatureAlgorithm) { + assign(t.Header, header.ToMap()) + + t.KeyID = kid + t.SignatureAlgorithm = alg + + t.Claims = claims } -func (t *Token) toJoseHeader() map[jose.HeaderKey]any { - h := map[jose.HeaderKey]any{ - JWTHeaderType: JWTHeaderTypeValueJWT, +// SetJWE sets the JWE output values. +func (t *Token) SetJWE(header Mapper, kid string, alg jose.KeyAlgorithm, enc jose.ContentEncryption, zip jose.CompressionAlgorithm) { + assign(t.HeaderJWE, header.ToMap()) + + t.EncryptionKeyID = kid + t.KeyAlgorithm = alg + t.ContentEncryption = enc + t.CompressionAlgorithm = zip +} + +// AssignJWE assigns values derived from the JWE decryption process to the Token. +func (t *Token) AssignJWE(jwe *jose.JSONWebEncryption) { + if jwe == nil { + return } - for k, v := range t.Header { - h[jose.HeaderKey(k)] = v + + t.HeaderJWE = map[string]any{ + JSONWebTokenHeaderAlgorithm: jwe.Header.Algorithm, + } + + if jwe.Header.KeyID != "" { + t.HeaderJWE[JSONWebTokenHeaderKeyIdentifier] = jwe.Header.KeyID + t.EncryptionKeyID = jwe.Header.KeyID + } + + for header, value := range jwe.Header.ExtraHeaders { + h := string(header) + + t.HeaderJWE[h] = value + + switch h { + case JSONWebTokenHeaderEncryptionAlgorithm: + if v, ok := value.(string); ok { + t.ContentEncryption = jose.ContentEncryption(v) + } + case JSONWebTokenHeaderCompressionAlgorithm: + if v, ok := value.(string); ok { + t.CompressionAlgorithm = jose.CompressionAlgorithm(v) + } + } } - return h + + t.KeyAlgorithm = jose.KeyAlgorithm(jwe.Header.Algorithm) } -// SignedString provides a compatible `jwt-go` Token.SignedString method +// CompactEncrypted serializes this token as a Compact Encrypted string, and returns the token string, signature, and +// an error if one occurred. +func (t *Token) CompactEncrypted(keySig, keyEnc any) (tokenString, signature string, err error) { + var ( + signed string + ) + + if signed, signature, err = t.CompactSigned(keySig); err != nil { + return "", "", err + } + + rcpt := jose.Recipient{ + Algorithm: t.KeyAlgorithm, + Key: keyEnc, + } + + opts := &jose.EncrypterOptions{ + Compression: t.CompressionAlgorithm, + ExtraHeaders: t.toEncryptedJoseHeader(), + } + + if _, ok := opts.ExtraHeaders[JSONWebTokenHeaderContentType]; !ok { + var typ any + + if typ, ok = t.Header[JSONWebTokenHeaderType]; ok { + opts.ExtraHeaders[JSONWebTokenHeaderContentType] = typ + } else { + opts.ExtraHeaders[JSONWebTokenHeaderContentType] = JSONWebTokenTypeJWT + } + } + + var encrypter jose.Encrypter + + if encrypter, err = jose.NewEncrypter(t.ContentEncryption, rcpt, opts); err != nil { + return "", "", errorsx.WithStack(err) + } + + var token *jose.JSONWebEncryption + + if token, err = encrypter.Encrypt([]byte(signed)); err != nil { + return "", "", errorsx.WithStack(err) + } + + if tokenString, err = token.CompactSerialize(); err != nil { + return "", "", errorsx.WithStack(err) + } + + return tokenString, signature, nil +} + +// CompactSigned serializes this token as a Compact Signed string, and returns the token string, signature, and +// an error if one occurred. +func (t *Token) CompactSigned(k any) (tokenString, signature string, err error) { + if tokenString, err = t.CompactSignedString(k); err != nil { + return "", "", err + } + + if signature, err = getJWTSignature(tokenString); err != nil { + return "", "", err + } + + return tokenString, signature, nil +} + +// CompactSignedString provides a compatible `jwt-go` Token.CompactSigned method // // > Get the complete, signed token -func (t *Token) SignedString(k any) (rawToken string, err error) { - if _, ok := k.(unsafeNoneMagicConstant); ok { - rawToken, err = unsignedToken(t) - return +func (t *Token) CompactSignedString(k any) (tokenString string, err error) { + if isUnsafeNoneMagicConstant(k) { + return unsignedToken(t) } - var signer jose.Signer - key := jose.SigningKey{ - Algorithm: t.Method, + Algorithm: t.SignatureAlgorithm, Key: k, } - opts := &jose.SignerOptions{ExtraHeaders: t.toJoseHeader()} - signer, err = jose.NewSigner(key, opts) - if err != nil { - err = errorsx.WithStack(err) - return + + opts := &jose.SignerOptions{ExtraHeaders: t.toSignedJoseHeader()} + + var signer jose.Signer + + if signer, err = jose.NewSigner(key, opts); err != nil { + return "", errorsx.WithStack(err) } // A explicit conversion from type alias MapClaims // to map[string]any is required because the // go-jose CompactSerialize() only support explicit maps // as claims or structs but not type aliases from maps. - claims := map[string]any(t.Claims) - rawToken, err = jwt.Signed(signer).Claims(claims).Serialize() - if err != nil { - err = &ValidationError{Errors: ValidationErrorClaimsInvalid, Inner: err} - return + // claims := t.Claims.ToMapClaims() + + if tokenString, err = jwt.Signed(signer).Claims(t.Claims.ToMapClaims().ToMap()).Serialize(); err != nil { + return "", &ValidationError{Errors: ValidationErrorClaimsInvalid, Inner: err} } - return + + return tokenString, nil } -func unsignedToken(t *Token) (string, error) { - t.Header[consts.JSONWebTokenHeaderAlgorithm] = consts.JSONWebTokenAlgNone - if _, ok := t.Header[string(JWTHeaderType)]; !ok { - t.Header[string(JWTHeaderType)] = JWTHeaderTypeValueJWT +// Valid validates the token headers given various input options. This does not validate any claims. +func (t *Token) Valid(opts ...HeaderValidationOption) (err error) { + vopts := &HeaderValidationOptions{ + types: []string{JSONWebTokenTypeJWT}, } - hbytes, err := json.Marshal(&t.Header) - if err != nil { - return "", errorsx.WithStack(err) + + for _, opt := range opts { + opt(vopts) } - bbytes, err := json.Marshal(&t.Claims) - if err != nil { - return "", errorsx.WithStack(err) + + vErr := new(ValidationError) + + if !t.valid { + vErr.Inner = errors.New("token has an invalid or unverified signature") + vErr.Errors |= ValidationErrorSignatureInvalid } - h := base64.RawURLEncoding.EncodeToString(hbytes) - b := base64.RawURLEncoding.EncodeToString(bbytes) - return fmt.Sprintf("%v.%v.", h, b), nil -} -func newToken(parsedToken *jwt.JSONWebToken, claims MapClaims) (*Token, error) { - token := &Token{Claims: claims} - if len(parsedToken.Headers) != 1 { - return nil, &ValidationError{text: fmt.Sprintf("only one header supported, got %v", len(parsedToken.Headers)), Errors: ValidationErrorMalformed} + if t.HeaderJWE != nil && (t.KeyAlgorithm != "" || t.ContentEncryption != "") { + var ( + typ any + ok bool + ) + + if typ, ok = t.HeaderJWE[JSONWebTokenHeaderType]; !ok || typ != JSONWebTokenTypeJWT { + vErr.Inner = errors.New("token was encrypted with invalid typ") + vErr.Errors |= ValidationErrorHeaderEncryptionTypeInvalid + } + + ttyp := t.Header[JSONWebTokenHeaderType] + cty := t.HeaderJWE[JSONWebTokenHeaderContentType] + + if cty != ttyp { + vErr.Inner = errors.New("token was encrypted with a cty value that doesn't match the typ value") + vErr.Errors |= ValidationErrorHeaderContentTypeInvalidMismatch + } + + if len(vopts.types) != 0 { + if !validateTokenTypeValue(vopts.types, cty) { + vErr.Inner = errors.New("token was encrypted with an invalid cty") + vErr.Errors |= ValidationErrorHeaderContentTypeInvalid + } + } } - // copy headers - h := parsedToken.Headers[0] - token.Header = map[string]any{ - consts.JSONWebTokenHeaderAlgorithm: h.Algorithm, + if len(vopts.types) != 0 { + if !validateTokenType(vopts.types, t.Header) { + vErr.Inner = errors.New("token was signed with an invalid typ") + vErr.Errors |= ValidationErrorHeaderTypeInvalid + } } - if h.KeyID != "" { - token.Header[consts.JSONWebTokenHeaderKeyIdentifier] = h.KeyID + + if len(vopts.alg) != 0 { + if vopts.alg != string(t.SignatureAlgorithm) { + vErr.Inner = errors.New("token was signed with an invalid alg") + vErr.Errors |= ValidationErrorHeaderAlgorithmInvalid + } } - for k, v := range h.ExtraHeaders { - token.Header[string(k)] = v + + if len(vopts.kid) != 0 { + if vopts.kid != t.KeyID { + vErr.Inner = errors.New("token was signed with an invalid kid") + vErr.Errors |= ValidationErrorHeaderKeyIDInvalid + } + } + + if len(vopts.keyAlg) != 0 && len(t.KeyAlgorithm) != 0 { + if vopts.keyAlg != string(t.KeyAlgorithm) { + vErr.Inner = errors.New("token was encrypted with an invalid alg") + vErr.Errors |= ValidationErrorHeaderKeyAlgorithmInvalid + } } - token.Method = jose.SignatureAlgorithm(h.Algorithm) + if len(vopts.contentEnc) != 0 && len(t.ContentEncryption) != 0 { + if vopts.contentEnc != string(t.ContentEncryption) { + vErr.Inner = errors.New("token was encrypted with an invalid enc") + vErr.Errors |= ValidationErrorHeaderContentEncryptionInvalid + } + } - return token, nil + if len(vopts.kidEnc) != 0 && len(t.EncryptionKeyID) != 0 { + if vopts.kidEnc != t.EncryptionKeyID { + vErr.Inner = errors.New("token was encrypted with an invalid kid") + vErr.Errors |= ValidationErrorHeaderEncryptionKeyIDInvalid + } + } + + if vErr.valid() { + return nil + } + + return vErr } -// Keyfunc is used by parsing methods to supply the key for verification. The function receives the parsed, but -// unverified Token. This allows you to use properties in the Header of the token (such as `kid`) to identify which key -// to use. -type Keyfunc func(*Token) (any, error) +// IsJWTProfileAccessToken returns true if the token is a JWT Profile Access Token. +func (t *Token) IsJWTProfileAccessToken() (ok bool) { + var ( + raw any + cty, typ string + ) -// Parse is an overload for ParseCustom which accepts all normal algs including 'none'. -func Parse(tokenString string, keyFunc Keyfunc) (*Token, error) { - return ParseCustom(tokenString, keyFunc, consts.JSONWebTokenAlgNone, jose.HS256, jose.HS384, jose.HS512, jose.RS256, jose.RS384, jose.RS512, jose.PS256, jose.PS384, jose.PS512, jose.ES256, jose.ES384, jose.ES512) + if t.HeaderJWE != nil && len(t.HeaderJWE) > 0 { + if raw, ok = t.HeaderJWE[JSONWebTokenHeaderContentType]; ok { + cty, ok = raw.(string) + + if !ok { + return false + } + + if cty != JSONWebTokenTypeAccessToken && cty != JSONWebTokenTypeAccessTokenAlternative { + return false + } + } + } + + if raw, ok = t.Header[JSONWebTokenHeaderType]; !ok { + return false + } + + typ, ok = raw.(string) + + return ok && (typ == JSONWebTokenTypeAccessToken || typ == JSONWebTokenTypeAccessTokenAlternative) } -// ParseCustom parses, validates, and returns a token. The keyFunc will receive the parsed token and should -// return the key for validating. If everything is kosher, err will be nil. -func ParseCustom(tokenString string, keyFunc Keyfunc, algs ...jose.SignatureAlgorithm) (*Token, error) { - return ParseCustomWithClaims(tokenString, MapClaims{}, keyFunc, algs...) +type HeaderValidationOption func(opts *HeaderValidationOptions) + +type HeaderValidationOptions struct { + types []string + alg string + kid string + kidEnc string + keyAlg string + contentEnc string } -// ParseWithClaims is an overload for ParseCustomWithClaims which accepts all normal algs including 'none'. -func ParseWithClaims(rawToken string, claims MapClaims, keyFunc Keyfunc) (*Token, error) { - return ParseCustomWithClaims(rawToken, claims, keyFunc, consts.JSONWebTokenAlgNone, jose.HS256, jose.HS384, jose.HS512, jose.RS256, jose.RS384, jose.RS512, jose.PS256, jose.PS384, jose.PS512, jose.ES256, jose.ES384, jose.ES512) +func ValidateTypes(types ...string) HeaderValidationOption { + return func(validator *HeaderValidationOptions) { + validator.types = types + } } -// ParseCustomWithClaims parses, validates, and returns a token with its respective claims. The keyFunc will receive the parsed token and should -// return the key for validating. If everything is kosher, err will be nil. -func ParseCustomWithClaims(rawToken string, claims MapClaims, keyFunc Keyfunc, algs ...jose.SignatureAlgorithm) (*Token, error) { - // Parse the token. - parsedToken, err := jwt.ParseSigned(rawToken, algs) - if err != nil { - return &Token{}, &ValidationError{Errors: ValidationErrorMalformed, text: err.Error()} +func ValidateKeyID(kid string) HeaderValidationOption { + return func(validator *HeaderValidationOptions) { + validator.kid = kid } +} - // fill unverified claims - // This conversion is required because go-jose supports - // only marshalling structs or maps but not alias types from maps - // - // The KeyFunc(*Token) function requires the claims to be set into the - // Token, that is an unverified token, therefore an UnsafeClaimsWithoutVerification is done first - // then with the returned key, the claims gets verified. - if err := parsedToken.UnsafeClaimsWithoutVerification(&claims); err != nil { - return nil, &ValidationError{Errors: ValidationErrorClaimsInvalid, text: err.Error()} +func ValidateAlgorithm(alg string) HeaderValidationOption { + return func(validator *HeaderValidationOptions) { + validator.alg = alg } +} - // creates an unsafe token - token, err := newToken(parsedToken, claims) - if err != nil { - return nil, err +func ValidateEncryptionKeyID(kid string) HeaderValidationOption { + return func(validator *HeaderValidationOptions) { + validator.kidEnc = kid } +} - if keyFunc == nil { - return token, &ValidationError{Errors: ValidationErrorUnverifiable, text: "no Keyfunc was provided."} +func ValidateKeyAlgorithm(alg string) HeaderValidationOption { + return func(validator *HeaderValidationOptions) { + validator.keyAlg = alg } +} - // Call keyFunc callback to get verification key - verificationKey, err := keyFunc(token) - if err != nil { - // keyFunc returned an error - var ve *ValidationError +func ValidateContentEncryption(enc string) HeaderValidationOption { + return func(validator *HeaderValidationOptions) { + validator.contentEnc = enc + } +} - if errors.As(err, &ve) { - return token, ve - } +func unsignedToken(token *Token) (tokenString string, err error) { + token.Header[JSONWebTokenHeaderAlgorithm] = JSONWebTokenAlgNone - return token, &ValidationError{Errors: ValidationErrorUnverifiable, Inner: err} + if _, ok := token.Header[JSONWebTokenHeaderType]; !ok { + token.Header[JSONWebTokenHeaderType] = JSONWebTokenTypeJWT } - if verificationKey == nil { - return token, &ValidationError{Errors: ValidationErrorSignatureInvalid, text: "keyfunc returned a nil verification key"} + + var ( + hbytes, bbytes []byte + ) + + if hbytes, err = json.Marshal(&token.Header); err != nil { + return "", errorsx.WithStack(err) } - // To verify signature go-jose requires a pointer to - // public key instead of the public key value. - // The pointer values provides that pointer. - // E.g. transform rsa.PublicKey -> *rsa.PublicKey - verificationKey = pointer(verificationKey) - // verify signature with returned key - _, validNoneKey := verificationKey.(*unsafeNoneMagicConstant) - isSignedToken := !(token.Method == SigningMethodNone && validNoneKey) - if isSignedToken { - if err := parsedToken.Claims(verificationKey, &claims); err != nil { - return token, &ValidationError{Errors: ValidationErrorSignatureInvalid, text: err.Error()} - } + if bbytes, err = json.Marshal(&token.Claims); err != nil { + return "", errorsx.WithStack(err) } - // Validate claims - // This validation is performed to be backwards compatible - // with jwt-go library behavior - if err := claims.Valid(); err != nil { - if e, ok := err.(*ValidationError); !ok { - err = &ValidationError{Inner: e, Errors: ValidationErrorClaimsInvalid} - } - return token, err + return fmt.Sprintf("%s.%s.", base64.RawURLEncoding.EncodeToString(hbytes), base64.RawURLEncoding.EncodeToString(bbytes)), nil +} + +func newToken(parsedToken *jwt.JSONWebToken, claims MapClaims) (*Token, error) { + token := &Token{Claims: claims, parsedToken: parsedToken} + + if token.Claims == nil { + token.Claims = MapClaims{} + } + + if len(parsedToken.Headers) != 1 { + return nil, &ValidationError{text: fmt.Sprintf("only one header supported, got %v", len(parsedToken.Headers)), Errors: ValidationErrorMalformed} + } + + // copy headers + h := parsedToken.Headers[0] + token.Header = map[string]any{ + JSONWebTokenHeaderAlgorithm: h.Algorithm, + } + + token.SignatureAlgorithm = jose.SignatureAlgorithm(h.Algorithm) + + if h.KeyID != "" { + token.Header[consts.JSONWebTokenHeaderKeyIdentifier] = h.KeyID + token.KeyID = h.KeyID + } + + for k, v := range h.ExtraHeaders { + token.Header[string(k)] = v } - // set token as verified and validated - token.valid = true return token, nil } @@ -265,3 +575,52 @@ func pointer(v any) any { } return v } + +func validateTokenType(values []string, header map[string]any) bool { + var ( + raw any + ok bool + ) + + if raw, ok = header[consts.JSONWebTokenHeaderType]; !ok { + return false + } + + return validateTokenTypeValue(values, raw) +} + +func validateTokenTypeValue(values []string, raw any) bool { + var ( + typ string + ok bool + ) + + if typ, ok = raw.(string); !ok { + return false + } + + for _, t := range values { + if t == typ { + return true + } + } + + return false +} + +func isUnsafeNoneMagicConstant(k any) bool { + switch key := k.(type) { + case unsafeNoneMagicConstant: + return true + case jose.JSONWebKey: + if _, ok := key.Key.(unsafeNoneMagicConstant); ok { + return true + } + case *jose.JSONWebKey: + if _, ok := key.Key.(unsafeNoneMagicConstant); ok { + return true + } + } + + return false +} diff --git a/token/jwt/token_test.go b/token/jwt/token_test.go index 81cb8fcc..9169170a 100644 --- a/token/jwt/token_test.go +++ b/token/jwt/token_test.go @@ -18,7 +18,6 @@ import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" - "authelia.com/provider/oauth2/internal/consts" "authelia.com/provider/oauth2/internal/gen" ) @@ -49,16 +48,16 @@ func TestUnsignedToken(t *testing.T) { "sub": "nestor", }) token.Header = tc.jwtHeaders - rawToken, err := token.SignedString(key) + rawToken, err := token.CompactSignedString(key) require.NoError(t, err) require.NotEmpty(t, rawToken) parts := strings.Split(rawToken, ".") require.Len(t, parts, 3) require.Empty(t, parts[2]) - tk, err := jwt.ParseSigned(rawToken, []jose.SignatureAlgorithm{consts.JSONWebTokenAlgNone, jose.HS256, jose.HS384, jose.HS512, jose.RS256, jose.RS384, jose.RS512, jose.PS256, jose.PS384, jose.PS512, jose.ES256, jose.ES384, jose.ES512}) + tk, err := jwt.ParseSigned(rawToken, []jose.SignatureAlgorithm{JSONWebTokenAlgNone, jose.HS256, jose.HS384, jose.HS512, jose.RS256, jose.RS384, jose.RS512, jose.PS256, jose.PS384, jose.PS512, jose.ES256, jose.ES384, jose.ES512}) require.NoError(t, err) require.Len(t, tk.Headers, 1) - require.Equal(t, tc.expectedType, tk.Headers[0].ExtraHeaders[(consts.JSONWebTokenHeaderType)]) + require.Equal(t, tc.expectedType, tk.Headers[0].ExtraHeaders[(JSONWebTokenHeaderType)]) }) } } @@ -72,12 +71,12 @@ func TestJWTHeaders(t *testing.T) { { name: "set JWT as 'typ' when the the type is not specified in the headers", jwtHeaders: map[string]any{}, - expectedType: JWTHeaderTypeValueJWT, + expectedType: JSONWebTokenTypeJWT, }, { name: "'typ' set explicitly", - jwtHeaders: map[string]any{JWTHeaderKeyValueType: JWTHeaderTypeValueAccessTokenJWT}, - expectedType: JWTHeaderTypeValueAccessTokenJWT, + jwtHeaders: map[string]any{JSONWebTokenHeaderType: JSONWebTokenTypeAccessToken}, + expectedType: JSONWebTokenTypeAccessToken, }, } for _, tc := range testCases { @@ -87,7 +86,7 @@ func TestJWTHeaders(t *testing.T) { require.NoError(t, err) require.Len(t, tk.Headers, 1) require.Equal(t, tk.Headers[0].Algorithm, "RS256") - require.Equal(t, tc.expectedType, tk.Headers[0].ExtraHeaders[(JWTHeaderKeyValueType)]) + require.Equal(t, tc.expectedType, tk.Headers[0].ExtraHeaders[(JSONWebTokenHeaderType)]) }) } } @@ -102,7 +101,6 @@ var ( nilKeyFunc Keyfunc = nil ) -// Many test cases where taken from https://github.com/dgrijalva/jwt-go/blob/master/parser_test.go // Test cases related to json.Number where excluded because that is not supported by go-jose, // it is not used here and therefore not supported. // @@ -322,12 +320,12 @@ func TestParser_Parse(t *testing.T) { given: given{ name: "used before issued", generate: &generate{ - claims: MapClaims{"foo": "bar", consts.ClaimIssuedAt: time.Now().Unix() + 500}, + claims: MapClaims{"foo": "bar", ClaimIssuedAt: time.Now().Unix() + 500}, }, }, expected: expected{ keyFunc: defaultKeyFunc, - claims: MapClaims{"foo": "bar", consts.ClaimIssuedAt: time.Now().Unix() + 500}, + claims: MapClaims{"foo": "bar", ClaimIssuedAt: time.Now().Unix() + 500}, valid: false, errors: ValidationErrorIssuedAt, }, @@ -419,7 +417,7 @@ func TestParser_Parse(t *testing.T) { // Figure out correct claims type token, err = ParseWithClaims(data.tokenString, MapClaims{}, data.keyFunc) // Verify result matches expectation - assert.EqualValues(t, data.claims, token.Claims) + assert.EqualValues(t, data.claims, token.Claims.ToMapClaims()) if data.valid && err != nil { t.Errorf("[%v] Error while verifying token: %T:%v", data.name, err, err) } @@ -428,7 +426,7 @@ func TestParser_Parse(t *testing.T) { t.Errorf("[%v] Invalid token passed validation", data.name) } - if (err == nil && !token.Valid()) || (err != nil && token.Valid()) { + if (err == nil && !token.IsSignatureValid()) || (err != nil && token.IsSignatureValid()) { t.Errorf("[%v] Inconsistent behavior between returned error and token.Valid", data.name) } @@ -453,7 +451,7 @@ func TestParser_Parse(t *testing.T) { func makeSampleToken(c MapClaims, m jose.SignatureAlgorithm, key any) string { token := NewWithClaims(m, c) - s, e := token.SignedString(key) + s, e := token.CompactSignedString(key) if e != nil { panic(e.Error()) @@ -465,7 +463,7 @@ func makeSampleToken(c MapClaims, m jose.SignatureAlgorithm, key any) string { func makeSampleTokenWithCustomHeaders(c MapClaims, m jose.SignatureAlgorithm, headers map[string]any, key any) string { token := NewWithClaims(m, c) token.Header = headers - s, e := token.SignedString(key) + s, e := token.CompactSignedString(key) if e != nil { panic(e.Error()) diff --git a/token/jwt/util.go b/token/jwt/util.go new file mode 100644 index 00000000..20e3ccdd --- /dev/null +++ b/token/jwt/util.go @@ -0,0 +1,545 @@ +package jwt + +import ( + "context" + "crypto" + "crypto/aes" + "crypto/sha256" + "crypto/sha512" + "fmt" + "hash" + "reflect" + "regexp" + "strings" + + "github.com/go-jose/go-jose/v4" + "github.com/go-jose/go-jose/v4/jwt" + "github.com/pkg/errors" +) + +var ( + reSignedJWT = regexp.MustCompile(`^[-_A-Za-z0-9]+\.[-_A-Za-z0-9]+\.([-_A-Za-z0-9]+)?$`) + reEncryptedJWT = regexp.MustCompile(`^[-_A-Za-z0-9]+\.[-_A-Za-z0-9]+\.[-_A-Za-z0-9]+\.[-_A-Za-z0-9]+\.[-_A-Za-z0-9]+$`) +) + +// IsSignedJWT returns true if a given token string meets the basic criteria of a compact serialized signed JWT. +func IsSignedJWT(tokenString string) (signed bool) { + return reSignedJWT.MatchString(tokenString) +} + +// IsEncryptedJWT returns true if a given token string meets the basic criteria of a compact serialized encrypted JWT. +func IsEncryptedJWT(tokenString string) (encrypted bool) { + return reEncryptedJWT.MatchString(tokenString) +} + +// IsSignedJWTClientSecretAlgStr returns true if the given alg string is a client secret based signature algorithm. +func IsSignedJWTClientSecretAlgStr(alg string) (csa bool) { + if a := jose.SignatureAlgorithm(alg); IsSignedJWTClientSecretAlg(a) { + return true + } + + return false +} + +func IsSignedJWTClientSecretAlg(alg jose.SignatureAlgorithm) (csa bool) { + switch alg { + case jose.HS256, jose.HS384, jose.HS512: + return true + default: + return false + } +} + +// IsEncryptedJWTClientSecretAlg returns true if a given alg string is a client secret based encryption algorithm +// i.e. symmetric. +func IsEncryptedJWTClientSecretAlg(alg string) (csa bool) { + switch a := jose.KeyAlgorithm(alg); a { + case jose.A128KW, jose.A192KW, jose.A256KW, jose.DIRECT, jose.A128GCMKW, jose.A192GCMKW, jose.A256GCMKW: + return true + default: + return IsEncryptedJWTPasswordBasedAlg(a) + } +} + +// IsEncryptedJWTPasswordBasedAlg returns true if a given jose.KeyAlgorithm is a Password Based Algorithm. +func IsEncryptedJWTPasswordBasedAlg(alg jose.KeyAlgorithm) (pba bool) { + switch alg { + case jose.PBES2_HS256_A128KW, jose.PBES2_HS384_A192KW, jose.PBES2_HS512_A256KW: + return true + default: + return false + } +} + +func headerValidateJWS(headers []jose.Header) (kid, alg string, err error) { + switch len(headers) { + case 1: + break + case 0: + return "", "", fmt.Errorf("jws header is missing") + default: + return "", "", fmt.Errorf("jws header is malformed") + } + + if headers[0].Algorithm == "" && headers[0].KeyID == "" { + return "", "", fmt.Errorf("jws header 'alg' and 'kid' values are missing or empty") + } + + if headers[0].JSONWebKey != nil { + return "", "", fmt.Errorf("jws header 'jwk' value is present but not supported") + } + + return headers[0].KeyID, headers[0].Algorithm, nil +} + +func headerValidateJWE(header jose.Header) (kid, alg, enc, cty string, err error) { + if header.KeyID == "" && !IsEncryptedJWTClientSecretAlg(header.Algorithm) { + return "", "", "", "", fmt.Errorf("jwe header 'kid' value is missing or empty") + } + + if header.Algorithm == "" { + return "", "", "", "", fmt.Errorf("jwe header 'alg' value is missing or empty") + } + + var ( + value any + ok bool + ) + + if IsEncryptedJWTPasswordBasedAlg(jose.KeyAlgorithm(header.Algorithm)) { + if value, ok = header.ExtraHeaders[JSONWebTokenHeaderPBES2Count]; ok { + switch p2c := value.(type) { + case float64: + if p2c > 5000000 { + return "", "", "", "", fmt.Errorf("jwe header 'p2c' has an invalid value '%d': more than 5,000,000", int(p2c)) + } else if p2c < 200000 { + return "", "", "", "", fmt.Errorf("jwe header 'p2c' has an invalid value '%d': less than 200,000", int(p2c)) + } + default: + return "", "", "", "", fmt.Errorf("jwe header 'p2c' value has invalid type %T", p2c) + } + } + } + + if value, ok = header.ExtraHeaders[JSONWebTokenHeaderEncryptionAlgorithm]; ok { + switch encv := value.(type) { + case string: + if encv != "" { + enc = encv + + break + } + + return "", "", "", "", fmt.Errorf("jwe header 'enc' value is empty") + default: + return "", "", "", "", fmt.Errorf("jwe header 'enc' value has invalid type %T", encv) + } + } + + if value, ok = header.ExtraHeaders[JSONWebTokenHeaderContentType]; ok { + cty, _ = value.(string) + } + + return header.KeyID, header.Algorithm, enc, cty, nil +} + +// PrivateKey properly describes crypto.PrivateKey. +type PrivateKey interface { + Public() crypto.PublicKey + Equal(x crypto.PrivateKey) bool +} + +const ( + JWKLookupErrorClientNoJWKS uint32 = 1 << iota +) + +type JWKLookupError struct { + Description string + Errors uint32 // bitfield. see JWKLookupError... constants +} + +func (e *JWKLookupError) GetDescription() string { + return e.Description +} + +func (e *JWKLookupError) Error() string { + return fmt.Sprintf("Error occurred retrieving the JSON Web Key. %s", e.Description) +} + +// FindClientPublicJWK given a BaseClient, JWKSFetcherStrategy, and search parameters will return a *jose.JSONWebKey on +// a valid match. The *jose.JSONWebKey is guaranteed to match the alg and use values, and if strict is true it must +// match the kid value as well. +func FindClientPublicJWK(ctx context.Context, client BaseClient, fetcher JWKSFetcherStrategy, kid, alg, use string, strict bool) (key *jose.JSONWebKey, err error) { + var ( + keys *jose.JSONWebKeySet + ) + + if keys = client.GetJSONWebKeys(); keys != nil { + return SearchJWKS(keys, kid, alg, use, strict) + } + + if location := client.GetJSONWebKeysURI(); len(location) > 0 { + if keys, err = fetcher.Resolve(ctx, location, false); err != nil { + return nil, err + } + + if key, err = SearchJWKS(keys, kid, alg, use, strict); err == nil { + return key, nil + } + + if keys, err = fetcher.Resolve(ctx, location, true); err != nil { + return nil, err + } + + return SearchJWKS(keys, kid, alg, use, strict) + } + + return nil, &JWKLookupError{Description: "No JWKs have been registered for the client"} +} + +func SearchJWKS(jwks *jose.JSONWebKeySet, kid, alg, use string, strict bool) (key *jose.JSONWebKey, err error) { + if len(jwks.Keys) == 0 { + return nil, &JWKLookupError{Description: "The retrieved JSON Web Key Set does not contain any key."} + } + + var keys []jose.JSONWebKey + + if kid == "" { + keys = jwks.Keys + } else { + keys = jwks.Key(kid) + } + + if len(keys) == 0 { + return nil, &JWKLookupError{Description: fmt.Sprintf("The JSON Web Token uses signing key with kid '%s' which was not found", kid)} + } + + var matched []jose.JSONWebKey + + for _, k := range keys { + if k.Use != use { + continue + } + + if k.Algorithm != alg { + continue + } + + matched = append(matched, k) + } + + switch len(matched) { + case 1: + return &matched[0], nil + case 0: + return nil, &JWKLookupError{Description: fmt.Sprintf("Unable to find JSON web key with kid '%s', use '%s', and alg '%s' in JSON Web Key Set", kid, use, alg)} + default: + if strict { + return nil, &JWKLookupError{Description: fmt.Sprintf("Unable to find JSON web key with kid '%s', use '%s', and alg '%s' in JSON Web Key Set", kid, use, alg)} + } + + return &matched[0], nil + } +} + +// NewClientSecretJWKFromClient returns a client secret based JWK from a client. +func NewClientSecretJWKFromClient(ctx context.Context, client BaseClient, kid, alg, enc, use string) (jwk *jose.JSONWebKey, err error) { + var ( + secret []byte + ok bool + ) + + if secret, ok, err = client.GetClientSecretPlainText(); err != nil { + return nil, &JWKLookupError{Description: fmt.Sprintf("The client returned an error while trying to retrieve the plaintext client secret. %s", err.Error())} + } + + if !ok { + return nil, &JWKLookupError{Description: "The client is not configured with a client secret"} + } + + return NewClientSecretJWK(ctx, secret, kid, alg, enc, use) +} + +// NewClientSecretJWK returns a client secret based JWK from a client secret value. +// +// The symmetric encryption key is derived from the client_secret value by using the left-most bits of a truncated +// SHA-2 hash of the octets of the UTF-8 representation of the client_secret. For keys of 256 or fewer bits, SHA-256 +// is used; for keys of 257-384 bits, SHA-384 is used; for keys of 385-512 bits, SHA-512 is used. The hash value MUST +// be truncated retaining the left-most bits to the appropriate bit length for the AES key wrapping or direct +// encryption algorithm used, for instance, truncating the SHA-256 hash to 128 bits for A128KW. If a symmetric key with +// greater than 512 bits is needed, a different method of deriving the key from the client_secret would have to be +// defined by an extension. Symmetric encryption MUST NOT be used by public (non-confidential) Clients because of +// their inability to keep secrets. +func NewClientSecretJWK(ctx context.Context, secret []byte, kid, alg, enc, use string) (jwk *jose.JSONWebKey, err error) { + if len(secret) == 0 { + return nil, &JWKLookupError{Description: "The client is not configured with a client secret that can be used for symmetric algorithms"} + } + + switch use { + case JSONWebTokenUseSignature: + var ( + hasher hash.Hash + ) + + switch jose.SignatureAlgorithm(alg) { + case jose.HS256: + hasher = sha256.New() + case jose.HS384: + hasher = sha512.New384() + case jose.HS512: + hasher = sha512.New() + default: + return nil, &JWKLookupError{Description: fmt.Sprintf("Unsupported algorithm '%s'", alg)} + } + + if _, err = hasher.Write(secret); err != nil { + return nil, &JWKLookupError{Description: fmt.Sprintf("Unable to derive key from hashing the client secret. %s", err.Error())} + } + + return &jose.JSONWebKey{ + Key: hasher.Sum(nil), + KeyID: kid, + Algorithm: alg, + Use: use, + }, nil + case JSONWebTokenUseEncryption: + var ( + hasher hash.Hash + bits int + ) + + keyAlg := jose.KeyAlgorithm(alg) + + switch keyAlg { + case jose.A128KW, jose.A128GCMKW, jose.A192KW, jose.A192GCMKW, jose.A256KW, jose.A256GCMKW, jose.PBES2_HS256_A128KW: + hasher = sha256.New() + case jose.PBES2_HS384_A192KW: + hasher = sha512.New384() + case jose.PBES2_HS512_A256KW, jose.DIRECT: + hasher = sha512.New() + default: + return nil, &JWKLookupError{Description: fmt.Sprintf("Unsupported algorithm '%s'", alg)} + } + + switch keyAlg { + case jose.A128KW, jose.A128GCMKW, jose.PBES2_HS256_A128KW: + bits = aes.BlockSize + case jose.A192KW, jose.A192GCMKW, jose.PBES2_HS384_A192KW: + bits = aes.BlockSize * 1.5 + case jose.A256KW, jose.A256GCMKW, jose.PBES2_HS512_A256KW: + bits = aes.BlockSize * 2 + case jose.DIRECT: + switch jose.ContentEncryption(enc) { + case jose.A128CBC_HS256, "": + bits = aes.BlockSize * 2 + case jose.A192CBC_HS384: + bits = aes.BlockSize * 3 + case jose.A256CBC_HS512: + bits = aes.BlockSize * 4 + default: + return nil, &JWKLookupError{Description: fmt.Sprintf("Unsupported content encryption for the direct key algorthm '%s'", enc)} + } + } + + if _, err = hasher.Write(secret); err != nil { + return nil, &JWKLookupError{Description: fmt.Sprintf("Unable to derive key from hashing the client secret. %s", err.Error())} + } + + return &jose.JSONWebKey{ + Key: hasher.Sum(nil)[:bits], + KeyID: kid, + Algorithm: alg, + Use: use, + }, nil + default: + return &jose.JSONWebKey{ + Key: secret, + KeyID: kid, + Algorithm: alg, + Use: use, + }, nil + } +} + +// EncodeCompactSigned helps encoding a token using a signature backed compact encoding. +func EncodeCompactSigned(ctx context.Context, claims Claims, headers Mapper, key *jose.JSONWebKey) (tokenString string, signature string, err error) { + token := New() + + if headers == nil { + headers = &Headers{} + } + + token.SetJWS(headers, claims, key.KeyID, jose.SignatureAlgorithm(key.Algorithm)) + + return token.CompactSigned(key) +} + +// EncodeNestedCompactEncrypted helps encoding a token using a signature backed compact encoding, then nests that within +// an encrypted compact encoded JWT. +func EncodeNestedCompactEncrypted(ctx context.Context, claims Claims, headers, headersJWE Mapper, keySig, keyEnc *jose.JSONWebKey, enc jose.ContentEncryption) (tokenString string, signature string, err error) { + token := New() + + if headers == nil { + headers = &Headers{} + } + + if headersJWE == nil { + headersJWE = &Headers{} + } + + token.SetJWS(headers, claims, keySig.KeyID, jose.SignatureAlgorithm(keySig.Algorithm)) + token.SetJWE(headersJWE, keyEnc.KeyID, jose.KeyAlgorithm(keyEnc.Algorithm), enc, jose.NONE) + + return token.CompactEncrypted(keySig, keyEnc) +} + +func getJWTSignature(tokenString string) (signature string, err error) { + switch segments := strings.SplitN(tokenString, ".", 5); len(segments) { + case 5: + return "", errors.WithStack(errors.New("invalid token: the token is probably encrypted")) + case 3: + return segments[2], nil + default: + return "", errors.WithStack(fmt.Errorf("invalid token: the format is unknown")) + } +} + +func assign(a, b map[string]any) map[string]any { + for k, w := range b { + if _, ok := a[k]; ok { + continue + } + a[k] = w + } + return a +} + +func getPublicJWK(jwk *jose.JSONWebKey) jose.JSONWebKey { + if jwk == nil { + return jose.JSONWebKey{} + } + + if _, ok := jwk.Key.([]byte); ok && IsSignedJWTClientSecretAlgStr(jwk.Algorithm) { + return jose.JSONWebKey{ + KeyID: jwk.KeyID, + Key: jwk.Key, + Algorithm: jwk.Algorithm, + Use: jwk.Use, + Certificates: jwk.Certificates, + CertificatesURL: jwk.CertificatesURL, + CertificateThumbprintSHA1: jwk.CertificateThumbprintSHA1, + CertificateThumbprintSHA256: jwk.CertificateThumbprintSHA256, + } + } + + return jwk.Public() +} + +// UnsafeParseSignedAny is a function that will attempt to parse any signed token without any verification process. +// It's unsafe for production and should only be used for tests. +func UnsafeParseSignedAny(tokenString string, dest any) (token *jwt.JSONWebToken, err error) { + if token, err = jwt.ParseSigned(tokenString, SignatureAlgorithmsNone); err != nil { + return nil, err + } + + if err = token.UnsafeClaimsWithoutVerification(dest); err != nil { + return nil, err + } + + return token, nil +} + +func newError(message string, err error, more ...error) error { + var format string + var args []any + if message != "" { + format = "%w: %s" + args = []any{err, message} + } else { + format = "%w" + args = []any{err} + } + + for _, e := range more { + format += ": %w" + args = append(args, e) + } + + err = fmt.Errorf(format, args...) + return err +} + +func toMap(obj any) (result map[string]any) { + result = map[string]any{} + + if obj == nil { + return result + } + + v := reflect.TypeOf(obj) + + reflectValue := reflect.ValueOf(obj) + reflectValue = reflect.Indirect(reflectValue) + + if v.Kind() == reflect.Ptr { + v = v.Elem() + } + + for i := 0; i < v.NumField(); i++ { + tag, opts := parseTag(v.Field(i).Tag.Get("json")) + field := reflectValue.Field(i).Interface() + if tag != "" && tag != "-" { + if opts.Contains("omitempty") && isEmptyValue(reflect.ValueOf(field)) { + continue + } + + if v.Field(i).Type.Kind() == reflect.Struct { + result[tag] = toMap(field) + } else { + result[tag] = field + } + } + } + + return result +} + +type tagOptionsJSON string + +func parseTag(tag string) (string, tagOptionsJSON) { + tag, opt, _ := strings.Cut(tag, ",") + return tag, tagOptionsJSON(opt) +} + +func (o tagOptionsJSON) Contains(optionName string) bool { + if len(o) == 0 { + return false + } + + s := string(o) + + for s != "" { + var name string + name, s, _ = strings.Cut(s, ",") + if name == optionName { + return true + } + } + + return false +} + +func isEmptyValue(v reflect.Value) bool { + switch v.Kind() { + case reflect.Array, reflect.Map, reflect.Slice, reflect.String: + return v.Len() == 0 + case reflect.Bool, + reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, + reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr, + reflect.Float32, reflect.Float64, + reflect.Interface, reflect.Pointer: + return v.IsZero() + default: + return false + } +} diff --git a/token/jwt/validate.go b/token/jwt/validate.go new file mode 100644 index 00000000..1797d1cd --- /dev/null +++ b/token/jwt/validate.go @@ -0,0 +1,164 @@ +package jwt + +import ( + "crypto/subtle" + "time" +) + +type ClaimValidationOption func(opts *ClaimValidationOptions) + +type ClaimValidationOptions struct { + timef func() time.Time + iss string + aud []string + audAll []string + sub string + expRequired bool + iatRequired bool + nbfRequired bool +} + +func ValidateTimeFunc(timef func() time.Time) ClaimValidationOption { + return func(opts *ClaimValidationOptions) { + opts.timef = timef + } +} + +func ValidateIssuer(iss string) ClaimValidationOption { + return func(opts *ClaimValidationOptions) { + opts.iss = iss + } +} + +func ValidateAudienceAny(aud ...string) ClaimValidationOption { + return func(opts *ClaimValidationOptions) { + opts.aud = aud + } +} + +func ValidateAudienceAll(aud ...string) ClaimValidationOption { + return func(opts *ClaimValidationOptions) { + opts.audAll = aud + } +} + +func ValidateSubject(sub string) ClaimValidationOption { + return func(opts *ClaimValidationOptions) { + opts.sub = sub + } +} + +func ValidateRequireExpiresAt() ClaimValidationOption { + return func(opts *ClaimValidationOptions) { + opts.expRequired = true + } +} + +func ValidateRequireIssuedAt() ClaimValidationOption { + return func(opts *ClaimValidationOptions) { + opts.iatRequired = true + } +} + +func ValidateRequireNotBefore() ClaimValidationOption { + return func(opts *ClaimValidationOptions) { + opts.nbfRequired = true + } +} + +func verifyAud(aud []string, cmp string, required bool) bool { + if len(aud) == 0 { + return !required + } + + for _, a := range aud { + if subtle.ConstantTimeCompare([]byte(a), []byte(cmp)) == 1 { + return true + } + } + + return false +} + +func verifyAudAny(aud []string, cmp []string, required bool) bool { + if len(aud) == 0 { + return !required + } + + for _, c := range cmp { + for _, a := range aud { + if subtle.ConstantTimeCompare([]byte(a), []byte(c)) == 1 { + return true + } + } + } + + return false +} + +func verifyAudAll(aud []string, cmp []string, required bool) bool { + if len(aud) == 0 { + return !required + } + +outer: + for _, c := range cmp { + for _, a := range aud { + if subtle.ConstantTimeCompare([]byte(a), []byte(c)) == 1 { + continue outer + } + } + + return false + } + + return true +} + +// validInt64Future ensures the given value is in the future. +func validInt64Future(value, now int64, required bool) bool { + if value == 0 { + return !required + } + + return now <= value +} + +// validInt64Past ensures the given value is in the past or the current value. +func validInt64Past(value, now int64, required bool) bool { + if value == 0 { + return !required + } + + return now >= value +} + +func validString(value, cmp string, required bool) bool { + if value == "" { + return !required + } + + return subtle.ConstantTimeCompare([]byte(value), []byte(cmp)) == 1 +} + +type validDateFunc func(value, now int64, required bool) bool + +func validDate(valid validDateFunc, now int64, required bool, date *NumericDate, err error) bool { + if err != nil || valid == nil { + return false + } + + if date == nil { + if required { + return false + } + + return true + } + + if valid(date.Int64(), now, required) { + return true + } + + return false +} diff --git a/token/jwt/validation_error.go b/token/jwt/validation_error.go index 05a32432..ce583f76 100644 --- a/token/jwt/validation_error.go +++ b/token/jwt/validation_error.go @@ -5,25 +5,31 @@ package jwt // Validation provides a backwards compatible error definition // from `jwt-go` to `go-jose`. -// The sourcecode was taken from https://github.com/dgrijalva/jwt-go/blob/master/errors.go -// -// > The errors that might occur when parsing and validating a token const ( - ValidationErrorMalformed uint32 = 1 << iota // Token is malformed - ValidationErrorUnverifiable // Token could not be verified because of signing problems - ValidationErrorSignatureInvalid // Signature validation failed - - // Standard Claim validation errors - ValidationErrorAudience // AUD validation failed - ValidationErrorExpired // EXP validation failed - ValidationErrorIssuedAt // IAT validation failed - ValidationErrorIssuer // ISS validation failed - ValidationErrorNotValidYet // NBF validation failed - ValidationErrorId // JTI validation failed - ValidationErrorClaimsInvalid // Generic claims validation error + ValidationErrorMalformed uint32 = 1 << iota // Token is malformed + ValidationErrorMalformedNotCompactSerialized // Token is malformed specifically it does not have the compact serialized format. + ValidationErrorUnverifiable // Token could not be verified because of signing problems + ValidationErrorSignatureInvalid // Signature validation failed. + ValidationErrorHeaderKeyIDInvalid // Header KID invalid error. + ValidationErrorHeaderAlgorithmInvalid // Header ALG invalid error. + ValidationErrorHeaderTypeInvalid // Header TYP invalid error. + ValidationErrorHeaderEncryptionTypeInvalid // Header TYP invalid error (JWE). + ValidationErrorHeaderContentTypeInvalid // Header TYP invalid error (JWE). + ValidationErrorHeaderContentTypeInvalidMismatch // Header TYP invalid error (JWE). + ValidationErrorHeaderEncryptionKeyIDInvalid // Header KID invalid error (JWE). + ValidationErrorHeaderKeyAlgorithmInvalid // Header ALG invalid error (JWE). + ValidationErrorHeaderContentEncryptionInvalid // Header ENC invalid error (JWE). + ValidationErrorId // Claim JTI validation failed. + ValidationErrorAudience // Claim AUD validation failed. + ValidationErrorExpired // Claim EXP validation failed. + ValidationErrorIssuedAt // Claim IAT validation failed. + ValidationErrorNotValidYet // Claim NBF validation failed. + ValidationErrorIssuer // Claim ISS validation failed. + ValidationErrorSubject // Claim SUB validation failed. + ValidationErrorClaimsInvalid // Generic claims validation error. ) -// The error from Parse if token is not valid +// The ValidationError is an error implementation from Parse if token is not valid. type ValidationError struct { Inner error // stores the error returned by external dependencies, i.e.: KeyFunc Errors uint32 // bitfield. see ValidationError... constants diff --git a/token/jwt/variables.go b/token/jwt/variables.go new file mode 100644 index 00000000..d5421afc --- /dev/null +++ b/token/jwt/variables.go @@ -0,0 +1,9 @@ +package jwt + +import "time" + +var ( + MarshalSingleStringAsArray = true + TimePrecision = time.Second + TimeFunc = time.Now +)