Skip to content

Commit

Permalink
Support JWT Claims content customization
Browse files Browse the repository at this point in the history
  • Loading branch information
vasayxtx committed Nov 25, 2024
1 parent b9fbec6 commit 9b4320a
Show file tree
Hide file tree
Showing 27 changed files with 975 additions and 586 deletions.
33 changes: 26 additions & 7 deletions auth.go
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ func NewJWTParser(cfg *Config, opts ...JWTParserOption) (JWTParser, error) {
ExpectedAudience: cfg.JWT.ExpectedAudience,
TrustedIssuerNotFoundFallback: options.trustedIssuerNotFoundFallback,
LoggerProvider: options.loggerProvider,
ClaimsTemplate: options.claimsTemplate,
}

if cfg.JWT.ClaimsCache.Enabled {
Expand Down Expand Up @@ -88,6 +89,7 @@ type jwtParserOptions struct {
loggerProvider func(ctx context.Context) log.FieldLogger
prometheusLibInstanceLabel string
trustedIssuerNotFoundFallback jwt.TrustedIssNotFoundFallback
claimsTemplate jwt.Claims
}

// JWTParserOption is an option for creating JWTParser.
Expand All @@ -114,6 +116,13 @@ func WithJWTParserTrustedIssuerNotFoundFallback(fallback jwt.TrustedIssNotFoundF
}
}

// WithJWTParserClaimsTemplate sets the claims template for JWTParser.
func WithJWTParserClaimsTemplate(claimsTemplate jwt.Claims) JWTParserOption {
return func(options *jwtParserOptions) {
options.claimsTemplate = claimsTemplate
}
}

// NewTokenIntrospector creates a new TokenIntrospector with the given configuration, token provider and scope filter.
// If cfg.Introspection.ClaimsCache.Enabled or cfg.Introspection.NegativeCache.Enabled is true,
// then idptoken.CachingIntrospector created, otherwise - idptoken.Introspector.
Expand All @@ -122,7 +131,7 @@ func WithJWTParserTrustedIssuerNotFoundFallback(fallback jwt.TrustedIssNotFoundF
func NewTokenIntrospector(
cfg *Config,
tokenProvider idptoken.IntrospectionTokenProvider,
scopeFilter []idptoken.IntrospectionScopeFilterAccessPolicy,
scopeFilter jwt.ScopeFilter,
opts ...TokenIntrospectorOption,
) (*idptoken.Introspector, error) {
options := tokenIntrospectorOptions{loggerProvider: middleware.GetLoggerFromContext}
Expand Down Expand Up @@ -159,6 +168,7 @@ func NewTokenIntrospector(
HTTPClient: idputil.MakeDefaultHTTPClient(cfg.HTTPClient.RequestTimeout, options.loggerProvider),
AccessTokenScope: cfg.Introspection.AccessTokenScope,
LoggerProvider: options.loggerProvider,
ResultTemplate: options.resultTemplate,
ScopeFilter: scopeFilter,
TrustedIssuerNotFoundFallback: options.trustedIssuerNotFoundFallback,
PrometheusLibInstanceLabel: options.prometheusLibInstanceLabel,
Expand Down Expand Up @@ -189,6 +199,7 @@ type tokenIntrospectorOptions struct {
loggerProvider func(ctx context.Context) log.FieldLogger
prometheusLibInstanceLabel string
trustedIssuerNotFoundFallback idptoken.TrustedIssNotFoundFallback
resultTemplate idptoken.IntrospectionResult
}

// TokenIntrospectorOption is an option for creating TokenIntrospector.
Expand Down Expand Up @@ -218,18 +229,26 @@ func WithTokenIntrospectorTrustedIssuerNotFoundFallback(
}
}

// WithTokenIntrospectorResultTemplate sets the result template for TokenIntrospector.
func WithTokenIntrospectorResultTemplate(resultTemplate idptoken.IntrospectionResult) TokenIntrospectorOption {
return func(options *tokenIntrospectorOptions) {
options.resultTemplate = resultTemplate
}
}

// Role is a representation of role which may be used for verifying access.
type Role struct {
Namespace string
Name string
}

// NewVerifyAccessByRolesInJWT creates a new function which may be used for verifying access by roles in JWT scope.
func NewVerifyAccessByRolesInJWT(roles ...Role) func(r *http.Request, claims *jwt.Claims) bool {
return func(_ *http.Request, claims *jwt.Claims) bool {
func NewVerifyAccessByRolesInJWT(roles ...Role) func(r *http.Request, claims jwt.Claims) bool {
return func(_ *http.Request, claims jwt.Claims) bool {
claimsScope := claims.GetScope()
for i := range roles {
for j := range claims.Scope {
if roles[i].Name == claims.Scope[j].Role && roles[i].Namespace == claims.Scope[j].ResourceNamespace {
for j := range claimsScope {
if roles[i].Name == claimsScope[j].Role && roles[i].Namespace == claimsScope[j].ResourceNamespace {
return true
}
}
Expand All @@ -239,8 +258,8 @@ func NewVerifyAccessByRolesInJWT(roles ...Role) func(r *http.Request, claims *jw
}

// NewVerifyAccessByRolesInJWTMaker creates a new function which may be used for verifying access by roles in JWT scope given a namespace.
func NewVerifyAccessByRolesInJWTMaker(namespace string) func(roleNames ...string) func(r *http.Request, claims *jwt.Claims) bool {
return func(roleNames ...string) func(r *http.Request, claims *jwt.Claims) bool {
func NewVerifyAccessByRolesInJWTMaker(namespace string) func(roleNames ...string) func(r *http.Request, claims jwt.Claims) bool {
return func(roleNames ...string) func(r *http.Request, claims jwt.Claims) bool {
roles := make([]Role, 0, len(roleNames))
for i := range roleNames {
roles = append(roles, Role{Namespace: namespace, Name: roleNames[i]})
Expand Down
66 changes: 33 additions & 33 deletions auth_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ func TestNewJWTParser(t *gotesting.T) {
require.NoError(t, idpSrv.StartAndWaitForReady(time.Second))
defer func() { _ = idpSrv.Shutdown(context.Background()) }()

claims := &jwt.Claims{
claims := &jwt.DefaultClaims{
RegisteredClaims: jwtgo.RegisteredClaims{
Issuer: idpSrv.URL(),
ExpiresAt: jwtgo.NewNumericDate(time.Now().Add(10 * time.Second)),
Expand All @@ -53,7 +53,7 @@ func TestNewJWTParser(t *gotesting.T) {
}
token := idptest.MustMakeTokenStringSignedWithTestKey(claims)

claimsWithNamedIssuer := &jwt.Claims{
claimsWithNamedIssuer := &jwt.DefaultClaims{
RegisteredClaims: jwtgo.RegisteredClaims{
Issuer: testIss,
ExpiresAt: jwtgo.NewNumericDate(time.Now().Add(10 * time.Second)),
Expand All @@ -66,7 +66,7 @@ func TestNewJWTParser(t *gotesting.T) {
name string
token string
cfg *Config
expectedClaims *jwt.Claims
expectedClaims jwt.Claims
checkFn func(t *gotesting.T, jwtParser JWTParser)
}{
{
Expand Down Expand Up @@ -149,7 +149,7 @@ func TestNewTokenIntrospector(t *gotesting.T) {
require.NoError(t, grpcIDPSrv.StartAndWaitForReady(time.Second))
defer func() { grpcIDPSrv.GracefulStop() }()

claims := &jwt.Claims{
claims := &jwt.DefaultClaims{
RegisteredClaims: jwtgo.RegisteredClaims{
Issuer: httpIDPSrv.URL(),
ExpiresAt: jwtgo.NewNumericDate(time.Now().Add(10 * time.Second)),
Expand All @@ -158,7 +158,7 @@ func TestNewTokenIntrospector(t *gotesting.T) {
}
token := idptest.MustMakeTokenStringSignedWithTestKey(claims)

claimsWithNamedIssuer := &jwt.Claims{
claimsWithNamedIssuer := &jwt.DefaultClaims{
RegisteredClaims: jwtgo.RegisteredClaims{
Issuer: testIss,
ExpiresAt: jwtgo.NewNumericDate(time.Now().Add(10 * time.Second)),
Expand All @@ -174,8 +174,8 @@ func TestNewTokenIntrospector(t *gotesting.T) {
Role: "admin",
ResourcePath: "resource-" + uuid.NewString(),
}}
httpServerIntrospector.SetResultForToken(opaqueToken, idptoken.IntrospectionResult{
Active: true, TokenType: idputil.TokenTypeBearer, Claims: jwt.Claims{Scope: opaqueTokenScope}})
httpServerIntrospector.SetResultForToken(opaqueToken, &idptoken.DefaultIntrospectionResult{
Active: true, TokenType: idputil.TokenTypeBearer, DefaultClaims: jwt.DefaultClaims{Scope: opaqueTokenScope}})
grpcServerIntrospector.SetResultForToken(opaqueToken, &pb.IntrospectTokenResponse{
Active: true, TokenType: idputil.TokenTypeBearer, Scope: []*pb.AccessTokenScope{
{
Expand All @@ -197,10 +197,10 @@ func TestNewTokenIntrospector(t *gotesting.T) {
name: "new token introspector, dynamic endpoint, trusted issuers map",
cfg: &Config{JWT: JWTConfig{TrustedIssuers: map[string]string{testIss: httpIDPSrv.URL()}}, Introspection: IntrospectionConfig{Enabled: true}},
token: tokenWithNamedIssuer,
expectedResult: idptoken.IntrospectionResult{
Active: true,
TokenType: idputil.TokenTypeBearer,
Claims: *claimsWithNamedIssuer,
expectedResult: &idptoken.DefaultIntrospectionResult{
Active: true,
TokenType: idputil.TokenTypeBearer,
DefaultClaims: *claimsWithNamedIssuer,
},
checkCacheFn: func(t *gotesting.T, introspector *idptoken.Introspector) {
require.Empty(t, introspector.ClaimsCache.Len(context.Background()))
Expand All @@ -211,10 +211,10 @@ func TestNewTokenIntrospector(t *gotesting.T) {
name: "new token introspector, dynamic endpoint, trusted issuer urls",
cfg: &Config{JWT: JWTConfig{TrustedIssuerURLs: []string{httpIDPSrv.URL()}}, Introspection: IntrospectionConfig{Enabled: true}},
token: token,
expectedResult: idptoken.IntrospectionResult{
Active: true,
TokenType: idputil.TokenTypeBearer,
Claims: *claims,
expectedResult: &idptoken.DefaultIntrospectionResult{
Active: true,
TokenType: idputil.TokenTypeBearer,
DefaultClaims: *claims,
},
checkCacheFn: func(t *gotesting.T, introspector *idptoken.Introspector) {
require.Empty(t, introspector.ClaimsCache.Len(context.Background()))
Expand All @@ -228,10 +228,10 @@ func TestNewTokenIntrospector(t *gotesting.T) {
Introspection: IntrospectionConfig{Enabled: true, ClaimsCache: IntrospectionCacheConfig{Enabled: true}},
},
token: tokenWithNamedIssuer,
expectedResult: idptoken.IntrospectionResult{
Active: true,
TokenType: idputil.TokenTypeBearer,
Claims: *claimsWithNamedIssuer,
expectedResult: &idptoken.DefaultIntrospectionResult{
Active: true,
TokenType: idputil.TokenTypeBearer,
DefaultClaims: *claimsWithNamedIssuer,
},
checkCacheFn: func(t *gotesting.T, introspector *idptoken.Introspector) {
require.Equal(t, 1, introspector.ClaimsCache.Len(context.Background()))
Expand All @@ -245,10 +245,10 @@ func TestNewTokenIntrospector(t *gotesting.T) {
Introspection: IntrospectionConfig{Enabled: true, ClaimsCache: IntrospectionCacheConfig{Enabled: true}},
},
token: token,
expectedResult: idptoken.IntrospectionResult{
Active: true,
TokenType: idputil.TokenTypeBearer,
Claims: *claims,
expectedResult: &idptoken.DefaultIntrospectionResult{
Active: true,
TokenType: idputil.TokenTypeBearer,
DefaultClaims: *claims,
},
checkCacheFn: func(t *gotesting.T, introspector *idptoken.Introspector) {
require.Equal(t, 1, introspector.ClaimsCache.Len(context.Background()))
Expand All @@ -265,10 +265,10 @@ func TestNewTokenIntrospector(t *gotesting.T) {
},
},
token: opaqueToken,
expectedResult: idptoken.IntrospectionResult{
Active: true,
TokenType: idputil.TokenTypeBearer,
Claims: jwt.Claims{Scope: opaqueTokenScope},
expectedResult: &idptoken.DefaultIntrospectionResult{
Active: true,
TokenType: idputil.TokenTypeBearer,
DefaultClaims: jwt.DefaultClaims{Scope: opaqueTokenScope},
},
checkCacheFn: func(t *gotesting.T, introspector *idptoken.Introspector) {
require.Equal(t, 1, introspector.ClaimsCache.Len(context.Background()))
Expand All @@ -291,10 +291,10 @@ func TestNewTokenIntrospector(t *gotesting.T) {
},
},
token: opaqueToken,
expectedResult: idptoken.IntrospectionResult{
Active: true,
TokenType: idputil.TokenTypeBearer,
Claims: jwt.Claims{Scope: opaqueTokenScope},
expectedResult: &idptoken.DefaultIntrospectionResult{
Active: true,
TokenType: idputil.TokenTypeBearer,
DefaultClaims: jwt.DefaultClaims{Scope: opaqueTokenScope},
},
checkCacheFn: func(t *gotesting.T, introspector *idptoken.Introspector) {
require.Empty(t, introspector.ClaimsCache.Len(context.Background()))
Expand Down Expand Up @@ -323,7 +323,7 @@ func TestNewTokenIntrospector(t *gotesting.T) {
}

func TestNewVerifyAccessByJWTRoles(t *gotesting.T) {
jwtClaims := &jwt.Claims{Scope: []jwt.AccessPolicy{
jwtClaims := &jwt.DefaultClaims{Scope: []jwt.AccessPolicy{
{ResourceNamespace: "policy_manager", Role: "admin"},
{ResourceNamespace: "scan_service", Role: "admin"},
{Role: "backup_user"},
Expand All @@ -347,7 +347,7 @@ func TestNewVerifyAccessByJWTRoles(t *gotesting.T) {
}

func TestNewVerifyAccessByJWTRolesMaker(t *gotesting.T) {
jwtClaims := &jwt.Claims{Scope: []jwt.AccessPolicy{
jwtClaims := &jwt.DefaultClaims{Scope: []jwt.AccessPolicy{
{ResourceNamespace: "policy_manager", Role: "admin"},
{ResourceNamespace: "scan_service", Role: "admin"},
{Role: "backup_user"},
Expand Down
5 changes: 3 additions & 2 deletions examples/authn-middleware/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -48,8 +48,9 @@ func runApp() error {

srvMux := http.NewServeMux()
srvMux.Handle("/", authNMw(http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
jwtClaims := authkit.GetJWTClaimsFromContext(r.Context()) // get JWT claims from the request context
_, _ = rw.Write([]byte(fmt.Sprintf("Hello, %s", jwtClaims.Subject)))
jwtClaims := authkit.GetJWTClaimsFromContext(r.Context()) // get JWT claims from the request context
tokenSubject, _ := jwtClaims.GetSubject() // error is always nil here unless custom claims are used
_, _ = rw.Write([]byte(fmt.Sprintf("Hello, %s", tokenSubject))) // use the subject to greet the user
})))
if err = http.ListenAndServe(":8080", middleware.Logging(logger)(srvMux)); err != nil && !errors.Is(err, http.ErrServerClosed) {
return fmt.Errorf("listen and HTTP server: %w", err)
Expand Down
17 changes: 9 additions & 8 deletions examples/idp-test-server/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -74,16 +74,17 @@ type demoTokenIntrospector struct {

func (dti *demoTokenIntrospector) IntrospectToken(r *http.Request, token string) (idptoken.IntrospectionResult, error) {
if bearerToken := authkit.GetBearerTokenFromRequest(r); bearerToken != "access-token-with-introspection-permission" {
return idptoken.IntrospectionResult{}, idptest.ErrUnauthorized
return nil, idptest.ErrUnauthorized
}
claims, err := dti.jwtParser.Parse(r.Context(), token)
if err != nil {
return idptoken.IntrospectionResult{Active: false}, nil
return &idptoken.DefaultIntrospectionResult{Active: false}, nil
}
if claims.Subject == "admin2" {
claims.Scope = append(claims.Scope, jwt.AccessPolicy{ResourceNamespace: "my_service", Role: "admin"})
defClaims := claims.(*jwt.DefaultClaims) // type assertion is safe here since we don't use custom claims
if defClaims.Subject == "admin2" {
defClaims.Scope = append(defClaims.Scope, jwt.AccessPolicy{ResourceNamespace: "my_service", Role: "admin"})
}
return idptoken.IntrospectionResult{Active: true, TokenType: "Bearer", Claims: *claims}, nil
return &idptoken.DefaultIntrospectionResult{Active: true, TokenType: "Bearer", DefaultClaims: *defClaims}, nil
}

type demoClaimsProvider struct {
Expand All @@ -92,9 +93,9 @@ type demoClaimsProvider struct {
func (dcp *demoClaimsProvider) Provide(r *http.Request) (jwt.Claims, error) {
username, password, ok := r.BasicAuth()
if !ok {
return jwt.Claims{}, idptest.ErrUnauthorized
return nil, idptest.ErrUnauthorized
}
var claims jwt.Claims
claims := &jwt.DefaultClaims{}
switch {
case username == "user" && password == "user-pwd":
claims.Subject = "user"
Expand All @@ -104,7 +105,7 @@ func (dcp *demoClaimsProvider) Provide(r *http.Request) (jwt.Claims, error) {
case username == "admin2" && password == "admin2-pwd":
claims.Subject = "admin2"
default:
return jwt.Claims{}, idptest.ErrUnauthorized
return nil, idptest.ErrUnauthorized
}
return claims, nil
}
18 changes: 10 additions & 8 deletions examples/token-introspection/grpc-server/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -88,23 +88,25 @@ func (dti *demoGRPCTokenIntrospector) IntrospectToken(
if authMeta != "Bearer "+accessTokenWithIntrospectionPermission {
return nil, idptest.ErrUnauthorized
}

claims, err := dti.jwtParser.Parse(ctx, req.Token)
if err != nil {
return &pb.IntrospectTokenResponse{Active: false}, nil
}
if claims.Subject == "admin2" {
claims.Scope = append(claims.Scope, jwt.AccessPolicy{ResourceNamespace: "my_service", Role: "admin"})
defClaims := claims.(*jwt.DefaultClaims) // type assertion is safe here since we don't use custom claims
if defClaims.Subject == "admin2" {
defClaims.Scope = append(claims.GetScope(), jwt.AccessPolicy{ResourceNamespace: "my_service", Role: "admin"})
}
resp := &pb.IntrospectTokenResponse{
Active: true,
TokenType: "Bearer",
Sub: claims.Subject,
Exp: claims.ExpiresAt.Unix(),
Aud: claims.Audience,
Iss: claims.Issuer,
Scope: make([]*pb.AccessTokenScope, 0, len(claims.Scope)),
Sub: defClaims.Subject,
Exp: defClaims.ExpiresAt.Unix(),
Aud: defClaims.Audience,
Iss: defClaims.Issuer,
Scope: make([]*pb.AccessTokenScope, 0, len(defClaims.Scope)),
}
for _, policy := range claims.Scope {
for _, policy := range defClaims.Scope {
resp.Scope = append(resp.Scope, &pb.AccessTokenScope{
ResourceNamespace: policy.ResourceNamespace,
RoleName: policy.Role,
Expand Down
12 changes: 7 additions & 5 deletions examples/token-introspection/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ import (
"github.com/acronis/go-appkit/log"

"github.com/acronis/go-authkit"
"github.com/acronis/go-authkit/idptoken"
"github.com/acronis/go-authkit/jwt"
)

const (
Expand Down Expand Up @@ -49,8 +49,8 @@ func runApp() error {
}

// Create token introspector.
introspectionScopeFilter := []idptoken.IntrospectionScopeFilterAccessPolicy{{ResourceNamespace: serviceAccessPolicy}}
tokenIntrospector, err := authkit.NewTokenIntrospector(cfg.Auth, introspectionTokenProvider{}, introspectionScopeFilter)
tokenIntrospector, err := authkit.NewTokenIntrospector(cfg.Auth,
introspectionTokenProvider{}, jwt.ScopeFilter{{ResourceNamespace: serviceAccessPolicy}})
if err != nil {
return fmt.Errorf("create token introspector: %w", err)
}
Expand Down Expand Up @@ -79,12 +79,14 @@ func runApp() error {
// "/" endpoint will be available for all authenticated users.
srvMux.Handle("/", authNMw(http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
jwtClaims := authkit.GetJWTClaimsFromContext(r.Context()) // get JWT claims from the request context
_, _ = rw.Write([]byte(fmt.Sprintf("Hello, %s", jwtClaims.Subject)))
tokenSubject, _ := jwtClaims.GetSubject() // error is always nil here unless custom claims are used
_, _ = rw.Write([]byte(fmt.Sprintf("Hello, %s", tokenSubject)))
})))
// "/admin" endpoint will be available only for users with the "admin" role.
srvMux.Handle("/admin", authZMw(http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
jwtClaims := authkit.GetJWTClaimsFromContext(r.Context()) // Get JWT claims from the request context.
_, _ = rw.Write([]byte(fmt.Sprintf("Hi, %s", jwtClaims.Subject)))
tokenSubject, _ := jwtClaims.GetSubject() // error is always nil here unless custom claims are used
_, _ = rw.Write([]byte(fmt.Sprintf("Hi, %s", tokenSubject)))
})))
if err = http.ListenAndServe(":8080", middleware.Logging(logger)(srvMux)); err != nil && !errors.Is(err, http.ErrServerClosed) {
return fmt.Errorf("listen and HTTP server: %w", err)
Expand Down
2 changes: 1 addition & 1 deletion idptest/http_server_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,7 @@ func TestHTTPServerDefault(t *gotesting.T) {
respBody, err = io.ReadAll(resp.Body)
require.NoError(t, err)
require.NoError(t, resp.Body.Close())
var introspectionRespData idptoken.IntrospectionResult
var introspectionRespData idptoken.DefaultIntrospectionResult
require.NoError(t, json.Unmarshal(respBody, &introspectionRespData))
require.True(t, introspectionRespData.Active)
require.Equal(t, idpSrv.URL(), introspectionRespData.Issuer)
Expand Down
Loading

0 comments on commit 9b4320a

Please sign in to comment.