Skip to content

Commit

Permalink
feat(jwt): support jwe
Browse files Browse the repository at this point in the history
This adds support for RFC7516 JSON Web Encryption (JWE). This adds a layer of privacy on top of tokens generated and asserted making their contents irrelevant. This is supported for both generated tokens and decoded tokens for elements like client assertions. This implementation uses the nested format which is generally the supported and recommended format.
  • Loading branch information
james-d-elliott committed Sep 4, 2024
1 parent 1d32ff6 commit bb5471d
Show file tree
Hide file tree
Showing 62 changed files with 2,846 additions and 905 deletions.
4 changes: 2 additions & 2 deletions access_error_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ func TestWriteAccessError(t *testing.T) {
rw.EXPECT().WriteHeader(http.StatusBadRequest)
rw.EXPECT().Write(gomock.Any())

provider.WriteAccessError(context.Background(), rw, nil, ErrInvalidRequest)
provider.WriteAccessError(context.TODO(), rw, nil, ErrInvalidRequest)
}

func TestWriteAccessError_RFC6749(t *testing.T) {
Expand Down Expand Up @@ -62,7 +62,7 @@ func TestWriteAccessError_RFC6749(t *testing.T) {
config.UseLegacyErrorFormat = c.includeExtraFields

rw := httptest.NewRecorder()
provider.WriteAccessError(context.Background(), rw, nil, c.err)
provider.WriteAccessError(context.TODO(), rw, nil, c.err)

var params struct {
Error string `json:"error"` // specified by RFC, required
Expand Down
2 changes: 1 addition & 1 deletion access_write_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ func TestWriteAccessResponse(t *testing.T) {
rw.EXPECT().Write(gomock.Any())
resp.EXPECT().ToMap().Return(map[string]any{})

provider.WriteAccessResponse(context.Background(), rw, ar, resp)
provider.WriteAccessResponse(context.TODO(), rw, ar, resp)
assert.Equal(t, consts.ContentTypeApplicationJSON, header.Get(consts.HeaderContentType))
assert.Equal(t, consts.CacheControlNoStore, header.Get(consts.HeaderCacheControl))
assert.Equal(t, consts.PragmaNoCache, header.Get(consts.HeaderPragma))
Expand Down
73 changes: 21 additions & 52 deletions authorize_request_handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@ import (
"net/http"
"strings"

"github.com/go-jose/go-jose/v4"
"github.com/pkg/errors"

"authelia.com/provider/oauth2/i18n"
Expand Down Expand Up @@ -74,10 +73,11 @@ func (f *Fosite) authorizeRequestParametersFromOpenIDConnectRequestObject(ctx co
}

var (
alg string
algAny, algNone bool
)

switch alg := client.GetRequestObjectSigningAlg(); alg {
switch alg = client.GetRequestObjectSigningAlg(); alg {
case consts.JSONWebTokenAlgNone:
algNone = true
case "":
Expand Down Expand Up @@ -123,65 +123,34 @@ func (f *Fosite) authorizeRequestParametersFromOpenIDConnectRequestObject(ctx co
assertion = request.Form.Get(consts.FormParameterRequest)
}

token, err := jwt.ParseWithClaims(assertion, jwt.MapClaims{}, func(t *jwt.Token) (key any, err error) {
// request_object_signing_alg - OPTIONAL.
// JWS [JWS] alg algorithm [JWA] that MUST be used for signing Request Objects sent to the OP. All Request Objects from this Client MUST be rejected,
// if not signed with this algorithm. Request Objects are described in Section 6.1 of OpenID Connect Core 1.0 [OpenID.Core]. This algorithm MUST
// be used both when the Request Object is passed by value (using the request parameter) and when it is passed by reference (using the request_uri parameter).
// Servers SHOULD support RS256. The value none MAY be used. The default, if omitted, is that any algorithm supported by the OP and the RP MAY be used.
if !algAny && client.GetRequestObjectSigningAlg() != fmt.Sprintf("%s", t.Header[consts.JSONWebTokenHeaderAlgorithm]) {
return nil, errorsx.WithStack(ErrInvalidRequestObject.WithHintf(hintRequestObjectValidate, hintRequestObjectPrefix(openid)).WithDebugf("The OAuth 2.0 client with id '%s' expects request objects to be signed with the '%s' algorithm but the request object was signed with the '%s' algorithm.", request.GetClient().GetID(), client.GetRequestObjectSigningAlg(), t.Header[consts.JSONWebTokenHeaderAlgorithm]))
}

if t.Method == jwt.SigningMethodNone {
algNone = true
strategy := f.Config.GetJWTStrategy(ctx)

return jwt.UnsafeAllowNoneSignatureType, nil
} else if algNone {
return nil, errorsx.WithStack(ErrInvalidRequestObject.WithHintf(hintRequestObjectValidate, hintRequestObjectPrefix(openid)).WithDebugf("The OAuth 2.0 client with id '%s' expects request objects to be signed with the '%s' algorithm but the request object was signed with the '%s' algorithm.", request.GetClient().GetID(), client.GetRequestObjectSigningAlg(), t.Header[consts.JSONWebTokenHeaderAlgorithm]))
token, err := strategy.Decode(ctx, assertion, jwt.WithSigAlgorithm(jwt.SignatureAlgorithmsNone...), jwt.WithJARClient(client))
if err != nil {
var e *jwt.ValidationError
if errors.As(err, &e) {
return wrapSigningKeyFailure(ErrInvalidRequestObject.WithHintf("The OAuth 2.0 client with id '%s' could not validate the request object.", client.GetID()), err)
} else {
return errorsx.WithStack(ErrInvalidRequestObject.WithHintf("The OAuth 2.0 client with id '%s' could not validate the request object.", client.GetID()).WithDebugError(err))
}
}

switch t.Method {
case jose.RS256, jose.RS384, jose.RS512:
if key, err = f.findClientPublicJWK(ctx, client, t, true); err != nil {
return nil, wrapSigningKeyFailure(
ErrInvalidRequestObject.WithHint("Unable to retrieve RSA signing key from OAuth 2.0 Client."), err)
}

return key, nil
case jose.ES256, jose.ES384, jose.ES512:
if key, err = f.findClientPublicJWK(ctx, client, t, false); err != nil {
return nil, wrapSigningKeyFailure(
ErrInvalidRequestObject.WithHint("Unable to retrieve ECDSA signing key from OAuth 2.0 Client."), err)
}

return key, nil
case jose.PS256, jose.PS384, jose.PS512:
if key, err = f.findClientPublicJWK(ctx, client, t, true); err != nil {
return nil, wrapSigningKeyFailure(
ErrInvalidRequestObject.WithHint("Unable to retrieve RSA signing key from OAuth 2.0 Client."), err)
}
if algAny {
if token.SignatureAlgorithm == "none" {

return key, nil
default:
return nil, errorsx.WithStack(ErrInvalidRequestObject.WithHintf(hintRequestObjectValidate, hintRequestObjectPrefix(openid)).WithDebugf("The OAuth 2.0 client with id '%s' provided a request object that uses the unsupported signing algorithm '%s'.", request.GetClient().GetID(), t.Header[consts.JSONWebTokenHeaderAlgorithm]))
}
})

if err != nil {
// Do not re-process already enhanced errors
var e *jwt.ValidationError
if errors.As(err, &e) {
if e.Inner != nil {
return e.Inner
}
if kid := client.GetRequestObjectSigningKeyID(); kid != "" && kid != token.KeyID {

return errorsx.WithStack(ErrInvalidRequestObject.WithHintf(hintRequestObjectValidate, hintRequestObjectPrefix(openid)).WithDebugf("The OAuth 2.0 client with id '%s' provided a request object which failed to validate with error: %+v.", request.GetClient().GetID(), err).WithWrap(err))
}
} else if string(token.SignatureAlgorithm) != alg {

return err
} else if err = token.Claims.Valid(); err != nil {
return errorsx.WithStack(ErrInvalidRequestObject.WithHintf(hintRequestObjectValidate, hintRequestObjectPrefix(openid)).WithDebugf("The OAuth 2.0 client with id '%s' provided a request object which could not be validated because its claims could not be validated with error: %+v.", request.GetClient().GetID(), err).WithWrap(err))
}

if token.SignatureAlgorithm == "none" && !algNone {
print("")
} else if !algAny && string(token.SignatureAlgorithm) != client.GetRequestObjectSigningAlg() {
print("")
}

claims := token.Claims
Expand Down
22 changes: 15 additions & 7 deletions authorize_request_handler_oidc_request_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,9 +30,10 @@ func TestAuthorizeRequestParametersFromOpenIDConnectRequestObject(t *testing.T)
jwks := &jose.JSONWebKeySet{
Keys: []jose.JSONWebKey{
{
KeyID: "kid-foo",
Use: "sig",
Key: &key.PublicKey,
KeyID: "kid-foo",
Use: "sig",
Algorithm: string(jose.RS256),
Key: &key.PublicKey,
},
},
}
Expand Down Expand Up @@ -371,7 +372,14 @@ func TestAuthorizeRequestParametersFromOpenIDConnectRequestObject(t *testing.T)
},
}

provider := &Fosite{Config: &Config{JWKSFetcherStrategy: NewDefaultJWKSFetcherStrategy(), IDTokenIssuer: "https://auth.example.com"}}
config := &Config{JWKSFetcherStrategy: NewDefaultJWKSFetcherStrategy(), IDTokenIssuer: "https://auth.example.com"}

strategy := &jwt.DefaultStrategy{
Config: config,
Issuer: jwt.MustGenDefaultIssuer(),
}

provider := &Fosite{Config: &Config{JWKSFetcherStrategy: NewDefaultJWKSFetcherStrategy(), IDTokenIssuer: "https://auth.example.com", JWTStrategy: strategy}}

err = provider.authorizeRequestParametersFromOpenIDConnectRequestObject(context.Background(), r, tc.par)
if tc.err != nil {
Expand Down Expand Up @@ -400,21 +408,21 @@ func mustGenerateAssertion(t *testing.T, claims jwt.MapClaims, key *rsa.PrivateK
if kid != "" {
token.Header[consts.JSONWebTokenHeaderKeyIdentifier] = kid
}
tokenString, err := token.SignedString(key)
tokenString, err := token.CompactSignedString(key)
require.NoError(t, err)
return tokenString
}

func mustGenerateHSAssertion(t *testing.T, claims jwt.MapClaims) string {
token := jwt.NewWithClaims(jose.HS256, claims)
tokenString, err := token.SignedString([]byte("aaaaaaaaaaaaaaabbbbbbbbbbbbbbbbbbbbbbbcccccccccccccccccccccddddddddddddddddddddddd"))
tokenString, err := token.CompactSignedString([]byte("aaaaaaaaaaaaaaabbbbbbbbbbbbbbbbbbbbbbbcccccccccccccccccccccddddddddddddddddddddddd"))
require.NoError(t, err)
return tokenString
}

func mustGenerateNoneAssertion(t *testing.T, claims jwt.MapClaims) string {
token := jwt.NewWithClaims(jwt.SigningMethodNone, claims)
tokenString, err := token.SignedString(jwt.UnsafeAllowNoneSignatureType)
tokenString, err := token.CompactSignedString(jwt.UnsafeAllowNoneSignatureType)
require.NoError(t, err)
return tokenString
}
64 changes: 50 additions & 14 deletions client.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ package oauth2

import (
"context"
"fmt"
"time"

"github.com/go-jose/go-jose/v4"
Expand All @@ -17,8 +18,12 @@ type Client interface {
// GetID returns the client ID.
GetID() (id string)

// GetClientSecret returns the ClientSecret.
GetClientSecret() (secret ClientSecret)

// GetClientSecretPlainText returns the ClientSecret as plaintext if available.
GetClientSecretPlainText() (secret []byte, err error)

// GetRedirectURIs returns the client's allowed redirect URIs.
GetRedirectURIs() []string

Expand Down Expand Up @@ -230,6 +235,19 @@ type AuthenticationMethodClient interface {
// methods.
GetRevocationEndpointAuthSigningAlg() (alg string)

// GetPushedAuthorizationRequestEndpointAuthMethod is equivalent to the
// 'pushed_authorize_request_endpoint_auth_method' client metadata value which determines the requested Client
// Authentication method for the Pushed Authorization Request Endpoint. The options are client_secret_post,
// client_secret_basic, client_secret_jwt, and private_key_jwt.
GetPushedAuthorizationRequestEndpointAuthMethod() (method string)

// GetPushedAuthorizationRequestEndpointAuthSigningAlg is equivalent to the
// 'pushed_authorization_request_endpoint_auth_signing_alg' client metadata value which determines the JWS [JWS] alg
// algorithm [JWA] that MUST be used for signing the JWT [JWT] used to authenticate the
// Client at the Pushed Authorization Request Endpoint for the private_key_jwt and client_secret_jwt authentication
// methods.
GetPushedAuthorizationRequestEndpointAuthSigningAlg() (alg string)

JSONWebKeysClient
}

Expand Down Expand Up @@ -414,20 +432,22 @@ type DefaultClient struct {

type DefaultJARClient struct {
*DefaultClient
JSONWebKeysURI string `json:"jwks_uri"`
JSONWebKeys *jose.JSONWebKeySet `json:"jwks"`
TokenEndpointAuthMethod string `json:"token_endpoint_auth_method"`
IntrospectionEndpointAuthMethod string `json:"introspection_endpoint_auth_method"`
RevocationEndpointAuthMethod string `json:"revocation_endpoint_auth_method"`
RequestURIs []string `json:"request_uris"`
RequestObjectSigningKeyID string `json:"request_object_signing_kid"`
RequestObjectSigningAlg string `json:"request_object_signing_alg"`
RequestObjectEncryptionKeyID string `json:"request_object_encryption_kid"`
RequestObjectEncryptionAlg string `json:"request_object_encryption_alg"`
RequestObjectEncryptionEnc string `json:"request_object_encryption_enc"`
TokenEndpointAuthSigningAlg string `json:"token_endpoint_auth_signing_alg"`
IntrospectionEndpointAuthSigningAlg string `json:"introspection_endpoint_auth_signing_alg"`
RevocationEndpointAuthSigningAlg string `json:"revocation_endpoint_auth_signing_alg"`
JSONWebKeysURI string `json:"jwks_uri"`
JSONWebKeys *jose.JSONWebKeySet `json:"jwks"`
TokenEndpointAuthMethod string `json:"token_endpoint_auth_method"`
IntrospectionEndpointAuthMethod string `json:"introspection_endpoint_auth_method"`
RevocationEndpointAuthMethod string `json:"revocation_endpoint_auth_method"`
PushedAuthorizationRequestEndpointAuthMethod string `json:"pushed_authorization_request_endpoint_auth_method"`
RequestURIs []string `json:"request_uris"`
RequestObjectSigningKeyID string `json:"request_object_signing_kid"`
RequestObjectSigningAlg string `json:"request_object_signing_alg"`
RequestObjectEncryptionKeyID string `json:"request_object_encryption_kid"`
RequestObjectEncryptionAlg string `json:"request_object_encryption_alg"`
RequestObjectEncryptionEnc string `json:"request_object_encryption_enc"`
TokenEndpointAuthSigningAlg string `json:"token_endpoint_auth_signing_alg"`
IntrospectionEndpointAuthSigningAlg string `json:"introspection_endpoint_auth_signing_alg"`
RevocationEndpointAuthSigningAlg string `json:"revocation_endpoint_auth_signing_alg"`
PushedAuthorizationRequestEndpointAuthSigningAlg string `json:"pushed_authorization_request_endpoint_auth_signing_alg"`
}

type DefaultResponseModeClient struct {
Expand Down Expand Up @@ -455,6 +475,14 @@ func (c *DefaultClient) GetClientSecret() (secret ClientSecret) {
return c.ClientSecret
}

func (c *DefaultClient) GetClientSecretPlainText() (secret []byte, err error) {
if c.ClientSecret == nil || !c.ClientSecret.Valid() {
return nil, fmt.Errorf("this secret doesn't support plaintext")
}

return c.ClientSecret.GetPlainTextValue()
}

func (c *DefaultClient) GetRotatedClientSecrets() (secrets []ClientSecret) {
return c.RotatedClientSecrets
}
Expand Down Expand Up @@ -513,6 +541,10 @@ func (c *DefaultJARClient) GetRevocationEndpointAuthSigningAlg() string {
return c.RevocationEndpointAuthSigningAlg
}

func (c *DefaultJARClient) GetPushedAuthorizationRequestEndpointAuthSigningAlg() (alg string) {
return c.PushedAuthorizationRequestEndpointAuthSigningAlg
}

func (c *DefaultJARClient) GetRequestObjectSigningKeyID() string {
return c.RequestObjectSigningKeyID
}
Expand Down Expand Up @@ -545,6 +577,10 @@ func (c *DefaultJARClient) GetRevocationEndpointAuthMethod() string {
return c.RevocationEndpointAuthMethod
}

func (c *DefaultJARClient) GetPushedAuthorizationRequestEndpointAuthMethod() string {
return c.PushedAuthorizationRequestEndpointAuthMethod
}

func (c *DefaultJARClient) GetRequestURIs() []string {
return c.RequestURIs
}
Expand Down
16 changes: 4 additions & 12 deletions client_authentication_jwks_strategy.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,21 +13,13 @@ import (
"github.com/go-jose/go-jose/v4"
"github.com/hashicorp/go-retryablehttp"

"authelia.com/provider/oauth2/token/jwt"
"authelia.com/provider/oauth2/x/errorsx"
)

const defaultJWKSFetcherStrategyCachePrefix = "authelia.com/provider/oauth2.DefaultJWKSFetcherStrategy:"

// JWKSFetcherStrategy is a strategy which pulls (optionally caches) JSON Web Key Sets from a location,
// typically a client's jwks_uri.
type JWKSFetcherStrategy interface {
// Resolve returns the JSON Web Key Set, or an error if something went wrong. The forceRefresh, if true, forces
// the strategy to fetch the key from the remote. If forceRefresh is false, the strategy may use a caching strategy
// to fetch the key.
Resolve(ctx context.Context, location string, ignoreCache bool) (*jose.JSONWebKeySet, error)
}

// DefaultJWKSFetcherStrategy is a default implementation of the JWKSFetcherStrategy interface.
// DefaultJWKSFetcherStrategy is a default implementation of the jwt.JWKSFetcherStrategy interface.
type DefaultJWKSFetcherStrategy struct {
client *retryablehttp.Client
cache *ristretto.Cache
Expand All @@ -36,7 +28,7 @@ type DefaultJWKSFetcherStrategy struct {
}

// NewDefaultJWKSFetcherStrategy returns a new instance of the DefaultJWKSFetcherStrategy.
func NewDefaultJWKSFetcherStrategy(opts ...func(*DefaultJWKSFetcherStrategy)) JWKSFetcherStrategy {
func NewDefaultJWKSFetcherStrategy(opts ...func(*DefaultJWKSFetcherStrategy)) jwt.JWKSFetcherStrategy {
dc, err := ristretto.NewCache(&ristretto.Config{
NumCounters: 10000 * 10,
MaxCost: 10000,
Expand Down Expand Up @@ -100,7 +92,7 @@ func (s *DefaultJWKSFetcherStrategy) Resolve(ctx context.Context, location strin
if !ok || ignoreCache {
req, err := retryablehttp.NewRequest(http.MethodGet, location, nil)
if err != nil {
return nil, errorsx.WithStack(ErrServerError.WithHintf("Unable to create HTTP 'GET' request to fetch JSON Web Keys from location '%s'.", location).WithWrap(err).WithDebugError(err))
return nil, errorsx.WithStack(ErrServerError.WithHintf("Unable to create HTTP 'GET' request to fetch JSON Web Keys from location '%s'.", location).WithWrap(err).WithDebugError(err))
}

hc := s.client
Expand Down
8 changes: 4 additions & 4 deletions client_authentication_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -792,31 +792,31 @@ func TestAuthenticateClientTwice(t *testing.T) {
func mustGenerateRSAAssertion(t *testing.T, claims jwt.MapClaims, key *rsa.PrivateKey, kid string) string {
token := jwt.NewWithClaims(jose.RS256, claims)
token.Header["kid"] = kid
tokenString, err := token.SignedString(key)
tokenString, err := token.CompactSignedString(key)
require.NoError(t, err)
return tokenString
}

func mustGenerateECDSAAssertion(t *testing.T, claims jwt.MapClaims, key *ecdsa.PrivateKey, kid string) string {
token := jwt.NewWithClaims(jose.ES256, claims)
token.Header["kid"] = kid
tokenString, err := token.SignedString(key)
tokenString, err := token.CompactSignedString(key)
require.NoError(t, err)
return tokenString
}

//nolint:unparam
func mustGenerateHSAssertion(t *testing.T, claims jwt.MapClaims, key *rsa.PrivateKey, kid string) string {
token := jwt.NewWithClaims(jose.HS256, claims)
tokenString, err := token.SignedString([]byte("aaaaaaaaaaaaaaabbbbbbbbbbbbbbbbbbbbbbbcccccccccccccccccccccddddddddddddddddddddddd"))
tokenString, err := token.CompactSignedString([]byte("aaaaaaaaaaaaaaabbbbbbbbbbbbbbbbbbbbbbbcccccccccccccccccccccddddddddddddddddddddddd"))
require.NoError(t, err)
return tokenString
}

//nolint:unparam
func mustGenerateNoneAssertion(t *testing.T, claims jwt.MapClaims, key *rsa.PrivateKey, kid string) string {
token := jwt.NewWithClaims(jwt.SigningMethodNone, claims)
tokenString, err := token.SignedString(jwt.UnsafeAllowNoneSignatureType)
tokenString, err := token.CompactSignedString(jwt.UnsafeAllowNoneSignatureType)
require.NoError(t, err)
return tokenString
}
Expand Down
Loading

0 comments on commit bb5471d

Please sign in to comment.