diff --git a/config.go b/config.go index 065067f..1327b4c 100644 --- a/config.go +++ b/config.go @@ -20,29 +20,32 @@ import ( ) const ( - cfgKeyHTTPClientRequestTimeout = "auth.httpClient.requestTimeout" - cfgKeyGRPCClientRequestTimeout = "auth.grpcClient.requestTimeout" - cfgKeyJWTTrustedIssuers = "auth.jwt.trustedIssuers" - cfgKeyJWTTrustedIssuerURLs = "auth.jwt.trustedIssuerUrls" - cfgKeyJWTRequireAudience = "auth.jwt.requireAudience" - cfgKeyJWTExceptedAudience = "auth.jwt.expectedAudience" - cfgKeyJWTClaimsCacheEnabled = "auth.jwt.claimsCache.enabled" - cfgKeyJWTClaimsCacheMaxEntries = "auth.jwt.claimsCache.maxEntries" - cfgKeyJWKSCacheUpdateMinInterval = "auth.jwks.cache.updateMinInterval" - cfgKeyIntrospectionEnabled = "auth.introspection.enabled" - cfgKeyIntrospectionEndpoint = "auth.introspection.endpoint" - cfgKeyIntrospectionGRPCEndpoint = "auth.introspection.grpc.endpoint" - cfgKeyIntrospectionGRPCTLSEnabled = "auth.introspection.grpc.tls.enabled" - cfgKeyIntrospectionGRPCTLSCACert = "auth.introspection.grpc.tls.caCert" - cfgKeyIntrospectionGRPCTLSClientCert = "auth.introspection.grpc.tls.clientCert" - cfgKeyIntrospectionGRPCTLSClientKey = "auth.introspection.grpc.tls.clientKey" - cfgKeyIntrospectionAccessTokenScope = "auth.introspection.accessTokenScope" // nolint:gosec // false positive - cfgKeyIntrospectionClaimsCacheEnabled = "auth.introspection.claimsCache.enabled" - cfgKeyIntrospectionClaimsCacheMaxEntries = "auth.introspection.claimsCache.maxEntries" - cfgKeyIntrospectionClaimsCacheTTL = "auth.introspection.claimsCache.ttl" - cfgKeyIntrospectionNegativeCacheEnabled = "auth.introspection.negativeCache.enabled" - cfgKeyIntrospectionNegativeCacheMaxEntries = "auth.introspection.negativeCache.maxEntries" - cfgKeyIntrospectionNegativeCacheTTL = "auth.introspection.negativeCache.ttl" + cfgKeyHTTPClientRequestTimeout = "auth.httpClient.requestTimeout" + cfgKeyGRPCClientRequestTimeout = "auth.grpcClient.requestTimeout" + cfgKeyJWTTrustedIssuers = "auth.jwt.trustedIssuers" + cfgKeyJWTTrustedIssuerURLs = "auth.jwt.trustedIssuerUrls" + cfgKeyJWTRequireAudience = "auth.jwt.requireAudience" + cfgKeyJWTExceptedAudience = "auth.jwt.expectedAudience" + cfgKeyJWTClaimsCacheEnabled = "auth.jwt.claimsCache.enabled" + cfgKeyJWTClaimsCacheMaxEntries = "auth.jwt.claimsCache.maxEntries" + cfgKeyJWKSCacheUpdateMinInterval = "auth.jwks.cache.updateMinInterval" + cfgKeyIntrospectionEnabled = "auth.introspection.enabled" + cfgKeyIntrospectionEndpoint = "auth.introspection.endpoint" + cfgKeyIntrospectionGRPCEndpoint = "auth.introspection.grpc.endpoint" + cfgKeyIntrospectionGRPCTLSEnabled = "auth.introspection.grpc.tls.enabled" + cfgKeyIntrospectionGRPCTLSCACert = "auth.introspection.grpc.tls.caCert" + cfgKeyIntrospectionGRPCTLSClientCert = "auth.introspection.grpc.tls.clientCert" + cfgKeyIntrospectionGRPCTLSClientKey = "auth.introspection.grpc.tls.clientKey" + cfgKeyIntrospectionAccessTokenScope = "auth.introspection.accessTokenScope" // nolint:gosec // false positive + cfgKeyIntrospectionClaimsCacheEnabled = "auth.introspection.claimsCache.enabled" + cfgKeyIntrospectionClaimsCacheMaxEntries = "auth.introspection.claimsCache.maxEntries" + cfgKeyIntrospectionClaimsCacheTTL = "auth.introspection.claimsCache.ttl" + cfgKeyIntrospectionNegativeCacheEnabled = "auth.introspection.negativeCache.enabled" + cfgKeyIntrospectionNegativeCacheMaxEntries = "auth.introspection.negativeCache.maxEntries" + cfgKeyIntrospectionNegativeCacheTTL = "auth.introspection.negativeCache.ttl" + cfgKeyIntrospectionEndpointDiscoveryCacheEnabled = "auth.introspection.endpointDiscoveryCache.enabled" + cfgKeyIntrospectionEndpointDiscoveryCacheMaxEntries = "auth.introspection.endpointDiscoveryCache.maxEntries" + cfgKeyIntrospectionEndpointDiscoveryCacheTTL = "auth.introspection.endpointDiscoveryCache.ttl" ) // JWTConfig is configuration of how JWT will be verified. @@ -68,8 +71,9 @@ type IntrospectionConfig struct { Endpoint string AccessTokenScope []string - ClaimsCache IntrospectionCacheConfig - NegativeCache IntrospectionCacheConfig + ClaimsCache IntrospectionCacheConfig + NegativeCache IntrospectionCacheConfig + EndpointDiscoveryCache IntrospectionCacheConfig GRPC IntrospectionGRPCConfig } @@ -145,12 +149,19 @@ func (c *Config) KeyPrefix() string { func (c *Config) SetProviderDefaults(dp config.DataProvider) { dp.SetDefault(cfgKeyHTTPClientRequestTimeout, idputil.DefaultHTTPRequestTimeout.String()) dp.SetDefault(cfgKeyGRPCClientRequestTimeout, idptoken.DefaultGRPCClientRequestTimeout.String()) + dp.SetDefault(cfgKeyJWTClaimsCacheMaxEntries, jwt.DefaultClaimsCacheMaxEntries) dp.SetDefault(cfgKeyJWKSCacheUpdateMinInterval, jwks.DefaultCacheUpdateMinInterval.String()) + dp.SetDefault(cfgKeyIntrospectionClaimsCacheMaxEntries, idptoken.DefaultIntrospectionClaimsCacheMaxEntries) dp.SetDefault(cfgKeyIntrospectionClaimsCacheTTL, idptoken.DefaultIntrospectionClaimsCacheTTL.String()) + dp.SetDefault(cfgKeyIntrospectionNegativeCacheMaxEntries, idptoken.DefaultIntrospectionNegativeCacheMaxEntries) dp.SetDefault(cfgKeyIntrospectionNegativeCacheTTL, idptoken.DefaultIntrospectionNegativeCacheTTL.String()) + + dp.SetDefault(cfgKeyIntrospectionEndpointDiscoveryCacheEnabled, true) + dp.SetDefault(cfgKeyIntrospectionEndpointDiscoveryCacheMaxEntries, idptoken.DefaultIntrospectionEndpointDiscoveryCacheMaxEntries) + dp.SetDefault(cfgKeyIntrospectionEndpointDiscoveryCacheTTL, idptoken.DefaultIntrospectionEndpointDiscoveryCacheTTL.String()) } // Set sets auth configuration values from config.DataProvider. @@ -279,5 +290,25 @@ func (c *Config) setIntrospectionConfig(dp config.DataProvider) error { return err } + // OpenID configuration cache + if c.Introspection.EndpointDiscoveryCache.Enabled, err = dp.GetBool( + cfgKeyIntrospectionEndpointDiscoveryCacheEnabled, + ); err != nil { + return err + } + if c.Introspection.EndpointDiscoveryCache.MaxEntries, err = dp.GetInt( + cfgKeyIntrospectionEndpointDiscoveryCacheMaxEntries, + ); err != nil { + return err + } + if c.Introspection.EndpointDiscoveryCache.MaxEntries < 0 { + return dp.WrapKeyErr(cfgKeyIntrospectionEndpointDiscoveryCacheMaxEntries, fmt.Errorf("max entries should be non-negative")) + } + if c.Introspection.EndpointDiscoveryCache.TTL, err = dp.GetDuration( + cfgKeyIntrospectionEndpointDiscoveryCacheTTL, + ); err != nil { + return err + } + return nil } diff --git a/config_test.go b/config_test.go index 672103b..cb82be3 100644 --- a/config_test.go +++ b/config_test.go @@ -46,13 +46,17 @@ auth: enabled: true endpoint: https://my-idp.com/introspect claimsCache: - enabled: true - maxEntries: 42000 - ttl: 42s + enabled: true + maxEntries: 42000 + ttl: 42s negativeCache: - enabled: true - maxEntries: 777 - ttl: 77s + enabled: true + maxEntries: 777 + ttl: 77m + endpointDiscoveryCache: + enabled: true + maxEntries: 73 + ttl: 7h accessTokenScope: - token_introspector grpc: @@ -98,7 +102,12 @@ auth: NegativeCache: IntrospectionCacheConfig{ Enabled: true, MaxEntries: 777, - TTL: time.Second * 77, + TTL: time.Minute * 77, + }, + EndpointDiscoveryCache: IntrospectionCacheConfig{ + Enabled: true, + MaxEntries: 73, + TTL: time.Hour * 7, }, AccessTokenScope: []string{"token_introspector"}, GRPC: IntrospectionGRPCConfig{ diff --git a/idptest/http_server.go b/idptest/http_server.go index a756f33..060dd75 100644 --- a/idptest/http_server.go +++ b/idptest/http_server.go @@ -205,10 +205,9 @@ func NewHTTPServer(options ...HTTPServerOption) *HTTPServer { }) s.Router = http.NewServeMux() - s.Router.Handle(s.paths.OpenIDConfiguration, s.OpenIDConfigurationHandler) - s.Router.Handle(s.paths.JWKS, s.KeysHandler) - s.Router.Handle(s.paths.Token, s.TokenHandler) - s.Router.Handle(s.paths.TokenIntrospection, s.TokenIntrospectionHandler) + for path, handler := range s.allHandlers() { + s.Router.Handle(path, handler) + } // nolint:gosec // This server is used for testing purposes only. s.Server = &http.Server{Handler: s.Router} @@ -256,8 +255,37 @@ func (s *HTTPServer) StartAndWaitForReady(timeout time.Duration) error { return testutil.WaitListeningServer(s.addr.Load().(string), timeout) } +// ServedCounts returns the number of requests served by each handler. +func (s *HTTPServer) ServedCounts() map[string]uint64 { + counts := make(map[string]uint64) + for path, handler := range s.allHandlers() { + if counter, ok := handler.(interface{ ServedCount() uint64 }); ok { + counts[path] = counter.ServedCount() + } + } + return counts +} + +// ResetServedCounts resets the number of requests served by each handler. +func (s *HTTPServer) ResetServedCounts() { + for _, h := range s.allHandlers() { + if r, ok := h.(interface{ ResetServedCount() }); ok { + r.ResetServedCount() + } + } +} + func (s *HTTPServer) makeJWTParser() *jwt.Parser { p := jwt.NewParser(jwks.NewClient()) _ = p.AddTrustedIssuerURL(s.URL()) return p } + +func (s *HTTPServer) allHandlers() map[string]http.Handler { + return map[string]http.Handler{ + s.paths.JWKS: s.KeysHandler, + s.paths.OpenIDConfiguration: s.OpenIDConfigurationHandler, + s.paths.Token: s.TokenHandler, + s.paths.TokenIntrospection: s.TokenIntrospectionHandler, + } +} diff --git a/idptest/jwks_handler.go b/idptest/jwks_handler.go index 32256b0..785c45b 100644 --- a/idptest/jwks_handler.go +++ b/idptest/jwks_handler.go @@ -95,6 +95,11 @@ func (h *JWKSHandler) ServedCount() uint64 { return h.servedCount.Load() } +// ResetServedCount resets the number of times JWKS handler has been served. +func (h *JWKSHandler) ResetServedCount() { + h.servedCount.Store(0) +} + type PublicJWKSResponse struct { Keys []PublicJWK `json:"keys"` } diff --git a/idptest/openid_configuration_handler.go b/idptest/openid_configuration_handler.go index 6dd8a18..a102516 100644 --- a/idptest/openid_configuration_handler.go +++ b/idptest/openid_configuration_handler.go @@ -46,6 +46,11 @@ func (h *OpenIDConfigurationHandler) ServedCount() uint64 { return h.servedCount.Load() } +// ResetServedCount resets the number of times the handler has been served. +func (h *OpenIDConfigurationHandler) ResetServedCount() { + h.servedCount.Store(0) +} + // OpenIDConfigurationResponse is a response for .well-known/openid-configuration endpoint. type OpenIDConfigurationResponse struct { TokenEndpoint string `json:"token_endpoint"` diff --git a/idptest/token_handlers.go b/idptest/token_handlers.go index 87c0f8a..03d845d 100644 --- a/idptest/token_handlers.go +++ b/idptest/token_handlers.go @@ -87,6 +87,11 @@ func (h *TokenHandler) ServedCount() uint64 { return h.servedCount.Load() } +// ResetServedCount resets the number of times the handler has been served. +func (h *TokenHandler) ResetServedCount() { + h.servedCount.Store(0) +} + // TokenResponse is a response for POST /idp/token endpoint. type TokenResponse struct { AccessToken string `json:"access_token"` @@ -149,3 +154,8 @@ func (h *TokenIntrospectionHandler) ServeHTTP(rw http.ResponseWriter, r *http.Re func (h *TokenIntrospectionHandler) ServedCount() uint64 { return h.servedCount.Load() } + +// ResetServedCount resets the number of times the handler has been served. +func (h *TokenIntrospectionHandler) ResetServedCount() { + h.servedCount.Store(0) +} diff --git a/idptoken/introspector.go b/idptoken/introspector.go index c64c23c..b202093 100644 --- a/idptoken/introspector.go +++ b/idptoken/introspector.go @@ -47,7 +47,13 @@ const ( DefaultIntrospectionNegativeCacheMaxEntries = 1000 // DefaultIntrospectionNegativeCacheTTL is a default time-to-live for the negative cache. - DefaultIntrospectionNegativeCacheTTL = 10 * time.Minute + DefaultIntrospectionNegativeCacheTTL = 1 * time.Hour + + // DefaultIntrospectionEndpointDiscoveryCacheMaxEntries is a default maximum number of entries in the endpoint discovery cache. + DefaultIntrospectionEndpointDiscoveryCacheMaxEntries = 1000 + + // DefaultIntrospectionEndpointDiscoveryCacheTTL is a default time-to-live for the endpoint discovery cache. + DefaultIntrospectionEndpointDiscoveryCacheTTL = 1 * time.Hour ) // ErrTokenNotIntrospectable is returned when token is not introspectable. @@ -119,6 +125,9 @@ type IntrospectorOpts struct { // NegativeCache is a configuration of how negative cache will be used. NegativeCache IntrospectorCacheOpts + + // EndpointDiscoveryCache is a configuration of how endpoint discovery cache will be used. + EndpointDiscoveryCache IntrospectorCacheOpts } // IntrospectorCacheOpts is a configuration of how cache will be used. @@ -144,6 +153,9 @@ type Introspector struct { // NegativeCache is a cache for storing info about tokens that are not active. NegativeCache IntrospectionNegativeCache + // EndpointDiscoveryCache is a cache for storing OpenID configuration. + EndpointDiscoveryCache IntrospectionEndpointDiscoveryCache + accessTokenProvider IntrospectionTokenProvider accessTokenProviderInvalidatedAt atomic.Value accessTokenScope []string @@ -162,8 +174,9 @@ type Introspector struct { promMetrics *metrics.PrometheusMetrics - claimsCacheTTL time.Duration - negativeCacheTTL time.Duration + claimsCacheTTL time.Duration + negativeCacheTTL time.Duration + endpointDiscoveryCacheTTL time.Duration } // IntrospectionResult is a struct for introspection result. @@ -205,6 +218,10 @@ func NewIntrospector(tokenProvider IntrospectionTokenProvider) (*Introspector, e // NewIntrospectorWithOpts creates a new Introspector with the given token provider and options. // See IntrospectorOpts for more details. func NewIntrospectorWithOpts(accessTokenProvider IntrospectionTokenProvider, opts IntrospectorOpts) (*Introspector, error) { + if accessTokenProvider == nil { + return nil, errors.New("access token provider is required") + } + if opts.HTTPClient == nil { opts.HTTPClient = idputil.MakeDefaultHTTPClient(idputil.DefaultHTTPRequestTimeout, opts.LoggerProvider) } @@ -217,38 +234,17 @@ func NewIntrospectorWithOpts(accessTokenProvider IntrospectionTokenProvider, opt promMetrics := metrics.GetPrometheusMetrics(opts.PrometheusLibInstanceLabel, tokenIntrospectorPromSource) - // Building claims cache if needed. - var claimsCache IntrospectionClaimsCache = &disabledIntrospectionClaimsCache{} - if opts.ClaimsCache.Enabled { - if opts.ClaimsCache.TTL == 0 { - opts.ClaimsCache.TTL = DefaultIntrospectionClaimsCacheTTL - } - if opts.ClaimsCache.MaxEntries == 0 { - opts.ClaimsCache.MaxEntries = DefaultIntrospectionClaimsCacheMaxEntries - } - cache, err := lrucache.New[[sha256.Size]byte, IntrospectionClaimsCacheItem]( - opts.ClaimsCache.MaxEntries, promMetrics.TokenClaimsCache) - if err != nil { - return nil, err - } - claimsCache = &IntrospectionLRUCache[[sha256.Size]byte, IntrospectionClaimsCacheItem]{cache} + claimsCache := makeIntrospectionClaimsCache(opts.ClaimsCache, promMetrics) + if opts.ClaimsCache.TTL == 0 { + opts.ClaimsCache.TTL = DefaultIntrospectionClaimsCacheTTL } - - // Building negative cache if needed. - var negativeCache IntrospectionNegativeCache = &disabledIntrospectionNegativeCache{} - if opts.NegativeCache.Enabled { - if opts.NegativeCache.TTL == 0 { - opts.NegativeCache.TTL = DefaultIntrospectionNegativeCacheTTL - } - if opts.NegativeCache.MaxEntries == 0 { - opts.NegativeCache.MaxEntries = DefaultIntrospectionNegativeCacheMaxEntries - } - cache, err := lrucache.New[[sha256.Size]byte, IntrospectionNegativeCacheItem]( - opts.NegativeCache.MaxEntries, promMetrics.TokenNegativeCache) - if err != nil { - return nil, err - } - negativeCache = &IntrospectionLRUCache[[sha256.Size]byte, IntrospectionNegativeCacheItem]{cache} + negativeCache := makeIntrospectionNegativeCache(opts.NegativeCache, promMetrics) + if opts.NegativeCache.TTL == 0 { + opts.NegativeCache.TTL = DefaultIntrospectionNegativeCacheTTL + } + endpointDiscoveryCache := makeIntrospectionEndpointDiscoveryCache(opts.EndpointDiscoveryCache, promMetrics) + if opts.EndpointDiscoveryCache.TTL == 0 { + opts.EndpointDiscoveryCache.TTL = DefaultIntrospectionEndpointDiscoveryCacheTTL } return &Introspector{ @@ -268,6 +264,8 @@ func NewIntrospectorWithOpts(accessTokenProvider IntrospectionTokenProvider, opt claimsCacheTTL: opts.ClaimsCache.TTL, NegativeCache: negativeCache, negativeCacheTTL: opts.NegativeCache.TTL, + EndpointDiscoveryCache: endpointDiscoveryCache, + endpointDiscoveryCacheTTL: opts.EndpointDiscoveryCache.TTL, }, nil } @@ -502,8 +500,17 @@ func (i *Introspector) makeIntrospectFuncGRPC() introspectFunc { } func (i *Introspector) getWellKnownIntrospectionEndpointURL(ctx context.Context, issuerURL string) (string, error) { + cacheKey := sha256.Sum256( + unsafe.Slice(unsafe.StringData(issuerURL), len(issuerURL))) // nolint:gosec // prevent redundant slice copying + + if c, ok := i.EndpointDiscoveryCache.Get(ctx, cacheKey); ok { + if c.CreatedAt.Add(i.endpointDiscoveryCacheTTL).After(time.Now()) { + return c.IntrospectionEndpoint, nil + } + } + logger := idputil.GetLoggerFromProvider(ctx, i.loggerProvider) - openIDCfgURL := strings.TrimSuffix(issuerURL, "/") + wellKnownPath + openIDCfgURL := strings.TrimSuffix(issuerURL, "/") + idputil.OpenIDConfigurationPath openIDCfg, err := idputil.GetOpenIDConfiguration( ctx, i.HTTPClient, openIDCfgURL, nil, logger, i.promMetrics) if err != nil { @@ -512,6 +519,12 @@ func (i *Introspector) getWellKnownIntrospectionEndpointURL(ctx context.Context, if openIDCfg.IntrospectionEndpoint == "" { return "", fmt.Errorf("no introspection endpoint URL found on %s", openIDCfgURL) } + + i.EndpointDiscoveryCache.Add(ctx, cacheKey, IntrospectionEndpointDiscoveryCacheItem{ + IntrospectionEndpoint: openIDCfg.IntrospectionEndpoint, + CreatedAt: time.Now(), + }) + return openIDCfg.IntrospectionEndpoint, nil } @@ -611,6 +624,18 @@ type IntrospectionClaimsCache interface { Len(ctx context.Context) int } +func makeIntrospectionClaimsCache(opts IntrospectorCacheOpts, promMetrics *metrics.PrometheusMetrics) IntrospectionClaimsCache { + if !opts.Enabled { + return &disabledIntrospectionClaimsCache{} + } + if opts.MaxEntries <= 0 { + opts.MaxEntries = DefaultIntrospectionClaimsCacheMaxEntries + } + cache, _ := lrucache.New[[sha256.Size]byte, IntrospectionClaimsCacheItem]( + opts.MaxEntries, promMetrics.TokenClaimsCache) // error is always nil here + return &IntrospectionLRUCache[[sha256.Size]byte, IntrospectionClaimsCacheItem]{cache} +} + type IntrospectionNegativeCacheItem struct { CreatedAt time.Time } @@ -622,6 +647,50 @@ type IntrospectionNegativeCache interface { Len(ctx context.Context) int } +func makeIntrospectionNegativeCache(opts IntrospectorCacheOpts, promMetrics *metrics.PrometheusMetrics) IntrospectionNegativeCache { + if !opts.Enabled { + return &disabledIntrospectionNegativeCache{} + } + if opts.TTL == 0 { + opts.TTL = DefaultIntrospectionNegativeCacheTTL + } + if opts.MaxEntries <= 0 { + opts.MaxEntries = DefaultIntrospectionNegativeCacheMaxEntries + } + cache, _ := lrucache.New[[sha256.Size]byte, IntrospectionNegativeCacheItem]( + opts.MaxEntries, promMetrics.TokenNegativeCache) // error is always nil here + return &IntrospectionLRUCache[[sha256.Size]byte, IntrospectionNegativeCacheItem]{cache} +} + +type IntrospectionEndpointDiscoveryCacheItem struct { + IntrospectionEndpoint string + CreatedAt time.Time +} + +type IntrospectionEndpointDiscoveryCache interface { + Get(ctx context.Context, key [sha256.Size]byte) (IntrospectionEndpointDiscoveryCacheItem, bool) + Add(ctx context.Context, key [sha256.Size]byte, value IntrospectionEndpointDiscoveryCacheItem) + Purge(ctx context.Context) + Len(ctx context.Context) int +} + +func makeIntrospectionEndpointDiscoveryCache( + opts IntrospectorCacheOpts, promMetrics *metrics.PrometheusMetrics, +) IntrospectionEndpointDiscoveryCache { + if !opts.Enabled { + return &disabledIntrospectionEndpointDiscoveryCache{} + } + if opts.TTL == 0 { + opts.TTL = DefaultIntrospectionEndpointDiscoveryCacheTTL + } + if opts.MaxEntries <= 0 { + opts.MaxEntries = DefaultIntrospectionEndpointDiscoveryCacheMaxEntries + } + cache, _ := lrucache.New[[sha256.Size]byte, IntrospectionEndpointDiscoveryCacheItem]( + opts.MaxEntries, promMetrics.EndpointDiscoveryCache) // error is always nil here + return &IntrospectionLRUCache[[sha256.Size]byte, IntrospectionEndpointDiscoveryCacheItem]{cache} +} + type IntrospectionLRUCache[K comparable, V any] struct { cache *lrucache.LRUCache[K, V] } @@ -661,3 +730,17 @@ func (c *disabledIntrospectionNegativeCache) Add(ctx context.Context, key [sha25 } func (c *disabledIntrospectionNegativeCache) Purge(ctx context.Context) {} func (c *disabledIntrospectionNegativeCache) Len(ctx context.Context) int { return 0 } + +type disabledIntrospectionEndpointDiscoveryCache struct{} + +func (c *disabledIntrospectionEndpointDiscoveryCache) Get( + ctx context.Context, key [sha256.Size]byte, +) (IntrospectionEndpointDiscoveryCacheItem, bool) { + return IntrospectionEndpointDiscoveryCacheItem{}, false +} +func (c *disabledIntrospectionEndpointDiscoveryCache) Add( + ctx context.Context, key [sha256.Size]byte, value IntrospectionEndpointDiscoveryCacheItem, +) { +} +func (c *disabledIntrospectionEndpointDiscoveryCache) Purge(ctx context.Context) {} +func (c *disabledIntrospectionEndpointDiscoveryCache) Len(ctx context.Context) int { return 0 } diff --git a/idptoken/introspector_test.go b/idptoken/introspector_test.go index 65868b3..8025a33 100644 --- a/idptoken/introspector_test.go +++ b/idptoken/introspector_test.go @@ -22,7 +22,6 @@ import ( "github.com/acronis/go-authkit/idptoken/pb" "github.com/acronis/go-authkit/internal/idputil" "github.com/acronis/go-authkit/internal/testing" - "github.com/acronis/go-authkit/jwks" "github.com/acronis/go-authkit/jwt" ) @@ -47,11 +46,6 @@ func TestIntrospector_IntrospectToken(t *gotesting.T) { require.NoError(t, err) defer func() { require.NoError(t, grpcClient.Close()) }() - jwtParser := jwt.NewParser(jwks.NewClient()) - require.NoError(t, jwtParser.AddTrustedIssuerURL(httpIDPSrv.URL())) - httpServerIntrospector.JWTParser = jwtParser - grpcServerIntrospector.JWTParser = jwtParser - jwtScopeToGRPC := func(jwtScope []jwt.AccessPolicy) []*pb.AccessTokenScope { grpcScope := make([]*pb.AccessTokenScope, len(jwtScope)) for i, scope := range jwtScope { @@ -65,17 +59,35 @@ func TestIntrospector_IntrospectToken(t *gotesting.T) { return grpcScope } - jwtExpiresAtInFuture := jwtgo.NewNumericDate(time.Now().Add(time.Hour)) - jwtIssuer := httpIDPSrv.URL() - jwtSubject := uuid.NewString() - jwtID := uuid.NewString() - jwtScope := []jwt.AccessPolicy{{ + // Expired JWT + expiredJWT := idptest.MustMakeTokenStringSignedWithTestKey(jwt.Claims{ + RegisteredClaims: jwtgo.RegisteredClaims{ + Issuer: httpIDPSrv.URL(), + Subject: uuid.NewString(), + ID: uuid.NewString(), + ExpiresAt: jwtgo.NewNumericDate(time.Now().Add(-time.Hour)), + }, + }) + httpServerIntrospector.SetResultForToken(expiredJWT, idptoken.IntrospectionResult{Active: false}) + + // Valid JWT with scope + validJWTScope := []jwt.AccessPolicy{{ TenantUUID: uuid.NewString(), ResourceNamespace: "account-server", Role: "account_viewer", ResourcePath: "resource-" + uuid.NewString(), }} + validJWTClaims := jwtgo.RegisteredClaims{ + Issuer: httpIDPSrv.URL(), + Subject: uuid.NewString(), + ID: uuid.NewString(), + ExpiresAt: jwtgo.NewNumericDate(time.Now().Add(time.Hour)), + } + validJWT := idptest.MustMakeTokenStringSignedWithTestKey(jwt.Claims{RegisteredClaims: validJWTClaims}) + httpServerIntrospector.SetResultForToken(validJWT, idptoken.IntrospectionResult{Active: true, + TokenType: idputil.TokenTypeBearer, Claims: jwt.Claims{RegisteredClaims: validJWTClaims, Scope: validJWTScope}}) + // Opaque token opaqueToken := "opaque-token-" + uuid.NewString() opaqueTokenScope := []jwt.AccessPolicy{{ TenantUUID: uuid.NewString(), @@ -83,13 +95,10 @@ func TestIntrospector_IntrospectToken(t *gotesting.T) { Role: "admin", ResourcePath: "resource-" + uuid.NewString(), }} - - httpServerIntrospector.SetScopeForJWTID(jwtID, jwtScope) httpServerIntrospector.SetResultForToken(opaqueToken, idptoken.IntrospectionResult{ Active: true, TokenType: idputil.TokenTypeBearer, Claims: jwt.Claims{Scope: opaqueTokenScope}}) - grpcServerIntrospector.SetScopeForJWTID(jwtID, jwtScopeToGRPC(jwtScope)) - grpcServerIntrospector.SetResultForToken(opaqueToken, &pb.IntrospectTokenResponse{ - Active: true, TokenType: idputil.TokenTypeBearer, Scope: jwtScopeToGRPC(opaqueTokenScope)}) + grpcServerIntrospector.SetResultForToken(opaqueToken, &pb.IntrospectTokenResponse{Active: true, + TokenType: idputil.TokenTypeBearer, Scope: jwtScopeToGRPC(opaqueTokenScope)}) tests := []struct { name string @@ -183,40 +192,18 @@ func TestIntrospector_IntrospectToken(t *gotesting.T) { }, }, { - name: "ok, dynamic introspection endpoint, introspected token is expired JWT", - tokenToIntrospect: idptest.MustMakeTokenStringSignedWithTestKey(jwt.Claims{ - RegisteredClaims: jwtgo.RegisteredClaims{ - Issuer: httpIDPSrv.URL(), - Subject: uuid.NewString(), - ID: uuid.NewString(), - ExpiresAt: jwtgo.NewNumericDate(time.Now().Add(-time.Hour)), - }, - }), + name: "ok, dynamic introspection endpoint, introspected token is expired JWT", + tokenToIntrospect: expiredJWT, expectedResult: idptoken.IntrospectionResult{Active: false}, expectedHTTPSrvCalled: true, }, { - name: "ok, dynamic introspection endpoint, introspected token is JWT", - tokenToIntrospect: idptest.MustMakeTokenStringSignedWithTestKey(jwt.Claims{ - RegisteredClaims: jwtgo.RegisteredClaims{ - Issuer: jwtIssuer, - Subject: jwtSubject, - ID: jwtID, - ExpiresAt: jwtExpiresAtInFuture, - }, - }), + name: "ok, dynamic introspection endpoint, introspected token is JWT", + tokenToIntrospect: validJWT, expectedResult: idptoken.IntrospectionResult{ Active: true, TokenType: idputil.TokenTypeBearer, - Claims: jwt.Claims{ - RegisteredClaims: jwtgo.RegisteredClaims{ - Issuer: jwtIssuer, - Subject: jwtSubject, - ID: jwtID, - ExpiresAt: jwtExpiresAtInFuture, - }, - Scope: jwtScope, - }, + Claims: jwt.Claims{RegisteredClaims: validJWTClaims, Scope: validJWTScope}, }, expectedHTTPSrvCalled: true, }, @@ -352,21 +339,7 @@ func TestCachingIntrospector_IntrospectTokenWithCache(t *gotesting.T) { require.NoError(t, idpSrv.StartAndWaitForReady(time.Second)) defer func() { _ = idpSrv.Shutdown(context.Background()) }() - jwtParser := jwt.NewParser(jwks.NewClient()) - require.NoError(t, jwtParser.AddTrustedIssuerURL(idpSrv.URL())) - serverIntrospector.JWTParser = jwtParser - - jwtExpiresAtInFuture := jwtgo.NewNumericDate(time.Now().Add(time.Hour)) - jwtIssuer := idpSrv.URL() - jwtSubject := uuid.NewString() - jwtID := uuid.NewString() - jwtScope := []jwt.AccessPolicy{{ - TenantUUID: uuid.NewString(), - ResourceNamespace: "account-server", - Role: "account_viewer", - ResourcePath: "resource-" + uuid.NewString(), - }} - + // Expired JWT expiredJWT := idptest.MustMakeTokenStringSignedWithTestKey(jwt.Claims{ RegisteredClaims: jwtgo.RegisteredClaims{ Issuer: idpSrv.URL(), @@ -375,15 +348,41 @@ func TestCachingIntrospector_IntrospectTokenWithCache(t *gotesting.T) { ExpiresAt: jwtgo.NewNumericDate(time.Now().Add(-time.Hour)), }, }) - activeJWT := idptest.MustMakeTokenStringSignedWithTestKey(jwt.Claims{ - RegisteredClaims: jwtgo.RegisteredClaims{ - Issuer: jwtIssuer, - Subject: jwtSubject, - ID: jwtID, - ExpiresAt: jwtExpiresAtInFuture, - }, - }) + serverIntrospector.SetResultForToken(expiredJWT, idptoken.IntrospectionResult{Active: false}) + // Valid JWTs with scope + validJWT1Scope := []jwt.AccessPolicy{{ + TenantUUID: uuid.NewString(), + ResourceNamespace: "account-server", + Role: "account_viewer", + ResourcePath: "resource-" + uuid.NewString(), + }} + validJWT1Claims := jwtgo.RegisteredClaims{ + Issuer: idpSrv.URL(), + Subject: uuid.NewString(), + ID: uuid.NewString(), + ExpiresAt: jwtgo.NewNumericDate(time.Now().Add(2 * time.Hour)), + } + valid1JWT := idptest.MustMakeTokenStringSignedWithTestKey(jwt.Claims{RegisteredClaims: validJWT1Claims}) + serverIntrospector.SetResultForToken(valid1JWT, idptoken.IntrospectionResult{Active: true, + TokenType: idputil.TokenTypeBearer, Claims: jwt.Claims{RegisteredClaims: validJWT1Claims, Scope: validJWT1Scope}}) + validJWT2Scope := []jwt.AccessPolicy{{ + TenantUUID: uuid.NewString(), + ResourceNamespace: "account-server", + Role: "account_viewer", + ResourcePath: "resource-" + uuid.NewString(), + }} + validJWT2Claims := jwtgo.RegisteredClaims{ + Issuer: idpSrv.URL(), + Subject: uuid.NewString(), + ID: uuid.NewString(), + ExpiresAt: jwtgo.NewNumericDate(time.Now().Add(time.Hour)), + } + valid2JWT := idptest.MustMakeTokenStringSignedWithTestKey(jwt.Claims{RegisteredClaims: validJWT2Claims}) + serverIntrospector.SetResultForToken(valid2JWT, idptoken.IntrospectionResult{Active: true, + TokenType: idputil.TokenTypeBearer, Claims: jwt.Claims{RegisteredClaims: validJWT2Claims, Scope: validJWT2Scope}}) + + // Opaque tokens opaqueToken1 := "opaque-token-" + uuid.NewString() opaqueToken2 := "opaque-token-" + uuid.NewString() opaqueToken3 := "opaque-token-" + uuid.NewString() @@ -399,8 +398,6 @@ func TestCachingIntrospector_IntrospectTokenWithCache(t *gotesting.T) { Role: "admin", ResourcePath: "resource-" + uuid.NewString(), }} - - serverIntrospector.SetScopeForJWTID(jwtID, jwtScope) serverIntrospector.SetResultForToken(opaqueToken1, idptoken.IntrospectionResult{ Active: true, TokenType: idputil.TokenTypeBearer, Claims: jwt.Claims{Scope: opaqueToken1Scope}}) serverIntrospector.SetResultForToken(opaqueToken2, idptoken.IntrospectionResult{ @@ -411,7 +408,7 @@ func TestCachingIntrospector_IntrospectTokenWithCache(t *gotesting.T) { name string introspectorOpts idptoken.IntrospectorOpts tokens []string - expectedSrvCalled []bool + expectedSrvCounts []map[string]uint64 expectedResult []idptoken.IntrospectionResult checkError []func(t *gotesting.T, err error) checkIntrospector func(t *gotesting.T, introspector *idptoken.Introspector) @@ -420,7 +417,7 @@ func TestCachingIntrospector_IntrospectTokenWithCache(t *gotesting.T) { { name: "error, token is not introspectable", tokens: []string{"", "opaque-token"}, - expectedSrvCalled: []bool{false, false}, + expectedSrvCounts: []map[string]uint64{{}, {}}, introspectorOpts: idptoken.IntrospectorOpts{ ClaimsCache: idptoken.IntrospectorCacheOpts{Enabled: true}, NegativeCache: idptoken.IntrospectorCacheOpts{Enabled: true}, @@ -438,57 +435,87 @@ func TestCachingIntrospector_IntrospectTokenWithCache(t *gotesting.T) { checkIntrospector: func(t *gotesting.T, introspector *idptoken.Introspector) { require.Equal(t, 0, introspector.ClaimsCache.Len(context.Background())) require.Equal(t, 0, introspector.NegativeCache.Len(context.Background())) + require.Equal(t, 0, introspector.EndpointDiscoveryCache.Len(context.Background())) }, }, { name: "ok, dynamic introspection endpoint, introspected token is expired JWT", introspectorOpts: idptoken.IntrospectorOpts{ - ClaimsCache: idptoken.IntrospectorCacheOpts{Enabled: true}, - NegativeCache: idptoken.IntrospectorCacheOpts{Enabled: true}, + ClaimsCache: idptoken.IntrospectorCacheOpts{Enabled: true}, + NegativeCache: idptoken.IntrospectorCacheOpts{Enabled: true}, + EndpointDiscoveryCache: idptoken.IntrospectorCacheOpts{Enabled: true}, }, - tokens: repeat(expiredJWT, 2), - expectedSrvCalled: []bool{true, false}, - expectedResult: []idptoken.IntrospectionResult{{Active: false}, {Active: false}}, + tokens: repeat(expiredJWT, 2), + expectedSrvCounts: []map[string]uint64{ + {idptest.TokenIntrospectionEndpointPath: 1, idptest.OpenIDConfigurationPath: 1}, + {}, + }, + expectedResult: []idptoken.IntrospectionResult{{Active: false}, {Active: false}}, checkIntrospector: func(t *gotesting.T, introspector *idptoken.Introspector) { require.Equal(t, 0, introspector.ClaimsCache.Len(context.Background())) require.Equal(t, 1, introspector.NegativeCache.Len(context.Background())) + require.Equal(t, 1, introspector.EndpointDiscoveryCache.Len(context.Background())) }, }, { name: "ok, dynamic introspection endpoint, introspected token is JWT", introspectorOpts: idptoken.IntrospectorOpts{ - ClaimsCache: idptoken.IntrospectorCacheOpts{Enabled: true}, - NegativeCache: idptoken.IntrospectorCacheOpts{Enabled: true}, + ClaimsCache: idptoken.IntrospectorCacheOpts{Enabled: true}, + NegativeCache: idptoken.IntrospectorCacheOpts{Enabled: true}, + EndpointDiscoveryCache: idptoken.IntrospectorCacheOpts{Enabled: true}, }, - tokens: repeat(activeJWT, 2), - expectedSrvCalled: []bool{true, false}, - expectedResult: repeat(idptoken.IntrospectionResult{ - Active: true, - TokenType: idputil.TokenTypeBearer, - Claims: jwt.Claims{ - RegisteredClaims: jwtgo.RegisteredClaims{ - Issuer: jwtIssuer, - Subject: jwtSubject, - ID: jwtID, - ExpiresAt: jwtExpiresAtInFuture, - }, - Scope: jwtScope, + tokens: []string{valid1JWT, valid1JWT, valid2JWT, valid2JWT}, + expectedSrvCounts: []map[string]uint64{ + {idptest.TokenIntrospectionEndpointPath: 1, idptest.OpenIDConfigurationPath: 1}, + {}, + {idptest.TokenIntrospectionEndpointPath: 1}, + {}, + }, + expectedResult: []idptoken.IntrospectionResult{ + { + Active: true, + TokenType: idputil.TokenTypeBearer, + Claims: jwt.Claims{RegisteredClaims: validJWT1Claims, Scope: validJWT1Scope}, + }, + { + Active: true, + TokenType: idputil.TokenTypeBearer, + Claims: jwt.Claims{RegisteredClaims: validJWT1Claims, Scope: validJWT1Scope}, }, - }, 2), + { + Active: true, + TokenType: idputil.TokenTypeBearer, + Claims: jwt.Claims{RegisteredClaims: validJWT2Claims, Scope: validJWT2Scope}, + }, + { + Active: true, + TokenType: idputil.TokenTypeBearer, + Claims: jwt.Claims{RegisteredClaims: validJWT2Claims, Scope: validJWT2Scope}, + }, + }, checkIntrospector: func(t *gotesting.T, introspector *idptoken.Introspector) { - require.Equal(t, 1, introspector.ClaimsCache.Len(context.Background())) + require.Equal(t, 2, introspector.ClaimsCache.Len(context.Background())) require.Equal(t, 0, introspector.NegativeCache.Len(context.Background())) + require.Equal(t, 1, introspector.EndpointDiscoveryCache.Len(context.Background())) }, }, { name: "ok, static introspection endpoint, introspected token is opaque", introspectorOpts: idptoken.IntrospectorOpts{ - HTTPEndpoint: idpSrv.URL() + idptest.TokenIntrospectionEndpointPath, - ClaimsCache: idptoken.IntrospectorCacheOpts{Enabled: true}, - NegativeCache: idptoken.IntrospectorCacheOpts{Enabled: true}, + HTTPEndpoint: idpSrv.URL() + idptest.TokenIntrospectionEndpointPath, + ClaimsCache: idptoken.IntrospectorCacheOpts{Enabled: true}, + NegativeCache: idptoken.IntrospectorCacheOpts{Enabled: true}, + EndpointDiscoveryCache: idptoken.IntrospectorCacheOpts{Enabled: true}, + }, + tokens: []string{opaqueToken1, opaqueToken1, opaqueToken2, opaqueToken2, opaqueToken3, opaqueToken3}, + expectedSrvCounts: []map[string]uint64{ + {idptest.TokenIntrospectionEndpointPath: 1}, + {idptest.TokenIntrospectionEndpointPath: 0}, + {idptest.TokenIntrospectionEndpointPath: 1}, + {idptest.TokenIntrospectionEndpointPath: 0}, + {idptest.TokenIntrospectionEndpointPath: 1}, + {idptest.TokenIntrospectionEndpointPath: 0}, }, - tokens: []string{opaqueToken1, opaqueToken1, opaqueToken2, opaqueToken2, opaqueToken3, opaqueToken3}, - expectedSrvCalled: []bool{true, false, true, false, true, false}, expectedResult: []idptoken.IntrospectionResult{ {Active: true, TokenType: idputil.TokenTypeBearer, Claims: jwt.Claims{Scope: opaqueToken1Scope}}, {Active: true, TokenType: idputil.TokenTypeBearer, Claims: jwt.Claims{Scope: opaqueToken1Scope}}, @@ -500,17 +527,24 @@ func TestCachingIntrospector_IntrospectTokenWithCache(t *gotesting.T) { checkIntrospector: func(t *gotesting.T, introspector *idptoken.Introspector) { require.Equal(t, 2, introspector.ClaimsCache.Len(context.Background())) require.Equal(t, 1, introspector.NegativeCache.Len(context.Background())) + require.Equal(t, 0, introspector.EndpointDiscoveryCache.Len(context.Background())) }, }, { - name: "ok, cache has ttl", + name: "ok, static introspection endpoint, cache has ttl", introspectorOpts: idptoken.IntrospectorOpts{ - HTTPEndpoint: idpSrv.URL() + idptest.TokenIntrospectionEndpointPath, - ClaimsCache: idptoken.IntrospectorCacheOpts{Enabled: true, TTL: 100 * time.Millisecond}, - NegativeCache: idptoken.IntrospectorCacheOpts{Enabled: true, TTL: 100 * time.Millisecond}, + HTTPEndpoint: idpSrv.URL() + idptest.TokenIntrospectionEndpointPath, + ClaimsCache: idptoken.IntrospectorCacheOpts{Enabled: true, TTL: 100 * time.Millisecond}, + NegativeCache: idptoken.IntrospectorCacheOpts{Enabled: true, TTL: 100 * time.Millisecond}, + EndpointDiscoveryCache: idptoken.IntrospectorCacheOpts{Enabled: true, TTL: 100 * time.Millisecond}, + }, + tokens: []string{opaqueToken1, opaqueToken1, opaqueToken3, opaqueToken3}, + expectedSrvCounts: []map[string]uint64{ + {idptest.TokenIntrospectionEndpointPath: 1}, + {idptest.TokenIntrospectionEndpointPath: 1}, + {idptest.TokenIntrospectionEndpointPath: 1}, + {idptest.TokenIntrospectionEndpointPath: 1}, }, - tokens: []string{opaqueToken1, opaqueToken1, opaqueToken3, opaqueToken3}, - expectedSrvCalled: []bool{true, true, true, true}, expectedResult: []idptoken.IntrospectionResult{ {Active: true, TokenType: idputil.TokenTypeBearer, Claims: jwt.Claims{Scope: opaqueToken1Scope}}, {Active: true, TokenType: idputil.TokenTypeBearer, Claims: jwt.Claims{Scope: opaqueToken1Scope}}, @@ -520,6 +554,7 @@ func TestCachingIntrospector_IntrospectTokenWithCache(t *gotesting.T) { checkIntrospector: func(t *gotesting.T, introspector *idptoken.Introspector) { require.Equal(t, 1, introspector.ClaimsCache.Len(context.Background())) require.Equal(t, 1, introspector.NegativeCache.Len(context.Background())) + require.Equal(t, 0, introspector.EndpointDiscoveryCache.Len(context.Background())) }, delay: 200 * time.Millisecond, }, @@ -532,7 +567,7 @@ func TestCachingIntrospector_IntrospectTokenWithCache(t *gotesting.T) { require.NoError(t, introspector.AddTrustedIssuerURL(idpSrv.URL())) for i, token := range tt.tokens { - serverIntrospector.ResetCallsInfo() + idpSrv.ResetServedCounts() result, introspectErr := introspector.IntrospectToken(context.Background(), token) if i < len(tt.checkError) { @@ -542,8 +577,12 @@ func TestCachingIntrospector_IntrospectTokenWithCache(t *gotesting.T) { require.Equal(t, tt.expectedResult[i], result) } - require.Equal(t, tt.expectedSrvCalled[i], serverIntrospector.Called) - if tt.expectedSrvCalled[i] { + require.Equal(t, tt.expectedSrvCounts[i][idptest.TokenIntrospectionEndpointPath], + idpSrv.ServedCounts()[idptest.TokenIntrospectionEndpointPath]) + require.Equal(t, tt.expectedSrvCounts[i][idptest.OpenIDConfigurationPath], + idpSrv.ServedCounts()[idptest.OpenIDConfigurationPath]) + + if tt.expectedSrvCounts[i][idptest.TokenIntrospectionEndpointPath] > 0 { require.Equal(t, token, serverIntrospector.LastIntrospectedToken) require.Equal(t, "Bearer "+accessToken, serverIntrospector.LastAuthorizationHeader) require.Equal(t, url.Values{"token": {token}}, serverIntrospector.LastFormValues) diff --git a/idptoken/provider.go b/idptoken/provider.go index 020d302..34d938f 100644 --- a/idptoken/provider.go +++ b/idptoken/provider.go @@ -32,7 +32,6 @@ const ( defaultMinRefreshPeriod = time.Second * 10 defaultExpirationOffset = time.Minute * 30 expiryDeltaMaxOffset = 5 - wellKnownPath = "/.well-known/openid-configuration" ) // ErrSourceNotRegistered is returned if GetToken is requested for the unknown Source @@ -541,7 +540,7 @@ func (ti *oauth2Issuer) ensureTokenURL(ctx context.Context, customHeaders map[st return nil } - openIDCfgURL := strings.TrimSuffix(ti.baseURL, "/") + wellKnownPath + openIDCfgURL := strings.TrimSuffix(ti.baseURL, "/") + idputil.OpenIDConfigurationPath openIDCfg, err := idputil.GetOpenIDConfiguration( ctx, ti.httpClient, openIDCfgURL, customHeaders, ti.logger, ti.promMetrics) if err != nil { diff --git a/internal/idputil/openid_configuration.go b/internal/idputil/openid_configuration.go index 370bb08..5d0dfdf 100644 --- a/internal/idputil/openid_configuration.go +++ b/internal/idputil/openid_configuration.go @@ -18,6 +18,8 @@ import ( "github.com/acronis/go-authkit/internal/metrics" ) +const OpenIDConfigurationPath = "/.well-known/openid-configuration" + type OpenIDConfiguration struct { TokenURL string `json:"token_endpoint"` IntrospectionEndpoint string `json:"introspection_endpoint"` diff --git a/internal/metrics/metrics.go b/internal/metrics/metrics.go index e7cf519..ee17f8a 100644 --- a/internal/metrics/metrics.go +++ b/internal/metrics/metrics.go @@ -60,6 +60,7 @@ type PrometheusMetrics struct { GRPCClientRequestDuration *prometheus.HistogramVec TokenClaimsCache *lrucache.PrometheusMetrics TokenNegativeCache *lrucache.PrometheusMetrics + EndpointDiscoveryCache *lrucache.PrometheusMetrics } func GetPrometheusMetrics(instance string, source string) *PrometheusMetrics { @@ -117,11 +118,18 @@ func newPrometheusMetrics() *PrometheusMetrics { CurriedLabelNames: curriedLabelNames, }) + endpointDiscoveryCache := lrucache.NewPrometheusMetricsWithOpts(lrucache.PrometheusMetricsOpts{ + Namespace: PrometheusNamespace + "_openid_configuration", + ConstLabels: PrometheusLabels(), + CurriedLabelNames: curriedLabelNames, + }) + return &PrometheusMetrics{ HTTPClientRequestDuration: httpClientReqDuration, GRPCClientRequestDuration: grpcClientReqDuration, TokenClaimsCache: tokenClaimsCache, TokenNegativeCache: tokenNegativeCache, + EndpointDiscoveryCache: endpointDiscoveryCache, } } @@ -132,6 +140,7 @@ func (pm *PrometheusMetrics) MustCurryWith(labels prometheus.Labels) *Prometheus GRPCClientRequestDuration: pm.GRPCClientRequestDuration.MustCurryWith(labels).(*prometheus.HistogramVec), TokenClaimsCache: pm.TokenClaimsCache.MustCurryWith(labels), TokenNegativeCache: pm.TokenNegativeCache.MustCurryWith(labels), + EndpointDiscoveryCache: pm.EndpointDiscoveryCache.MustCurryWith(labels), } } @@ -143,6 +152,7 @@ func (pm *PrometheusMetrics) MustRegister() { ) pm.TokenClaimsCache.MustRegister() pm.TokenNegativeCache.MustRegister() + pm.EndpointDiscoveryCache.MustRegister() } // Unregister cancels registration of metrics collector in Prometheus. @@ -151,6 +161,7 @@ func (pm *PrometheusMetrics) Unregister() { prometheus.Unregister(pm.GRPCClientRequestDuration) pm.TokenClaimsCache.Unregister() pm.TokenNegativeCache.Unregister() + pm.EndpointDiscoveryCache.Unregister() } func (pm *PrometheusMetrics) ObserveHTTPClientRequest( diff --git a/jwks/client.go b/jwks/client.go index 06a5506..28973c3 100644 --- a/jwks/client.go +++ b/jwks/client.go @@ -23,8 +23,6 @@ import ( "github.com/acronis/go-authkit/internal/metrics" ) -const OpenIDConfigurationPath = "/.well-known/openid-configuration" - type jwksData struct { Keys []*gojwk.Key `json:"keys"` } @@ -69,7 +67,7 @@ func NewClientWithOpts(opts ClientOpts) *Client { func (c *Client) getRSAPubKeysForIssuer(ctx context.Context, issuerURL string) (map[string]interface{}, error) { logger := idputil.GetLoggerFromProvider(ctx, c.loggerProvider) - openIDConfigURL := strings.TrimPrefix(issuerURL, "/") + OpenIDConfigurationPath + openIDConfigURL := strings.TrimPrefix(issuerURL, "/") + idputil.OpenIDConfigurationPath openIDConfig, err := idputil.GetOpenIDConfiguration( ctx, c.httpClient, openIDConfigURL, nil, logger, c.promMetrics) if err != nil { diff --git a/jwks/client_test.go b/jwks/client_test.go index 5bcca45..bdd4974 100644 --- a/jwks/client_test.go +++ b/jwks/client_test.go @@ -19,6 +19,7 @@ import ( "github.com/stretchr/testify/require" "github.com/acronis/go-authkit/idptest" + "github.com/acronis/go-authkit/internal/idputil" "github.com/acronis/go-authkit/jwks" ) @@ -47,7 +48,7 @@ func TestClient_GetRSAPublicKey(t *testing.T) { require.Error(t, err) var openIDCfgErr *jwks.GetOpenIDConfigurationError require.True(t, errors.As(err, &openIDCfgErr)) - require.Equal(t, issuerConfigServer.URL+jwks.OpenIDConfigurationPath, openIDCfgErr.URL) + require.Equal(t, issuerConfigServer.URL+idputil.OpenIDConfigurationPath, openIDCfgErr.URL) requireLocalhostConnRefusedError(t, openIDCfgErr.Inner) require.Nil(t, pubKey) }) @@ -63,7 +64,7 @@ func TestClient_GetRSAPublicKey(t *testing.T) { require.Error(t, err) var openIDCfgErr *jwks.GetOpenIDConfigurationError require.True(t, errors.As(err, &openIDCfgErr)) - require.Equal(t, issuerConfigServer.URL+jwks.OpenIDConfigurationPath, openIDCfgErr.URL) + require.Equal(t, issuerConfigServer.URL+idputil.OpenIDConfigurationPath, openIDCfgErr.URL) require.EqualError(t, openIDCfgErr.Inner, "unexpected HTTP code 500") require.Nil(t, pubKey) }) @@ -80,7 +81,7 @@ func TestClient_GetRSAPublicKey(t *testing.T) { require.Error(t, err) var openIDCfgErr *jwks.GetOpenIDConfigurationError require.True(t, errors.As(err, &openIDCfgErr)) - require.Equal(t, issuerConfigServer.URL+jwks.OpenIDConfigurationPath, openIDCfgErr.URL) + require.Equal(t, issuerConfigServer.URL+idputil.OpenIDConfigurationPath, openIDCfgErr.URL) var jsonSyntaxErr *json.SyntaxError require.True(t, errors.As(openIDCfgErr, &jsonSyntaxErr)) require.Nil(t, pubKey) @@ -98,7 +99,7 @@ func TestClient_GetRSAPublicKey(t *testing.T) { var jwksErr *jwks.GetJWKSError require.True(t, errors.As(err, &jwksErr)) require.Equal(t, jwksServer.URL, jwksErr.URL) - require.Equal(t, issuerConfigServer.URL+jwks.OpenIDConfigurationPath, jwksErr.OpenIDConfigurationURL) + require.Equal(t, issuerConfigServer.URL+idputil.OpenIDConfigurationPath, jwksErr.OpenIDConfigurationURL) requireLocalhostConnRefusedError(t, jwksErr.Inner) require.Nil(t, pubKey) }) @@ -117,7 +118,7 @@ func TestClient_GetRSAPublicKey(t *testing.T) { var jwksErr *jwks.GetJWKSError require.True(t, errors.As(err, &jwksErr)) require.Equal(t, jwksServer.URL, jwksErr.URL) - require.Equal(t, issuerConfigServer.URL+jwks.OpenIDConfigurationPath, jwksErr.OpenIDConfigurationURL) + require.Equal(t, issuerConfigServer.URL+idputil.OpenIDConfigurationPath, jwksErr.OpenIDConfigurationURL) require.EqualError(t, jwksErr.Inner, "unexpected HTTP code 500") require.Nil(t, pubKey) }) @@ -153,7 +154,7 @@ func TestClient_GetRSAPublicKey(t *testing.T) { require.Error(t, err) var openIDCfgErr *jwks.GetOpenIDConfigurationError require.True(t, errors.As(err, &openIDCfgErr)) - require.Equal(t, issuerConfigServer.URL+jwks.OpenIDConfigurationPath, openIDCfgErr.URL) + require.Equal(t, issuerConfigServer.URL+idputil.OpenIDConfigurationPath, openIDCfgErr.URL) require.ErrorIs(t, openIDCfgErr, context.Canceled) require.Nil(t, pubKey) })