Skip to content

Commit

Permalink
feat: claims interface
Browse files Browse the repository at this point in the history
  • Loading branch information
james-d-elliott committed Oct 3, 2024
1 parent 7fac454 commit 6540167
Show file tree
Hide file tree
Showing 3 changed files with 95 additions and 58 deletions.
130 changes: 73 additions & 57 deletions token/jwt/claims_id_token.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,12 +27,12 @@ type IDTokenClaims struct {
IssuedAt *NumericDate `json:"iat"`
RequestedAt *NumericDate `json:"rat"`
AuthTime *NumericDate `json:"auth_time"`
AccessTokenHash string `json:"at_hash"`
AuthenticationContextClassReference string `json:"acr"`
AuthenticationMethodsReferences []string `json:"amr"`
AccessTokenHash string `json:"at_hash"`
CodeHash string `json:"c_hash"`
StateHash string `json:"s_hash"`
Extra map[string]any `json:"ext"`
Extra map[string]any `json:"ext,omitempty"`
}

func (c *IDTokenClaims) GetExpirationTime() (exp *NumericDate, err error) {
Expand Down Expand Up @@ -196,32 +196,7 @@ func (c *IDTokenClaims) UnmarshalJSON(data []byte) error {
case ClaimSubject:
c.Subject, ok = value.(string)
case ClaimAudience:
switch aud := value.(type) {
case nil:
ok = true
case string:
ok = true

c.Audience = []string{aud}
case []string:
ok = true

c.Audience = aud
case []any:
ok = true

loop:
for _, av := range aud {
switch a := av.(type) {
case string:
c.Audience = append(c.Audience, a)
default:
ok = false

break loop
}
}
}
c.Audience, ok = toStringSlice(value)
case ClaimNonce:
c.Nonce, ok = value.(string)
case ClaimExpirationTime:
Expand All @@ -240,12 +215,16 @@ func (c *IDTokenClaims) UnmarshalJSON(data []byte) error {
if c.AuthTime, err = toNumericDate(value); err == nil {
ok = true
}
case ClaimAuthenticationContextClassReference:
c.AuthenticationContextClassReference, ok = value.(string)
case ClaimAuthenticationMethodsReference:
c.AuthenticationMethodsReferences, ok = toStringSlice(value)
case ClaimAccessTokenHash:
c.AccessTokenHash, ok = value.(string)
case ClaimCodeHash:
c.CodeHash, ok = value.(string)
case ClaimStateHash:
c.StateHash, ok = value.(string)
case ClaimAuthenticationContextClassReference:
c.AuthenticationContextClassReference, ok = value.(string)
default:
if c.Extra == nil {
c.Extra = make(map[string]any)
Expand All @@ -268,10 +247,10 @@ func (c *IDTokenClaims) UnmarshalJSON(data []byte) error {
func (c *IDTokenClaims) ToMap() map[string]any {
var ret = Copy(c.Extra)

if c.Subject != "" {
ret[ClaimSubject] = c.Subject
if c.JTI != "" {
ret[consts.ClaimJWTID] = c.JTI
} else {
delete(ret, ClaimSubject)
ret[consts.ClaimJWTID] = uuid.New().String()
}

if c.Issuer != "" {
Expand All @@ -280,10 +259,10 @@ func (c *IDTokenClaims) ToMap() map[string]any {
delete(ret, consts.ClaimIssuer)
}

if c.JTI != "" {
ret[consts.ClaimJWTID] = c.JTI
if c.Subject != "" {
ret[ClaimSubject] = c.Subject
} else {
ret[consts.ClaimJWTID] = uuid.New().String()
delete(ret, ClaimSubject)
}

if len(c.Audience) > 0 {
Expand All @@ -292,40 +271,28 @@ func (c *IDTokenClaims) ToMap() map[string]any {
ret[consts.ClaimAudience] = []string{}
}

if c.IssuedAt != nil {
ret[consts.ClaimIssuedAt] = c.IssuedAt.Unix()
} else {
delete(ret, consts.ClaimIssuedAt)
}

if c.ExpirationTime != nil {
ret[consts.ClaimExpirationTime] = c.ExpirationTime.Unix()
} else {
delete(ret, consts.ClaimExpirationTime)
}

if len(c.Nonce) > 0 {
ret[consts.ClaimNonce] = c.Nonce
} else {
delete(ret, consts.ClaimNonce)
}

if len(c.AccessTokenHash) > 0 {
ret[consts.ClaimAccessTokenHash] = c.AccessTokenHash
if c.ExpirationTime != nil {
ret[consts.ClaimExpirationTime] = c.ExpirationTime.Unix()
} else {
delete(ret, consts.ClaimAccessTokenHash)
delete(ret, consts.ClaimExpirationTime)
}

if len(c.CodeHash) > 0 {
ret[consts.ClaimCodeHash] = c.CodeHash
if c.IssuedAt != nil {
ret[consts.ClaimIssuedAt] = c.IssuedAt.Unix()
} else {
delete(ret, consts.ClaimCodeHash)
delete(ret, consts.ClaimIssuedAt)
}

if len(c.StateHash) > 0 {
ret[consts.ClaimStateHash] = c.StateHash
if c.RequestedAt != nil {
ret[consts.ClaimRequestedAt] = c.RequestedAt.Unix()
} else {
delete(ret, consts.ClaimStateHash)
delete(ret, consts.ClaimRequestedAt)
}

if c.AuthTime != nil {
Expand All @@ -346,6 +313,24 @@ func (c *IDTokenClaims) ToMap() map[string]any {
delete(ret, consts.ClaimAuthenticationMethodsReference)
}

if len(c.AccessTokenHash) > 0 {
ret[consts.ClaimAccessTokenHash] = c.AccessTokenHash
} else {
delete(ret, consts.ClaimAccessTokenHash)
}

if len(c.CodeHash) > 0 {
ret[consts.ClaimCodeHash] = c.CodeHash
} else {
delete(ret, consts.ClaimCodeHash)
}

if len(c.StateHash) > 0 {
ret[consts.ClaimStateHash] = c.StateHash
} else {
delete(ret, consts.ClaimStateHash)
}

return ret
}

Expand Down Expand Up @@ -381,6 +366,37 @@ func (c IDTokenClaims) toNumericDate(key string) (date *NumericDate, err error)
return toNumericDate(v)
}

func toStringSlice(value any) (values []string, ok bool) {
switch t := value.(type) {
case nil:
ok = true
case string:
ok = true

values = []string{t}
case []string:
ok = true

values = t
case []any:
ok = true

loop:
for _, tv := range t {
switch vv := tv.(type) {
case string:
values = append(values, vv)
default:
ok = false

break loop
}
}
}

return values, ok
}

var (
_ Claims = (*IDTokenClaims)(nil)
)
2 changes: 2 additions & 0 deletions token/jwt/claims_id_token_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ func TestIDTokenClaimsToMap(t *testing.T) {
ClaimIssuer: idTokenClaims.Issuer,
ClaimAudience: idTokenClaims.Audience,
ClaimExpirationTime: idTokenClaims.ExpirationTime.Unix(),
ClaimRequestedAt: idTokenClaims.RequestedAt.Unix(),
"foo": idTokenClaims.Extra["foo"],
"baz": idTokenClaims.Extra["baz"],
ClaimAccessTokenHash: idTokenClaims.AccessTokenHash,
Expand All @@ -67,6 +68,7 @@ func TestIDTokenClaimsToMap(t *testing.T) {
consts.ClaimIssuer: idTokenClaims.Issuer,
consts.ClaimAudience: idTokenClaims.Audience,
consts.ClaimExpirationTime: idTokenClaims.ExpirationTime.Unix(),
consts.ClaimRequestedAt: idTokenClaims.RequestedAt.Unix(),
"foo": idTokenClaims.Extra["foo"],
"baz": idTokenClaims.Extra["baz"],
consts.ClaimAccessTokenHash: idTokenClaims.AccessTokenHash,
Expand Down
21 changes: 20 additions & 1 deletion token/jwt/util.go
Original file line number Diff line number Diff line change
Expand Up @@ -276,8 +276,27 @@ func NewClientSecretJWK(ctx context.Context, secret []byte, kid, alg, enc, use s

switch use {
case JSONWebTokenUseSignature:
var (
hasher hash.Hash
)

switch jose.SignatureAlgorithm(alg) {
case jose.HS256:
hasher = sha256.New()
case jose.HS384:
hasher = sha512.New384()
case jose.HS512:
hasher = sha512.New()
default:
return nil, &JWKLookupError{Description: fmt.Sprintf("Unsupported algorithm '%s'", alg)}
}

if _, err = hasher.Write(secret); err != nil {
return nil, &JWKLookupError{Description: fmt.Sprintf("Unable to derive key from hashing the client secret. %s", err.Error())}
}

return &jose.JSONWebKey{
Key: secret,
Key: hasher.Sum(nil),
KeyID: kid,
Algorithm: alg,
Use: use,
Expand Down

0 comments on commit 6540167

Please sign in to comment.