Skip to content

Commit

Permalink
Check OpenID audience when validating token. (#3541)
Browse files Browse the repository at this point in the history
* Add validation of 'aud' claim in JWT tokens, along with dev keycloak config support

* Add minder audience protocol mapper

* Switch to jwt.WithAudience

* Don't log full bearer tokens

* Set the default for aud to 'minder'

* Update client to request a specific audience in addition to openid

* Fix added test in jwauth

* Fix lint

---------

Co-authored-by: Eleftheria Stein-Kousathana <[email protected]>
  • Loading branch information
evankanderson and eleftherias authored Jun 5, 2024
1 parent e3b4022 commit 1ec9ffc
Show file tree
Hide file tree
Showing 9 changed files with 64 additions and 15 deletions.
11 changes: 5 additions & 6 deletions cmd/cli/app/auth/auth_login.go
Original file line number Diff line number Diff line change
Expand Up @@ -66,12 +66,11 @@ func LoginCommand(cmd *cobra.Command, _ []string) error {
// wait for the token to be received
var loginErr loginError
token, err := Login(ctx, cmd, clientConfig, nil, skipBrowser)
if errors.As(err, &loginErr) {
if loginErr.isAccessDenied() {
cmd.Println("Access denied. Please run the command again and accept the terms and conditions.")
return nil
}
} else if err != nil {
if errors.As(err, &loginErr) && loginErr.isAccessDenied() {
cmd.Println("Access denied. Please run the command again and accept the terms and conditions.")
return nil
}
if err != nil {
return err
}

Expand Down
2 changes: 1 addition & 1 deletion cmd/cli/app/auth/common.go
Original file line number Diff line number Diff line change
Expand Up @@ -167,7 +167,7 @@ func Login(
}

issuerUrl := parsedURL.JoinPath("realms/stacklok")
scopes := []string{"openid"}
scopes := []string{"openid", "minder-audience"}

if len(extraScopes) > 0 {
scopes = append(scopes, extraScopes...)
Expand Down
2 changes: 1 addition & 1 deletion cmd/server/app/serve.go
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ var serveCmd = &cobra.Command{
if err != nil {
return fmt.Errorf("failed to create JWKS URL: %w\n", err)
}
jwt, err := auth.NewJwtValidator(ctx, jwksUrl.String())
jwt, err := auth.NewJwtValidator(ctx, jwksUrl.String(), cfg.Identity.Server.Audience)
if err != nil {
return fmt.Errorf("failed to fetch and cache identity provider JWKS: %w\n", err)
}
Expand Down
1 change: 1 addition & 0 deletions config/server-config.yaml.example
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ identity:
issuer_url: http://keycloak:8080 # Use http://localhost:8081 instead for running minder outside of docker compose
client_id: minder-server
client_secret: secret
audience: minder

# Crypto (these should be ultimately stored in a secure vault)
# The token key can be generated with:
Expand Down
20 changes: 19 additions & 1 deletion identity/config/stacklok.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,23 @@ clientScopes:
attributes:
"include.in.token.scope": "true"
"display.on.consent.screen": "false"

- name: minder-audience
description: "Add minder to audience claim"
protocol: openid-connect
attributes:
"include.in.token.scope": "true"
"display.on.consent.screen": "false"
protocolMappers:
- name: minder_audience
protocol: openid-connect
protocolMapper: oidc-audience-mapper
consentRequired: false
config:
id.token.claim: "false"
access.token.claim: "true"
introspection.token.claim: "true"
included.custom.audience: "minder"
userinfo.token.claim: "false"

clients:
# From:
Expand All @@ -100,6 +116,7 @@ clients:
- roles
- web-origins
- gh-data
- minder-audience
optionalClientScopes:
- microprofile-jwt
- offline_access
Expand All @@ -118,6 +135,7 @@ clients:
- roles
- web-origins
- gh-data
- minder-audience
optionalClientScopes:
- microprofile-jwt
- offline_access
Expand Down
22 changes: 19 additions & 3 deletions internal/auth/jwauth_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ func TestParseAndValidate(t *testing.T) {
{
name: "Valid token",
buildToken: func() string {
token, _ := jwt.NewBuilder().Subject("123").Expiration(time.Now().Add(time.Duration(1) * time.Minute)).Build()
token, _ := jwt.NewBuilder().Subject("123").Audience([]string{"minder"}).Expiration(time.Now().Add(time.Duration(1) * time.Minute)).Build()
signed, _ := jwt.Sign(token, jwt.WithKey(jwa.RS256, privateJwk))
return string(signed)
},
Expand All @@ -71,7 +71,7 @@ func TestParseAndValidate(t *testing.T) {
{
name: "Expired token",
buildToken: func() string {
token, _ := jwt.NewBuilder().Subject("123").Expiration(time.Now().Add(-time.Duration(1) * time.Minute)).Build()
token, _ := jwt.NewBuilder().Subject("123").Audience([]string{"minder"}).Expiration(time.Now().Add(-time.Duration(1) * time.Minute)).Build()
signed, _ := jwt.Sign(token, jwt.WithKey(jwa.RS256, privateJwk))
return string(signed)
},
Expand Down Expand Up @@ -119,6 +119,19 @@ func TestParseAndValidate(t *testing.T) {
checkError: func(t *testing.T, err error) {
t.Helper()

assert.Error(t, err)
},
},
{
name: "Missing audience claim",
buildToken: func() string {
token, _ := jwt.NewBuilder().Subject("123").Expiration(time.Now().Add(-time.Duration(1) * time.Minute)).Build()
signed, _ := jwt.Sign(token, jwt.WithKey(jwa.RS256, privateJwk))
return string(signed)
},
checkError: func(t *testing.T, err error) {
t.Helper()

assert.Error(t, err)
},
},
Expand All @@ -134,7 +147,10 @@ func TestParseAndValidate(t *testing.T) {
mockKeyFetcher := mockjwt.NewMockKeySetFetcher(ctrl)
mockKeyFetcher.EXPECT().GetKeySet().Return(jwks, nil)

jwtValidator := JwkSetJwtValidator{jwksFetcher: mockKeyFetcher}
jwtValidator := JwkSetJwtValidator{
jwksFetcher: mockKeyFetcher,
aud: "minder",
}
_, err := jwtValidator.ParseAndValidate(tc.buildToken())
tc.checkError(t, err)
})
Expand Down
11 changes: 9 additions & 2 deletions internal/auth/jwtauth.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ type JwtValidator interface {
// JwkSetJwtValidator is a JWT validator that uses a JWK set URL to validate the tokens
type JwkSetJwtValidator struct {
jwksFetcher KeySetFetcher
aud string
}

// KeySetFetcher provides the functions to fetch a JWK set
Expand All @@ -60,7 +61,12 @@ func (j *JwkSetJwtValidator) ParseAndValidate(tokenString string) (openid.Token,
return nil, err
}

token, err := jwt.ParseString(tokenString, jwt.WithKeySet(set), jwt.WithValidate(true), jwt.WithToken(openid.New()))
token, err := jwt.ParseString(
tokenString,
jwt.WithKeySet(set),
jwt.WithValidate(true),
jwt.WithToken(openid.New()),
jwt.WithAudience(j.aud))
if err != nil {
return nil, err
}
Expand All @@ -78,7 +84,7 @@ func (j *JwkSetJwtValidator) ParseAndValidate(tokenString string) (openid.Token,
}

// NewJwtValidator creates a new JWT validator that uses a JWK set URL to validate the tokens
func NewJwtValidator(ctx context.Context, jwksUrl string) (JwtValidator, error) {
func NewJwtValidator(ctx context.Context, jwksUrl string, aud string) (JwtValidator, error) {
// Cache the JWK set
// The cache will refresh every 15 minutes by default
jwks := jwk.NewCache(ctx)
Expand All @@ -100,6 +106,7 @@ func NewJwtValidator(ctx context.Context, jwksUrl string) (JwtValidator, error)
}
return &JwkSetJwtValidator{
jwksFetcher: &keySetCache,
aud: aud,
}, nil
}

Expand Down
2 changes: 2 additions & 0 deletions internal/config/server/identity.go
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,8 @@ type IdentityConfig struct {
ClientSecret string `mapstructure:"client_secret" default:"secret"`
// ClientSecretFile is the location of a file containing the client secret for the minder server (optional)
ClientSecretFile string `mapstructure:"client_secret_file"`
// Audience is the expected audience for JWT tokens (see OpenID spec)
Audience string `mapstructure:"audience" default:"minder"`
}

// GetClientSecret returns the minder-server client secret
Expand Down
8 changes: 7 additions & 1 deletion internal/controlplane/handlers_token.go
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,13 @@ func TokenValidationInterceptor(ctx context.Context, req interface{}, info *grpc

parsedToken, err := server.jwt.ParseAndValidate(token)
if err != nil {
zerolog.Ctx(ctx).Info().Msgf("Error validating token %s", token)
// We don't want to _actually_ log a bearer token. JWTs will always be > 10 chars,
// but by logging the start, we can see if it's actually a JWT or something else.
shortToken := token
if len(token) > 10 {
shortToken = token[:10]
}
zerolog.Ctx(ctx).Info().Msgf("Error validating token %s", shortToken)
return nil, status.Errorf(codes.Unauthenticated, "invalid auth token: %v", err)
}

Expand Down

0 comments on commit 1ec9ffc

Please sign in to comment.