From 6540167991baeec884c8af87b5095f595addd13f Mon Sep 17 00:00:00 2001 From: James Elliott Date: Thu, 3 Oct 2024 22:54:03 +1000 Subject: [PATCH] feat: claims interface --- token/jwt/claims_id_token.go | 130 +++++++++++++++++------------- token/jwt/claims_id_token_test.go | 2 + token/jwt/util.go | 21 ++++- 3 files changed, 95 insertions(+), 58 deletions(-) diff --git a/token/jwt/claims_id_token.go b/token/jwt/claims_id_token.go index 4dec51a2..d070dbd8 100644 --- a/token/jwt/claims_id_token.go +++ b/token/jwt/claims_id_token.go @@ -27,12 +27,12 @@ type IDTokenClaims struct { IssuedAt *NumericDate `json:"iat"` RequestedAt *NumericDate `json:"rat"` AuthTime *NumericDate `json:"auth_time"` - AccessTokenHash string `json:"at_hash"` AuthenticationContextClassReference string `json:"acr"` AuthenticationMethodsReferences []string `json:"amr"` + AccessTokenHash string `json:"at_hash"` CodeHash string `json:"c_hash"` StateHash string `json:"s_hash"` - Extra map[string]any `json:"ext"` + Extra map[string]any `json:"ext,omitempty"` } func (c *IDTokenClaims) GetExpirationTime() (exp *NumericDate, err error) { @@ -196,32 +196,7 @@ func (c *IDTokenClaims) UnmarshalJSON(data []byte) error { case ClaimSubject: c.Subject, ok = value.(string) case ClaimAudience: - switch aud := value.(type) { - case nil: - ok = true - case string: - ok = true - - c.Audience = []string{aud} - case []string: - ok = true - - c.Audience = aud - case []any: - ok = true - - loop: - for _, av := range aud { - switch a := av.(type) { - case string: - c.Audience = append(c.Audience, a) - default: - ok = false - - break loop - } - } - } + c.Audience, ok = toStringSlice(value) case ClaimNonce: c.Nonce, ok = value.(string) case ClaimExpirationTime: @@ -240,12 +215,16 @@ func (c *IDTokenClaims) UnmarshalJSON(data []byte) error { if c.AuthTime, err = toNumericDate(value); err == nil { ok = true } + case ClaimAuthenticationContextClassReference: + c.AuthenticationContextClassReference, ok = value.(string) + case ClaimAuthenticationMethodsReference: + c.AuthenticationMethodsReferences, ok = toStringSlice(value) + case ClaimAccessTokenHash: + c.AccessTokenHash, ok = value.(string) case ClaimCodeHash: c.CodeHash, ok = value.(string) case ClaimStateHash: c.StateHash, ok = value.(string) - case ClaimAuthenticationContextClassReference: - c.AuthenticationContextClassReference, ok = value.(string) default: if c.Extra == nil { c.Extra = make(map[string]any) @@ -268,10 +247,10 @@ func (c *IDTokenClaims) UnmarshalJSON(data []byte) error { func (c *IDTokenClaims) ToMap() map[string]any { var ret = Copy(c.Extra) - if c.Subject != "" { - ret[ClaimSubject] = c.Subject + if c.JTI != "" { + ret[consts.ClaimJWTID] = c.JTI } else { - delete(ret, ClaimSubject) + ret[consts.ClaimJWTID] = uuid.New().String() } if c.Issuer != "" { @@ -280,10 +259,10 @@ func (c *IDTokenClaims) ToMap() map[string]any { delete(ret, consts.ClaimIssuer) } - if c.JTI != "" { - ret[consts.ClaimJWTID] = c.JTI + if c.Subject != "" { + ret[ClaimSubject] = c.Subject } else { - ret[consts.ClaimJWTID] = uuid.New().String() + delete(ret, ClaimSubject) } if len(c.Audience) > 0 { @@ -292,40 +271,28 @@ func (c *IDTokenClaims) ToMap() map[string]any { ret[consts.ClaimAudience] = []string{} } - if c.IssuedAt != nil { - ret[consts.ClaimIssuedAt] = c.IssuedAt.Unix() - } else { - delete(ret, consts.ClaimIssuedAt) - } - - if c.ExpirationTime != nil { - ret[consts.ClaimExpirationTime] = c.ExpirationTime.Unix() - } else { - delete(ret, consts.ClaimExpirationTime) - } - if len(c.Nonce) > 0 { ret[consts.ClaimNonce] = c.Nonce } else { delete(ret, consts.ClaimNonce) } - if len(c.AccessTokenHash) > 0 { - ret[consts.ClaimAccessTokenHash] = c.AccessTokenHash + if c.ExpirationTime != nil { + ret[consts.ClaimExpirationTime] = c.ExpirationTime.Unix() } else { - delete(ret, consts.ClaimAccessTokenHash) + delete(ret, consts.ClaimExpirationTime) } - if len(c.CodeHash) > 0 { - ret[consts.ClaimCodeHash] = c.CodeHash + if c.IssuedAt != nil { + ret[consts.ClaimIssuedAt] = c.IssuedAt.Unix() } else { - delete(ret, consts.ClaimCodeHash) + delete(ret, consts.ClaimIssuedAt) } - if len(c.StateHash) > 0 { - ret[consts.ClaimStateHash] = c.StateHash + if c.RequestedAt != nil { + ret[consts.ClaimRequestedAt] = c.RequestedAt.Unix() } else { - delete(ret, consts.ClaimStateHash) + delete(ret, consts.ClaimRequestedAt) } if c.AuthTime != nil { @@ -346,6 +313,24 @@ func (c *IDTokenClaims) ToMap() map[string]any { delete(ret, consts.ClaimAuthenticationMethodsReference) } + if len(c.AccessTokenHash) > 0 { + ret[consts.ClaimAccessTokenHash] = c.AccessTokenHash + } else { + delete(ret, consts.ClaimAccessTokenHash) + } + + if len(c.CodeHash) > 0 { + ret[consts.ClaimCodeHash] = c.CodeHash + } else { + delete(ret, consts.ClaimCodeHash) + } + + if len(c.StateHash) > 0 { + ret[consts.ClaimStateHash] = c.StateHash + } else { + delete(ret, consts.ClaimStateHash) + } + return ret } @@ -381,6 +366,37 @@ func (c IDTokenClaims) toNumericDate(key string) (date *NumericDate, err error) return toNumericDate(v) } +func toStringSlice(value any) (values []string, ok bool) { + switch t := value.(type) { + case nil: + ok = true + case string: + ok = true + + values = []string{t} + case []string: + ok = true + + values = t + case []any: + ok = true + + loop: + for _, tv := range t { + switch vv := tv.(type) { + case string: + values = append(values, vv) + default: + ok = false + + break loop + } + } + } + + return values, ok +} + var ( _ Claims = (*IDTokenClaims)(nil) ) diff --git a/token/jwt/claims_id_token_test.go b/token/jwt/claims_id_token_test.go index bc35493f..8c22119c 100644 --- a/token/jwt/claims_id_token_test.go +++ b/token/jwt/claims_id_token_test.go @@ -49,6 +49,7 @@ func TestIDTokenClaimsToMap(t *testing.T) { ClaimIssuer: idTokenClaims.Issuer, ClaimAudience: idTokenClaims.Audience, ClaimExpirationTime: idTokenClaims.ExpirationTime.Unix(), + ClaimRequestedAt: idTokenClaims.RequestedAt.Unix(), "foo": idTokenClaims.Extra["foo"], "baz": idTokenClaims.Extra["baz"], ClaimAccessTokenHash: idTokenClaims.AccessTokenHash, @@ -67,6 +68,7 @@ func TestIDTokenClaimsToMap(t *testing.T) { consts.ClaimIssuer: idTokenClaims.Issuer, consts.ClaimAudience: idTokenClaims.Audience, consts.ClaimExpirationTime: idTokenClaims.ExpirationTime.Unix(), + consts.ClaimRequestedAt: idTokenClaims.RequestedAt.Unix(), "foo": idTokenClaims.Extra["foo"], "baz": idTokenClaims.Extra["baz"], consts.ClaimAccessTokenHash: idTokenClaims.AccessTokenHash, diff --git a/token/jwt/util.go b/token/jwt/util.go index f97ec97c..2f299772 100644 --- a/token/jwt/util.go +++ b/token/jwt/util.go @@ -276,8 +276,27 @@ func NewClientSecretJWK(ctx context.Context, secret []byte, kid, alg, enc, use s switch use { case JSONWebTokenUseSignature: + var ( + hasher hash.Hash + ) + + switch jose.SignatureAlgorithm(alg) { + case jose.HS256: + hasher = sha256.New() + case jose.HS384: + hasher = sha512.New384() + case jose.HS512: + hasher = sha512.New() + default: + return nil, &JWKLookupError{Description: fmt.Sprintf("Unsupported algorithm '%s'", alg)} + } + + if _, err = hasher.Write(secret); err != nil { + return nil, &JWKLookupError{Description: fmt.Sprintf("Unable to derive key from hashing the client secret. %s", err.Error())} + } + return &jose.JSONWebKey{ - Key: secret, + Key: hasher.Sum(nil), KeyID: kid, Algorithm: alg, Use: use,