From 138392de31b004f3b4e93b8f539375f3b4fe51d6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ruggero=20=28=D0=A0=D1=83=D0=B4=D0=B6=D0=B5=D1=80=D0=BE=29?= =?UTF-8?q?=20Ferretti?= Date: Thu, 19 Dec 2024 18:00:12 +0100 Subject: [PATCH] In jwtAuthHandler.ServeHTTP(), no longer saving the request context off, to allow injecting key/value pairs by user defined functions, such as verifyAccess. --- middleware.go | 36 ++++++------ middleware_test.go | 136 +++++++++++++++++++++++++++++---------------- 2 files changed, 106 insertions(+), 66 deletions(-) diff --git a/middleware.go b/middleware.go index 8dc2912..df444a2 100644 --- a/middleware.go +++ b/middleware.go @@ -139,33 +139,34 @@ func JWTAuthMiddleware(errorDomain string, jwtParser JWTParser, opts ...JWTAuthM } func (h *jwtAuthHandler) ServeHTTP(rw http.ResponseWriter, r *http.Request) { - reqCtx := r.Context() + logger := idputil.GetLoggerFromProvider(r.Context(), h.loggerProvider) bearerToken := GetBearerTokenFromRequest(r) if bearerToken == "" { apiErr := restapi.NewError(h.errorDomain, ErrCodeBearerTokenMissing, ErrMessageBearerTokenMissing) - restapi.RespondError(rw, http.StatusUnauthorized, apiErr, h.logger(reqCtx)) + restapi.RespondError(rw, http.StatusUnauthorized, apiErr, logger) return } + // Add the bearer token to the request context + r = r.WithContext(NewContextWithBearerToken(r.Context(), bearerToken)) var jwtClaims jwt.Claims if h.tokenIntrospector != nil { - if introspectionResult, err := h.tokenIntrospector.IntrospectToken(reqCtx, bearerToken); err != nil { + if introspectionResult, err := h.tokenIntrospector.IntrospectToken(r.Context(), bearerToken); err != nil { switch { case errors.Is(err, idptoken.ErrTokenIntrospectionNotNeeded): // Do nothing. Access Token already contains all necessary information for authN/authZ. - h.logger(reqCtx).AtLevel(log.LevelDebug, func(logFunc log.LogFunc) { + logger.AtLevel(log.LevelDebug, func(logFunc log.LogFunc) { logFunc("token's introspection is not needed") }) h.promMetrics.IncTokenIntrospectionsTotal(metrics.TokenIntrospectionStatusNotNeeded) case errors.Is(err, idptoken.ErrTokenNotIntrospectable): // Token is not introspectable by some reason. // In this case, we will parse it as JWT and use it for authZ. - h.logger(reqCtx).Warn("token is not introspectable, it will be used for authentication and authorization as is", + logger.Warn("token is not introspectable, it will be used for authentication and authorization as is", log.Error(err)) h.promMetrics.IncTokenIntrospectionsTotal(metrics.TokenIntrospectionStatusNotIntrospectable) default: - logger := h.logger(reqCtx) logger.Error("token's introspection failed", log.Error(err)) h.promMetrics.IncTokenIntrospectionsTotal(metrics.TokenIntrospectionStatusError) apiErr := restapi.NewError(h.errorDomain, ErrCodeAuthenticationFailed, ErrMessageAuthenticationFailed) @@ -174,14 +175,14 @@ func (h *jwtAuthHandler) ServeHTTP(rw http.ResponseWriter, r *http.Request) { } } else { if !introspectionResult.IsActive() { - h.logger(reqCtx).Warn("token was successfully introspected, but it is not active") + logger.Warn("token was successfully introspected, but it is not active") h.promMetrics.IncTokenIntrospectionsTotal(metrics.TokenIntrospectionStatusNotActive) apiErr := restapi.NewError(h.errorDomain, ErrCodeAuthenticationFailed, ErrMessageAuthenticationFailed) - restapi.RespondError(rw, http.StatusUnauthorized, apiErr, h.logger(reqCtx)) + restapi.RespondError(rw, http.StatusUnauthorized, apiErr, logger) return } jwtClaims = introspectionResult.GetClaims() - h.logger(reqCtx).AtLevel(log.LevelDebug, func(logFunc log.LogFunc) { + logger.AtLevel(log.LevelDebug, func(logFunc log.LogFunc) { logFunc("token was successfully introspected") }) h.promMetrics.IncTokenIntrospectionsTotal(metrics.TokenIntrospectionStatusActive) @@ -190,30 +191,27 @@ func (h *jwtAuthHandler) ServeHTTP(rw http.ResponseWriter, r *http.Request) { if jwtClaims == nil { var err error - if jwtClaims, err = h.jwtParser.Parse(reqCtx, bearerToken); err != nil { - logger := h.logger(reqCtx) + if jwtClaims, err = h.jwtParser.Parse(r.Context(), bearerToken); err != nil { logger.Error("authentication failed", log.Error(err)) apiErr := restapi.NewError(h.errorDomain, ErrCodeAuthenticationFailed, ErrMessageAuthenticationFailed) restapi.RespondError(rw, http.StatusUnauthorized, apiErr, logger) return } } + // Add the JWT claims to the request context + r = r.WithContext(NewContextWithJWTClaims(r.Context(), jwtClaims)) if h.verifyAccess != nil { + // By passing a *http.Request to verifyAccess, we allow its implementations + // to inject new key/value pairs into the request context. if !h.verifyAccess(r, jwtClaims) { apiErr := restapi.NewError(h.errorDomain, ErrCodeAuthorizationFailed, ErrMessageAuthorizationFailed) - restapi.RespondError(rw, http.StatusForbidden, apiErr, h.logger(reqCtx)) + restapi.RespondError(rw, http.StatusForbidden, apiErr, logger) return } } - reqCtx = NewContextWithBearerToken(reqCtx, bearerToken) - reqCtx = NewContextWithJWTClaims(reqCtx, jwtClaims) - h.next.ServeHTTP(rw, r.WithContext(reqCtx)) -} - -func (h *jwtAuthHandler) logger(ctx context.Context) log.FieldLogger { - return idputil.GetLoggerFromProvider(ctx, h.loggerProvider) + h.next.ServeHTTP(rw, r) } // GetBearerTokenFromRequest extracts jwt token from request headers. diff --git a/middleware_test.go b/middleware_test.go index 66d35f2..2a27ea7 100644 --- a/middleware_test.go +++ b/middleware_test.go @@ -4,7 +4,7 @@ Copyright © 2024 Acronis International GmbH. Released under MIT license. */ -package authkit +package authkit_test import ( "context" @@ -17,19 +17,25 @@ import ( jwtgo "github.com/golang-jwt/jwt/v5" "github.com/stretchr/testify/require" + "github.com/acronis/go-authkit" "github.com/acronis/go-authkit/idptoken" "github.com/acronis/go-authkit/internal/metrics" "github.com/acronis/go-authkit/jwt" ) +const ( + testErrDomain = "TestDomain" + testBearerToken = "a.b.c" +) + type mockJWTAuthMiddlewareNextHandler struct { - called int - jwtClaims jwt.Claims + request *http.Request + called int } func (h *mockJWTAuthMiddlewareNextHandler) ServeHTTP(_ http.ResponseWriter, r *http.Request) { + h.request = r h.called++ - h.jwtClaims = GetJWTClaimsFromContext(r.Context()) } type mockJWTParser struct { @@ -59,24 +65,22 @@ func (i *mockTokenIntrospector) IntrospectToken(_ context.Context, token string) } func TestJWTAuthMiddleware(t *testing.T) { - const errDomain = "TestDomain" - t.Run("bearer token is missing", func(t *testing.T) { for _, headerVal := range []string{"", "foobar", "Bearer", "Bearer "} { parser := &mockJWTParser{} next := &mockJWTAuthMiddlewareNextHandler{} req := httptest.NewRequest(http.MethodGet, "/", http.NoBody) if headerVal != "" { - req.Header.Set(HeaderAuthorization, headerVal) + req.Header.Set(authkit.HeaderAuthorization, headerVal) } resp := httptest.NewRecorder() - JWTAuthMiddleware(errDomain, parser)(next).ServeHTTP(resp, req) + authkit.JWTAuthMiddleware(testErrDomain, parser)(next).ServeHTTP(resp, req) - testutil.RequireErrorInRecorder(t, resp, http.StatusUnauthorized, errDomain, ErrCodeBearerTokenMissing) + testutil.RequireErrorInRecorder(t, resp, http.StatusUnauthorized, testErrDomain, authkit.ErrCodeBearerTokenMissing) require.Equal(t, 0, parser.parseCalled) require.Equal(t, 0, next.called) - require.Nil(t, next.jwtClaims) + require.Nil(t, next.request) } }) @@ -84,15 +88,15 @@ func TestJWTAuthMiddleware(t *testing.T) { parser := &mockJWTParser{errToReturn: errors.New("malformed JWT")} next := &mockJWTAuthMiddlewareNextHandler{} req := httptest.NewRequest(http.MethodGet, "/", http.NoBody) - req.Header.Set(HeaderAuthorization, "Bearer foobar") + withBearerToken(req, "foobar") resp := httptest.NewRecorder() - JWTAuthMiddleware(errDomain, parser)(next).ServeHTTP(resp, req) + authkit.JWTAuthMiddleware(testErrDomain, parser)(next).ServeHTTP(resp, req) - testutil.RequireErrorInRecorder(t, resp, http.StatusUnauthorized, errDomain, ErrCodeAuthenticationFailed) + testutil.RequireErrorInRecorder(t, resp, http.StatusUnauthorized, testErrDomain, authkit.ErrCodeAuthenticationFailed) require.Equal(t, 1, parser.parseCalled) require.Equal(t, 0, next.called) - require.Nil(t, next.jwtClaims) + require.Nil(t, next.request) }) t.Run("ok", func(t *testing.T) { @@ -100,16 +104,16 @@ func TestJWTAuthMiddleware(t *testing.T) { parser := &mockJWTParser{claimsToReturn: &jwt.DefaultClaims{RegisteredClaims: jwtgo.RegisteredClaims{Issuer: issuer}}} next := &mockJWTAuthMiddlewareNextHandler{} req := httptest.NewRequest(http.MethodGet, "/", http.NoBody) - req.Header.Set(HeaderAuthorization, "Bearer a.b.c") + withBearerToken(req, testBearerToken) resp := httptest.NewRecorder() - JWTAuthMiddleware(errDomain, parser)(next).ServeHTTP(resp, req) + authkit.JWTAuthMiddleware(testErrDomain, parser)(next).ServeHTTP(resp, req) require.Equal(t, http.StatusOK, resp.Code) require.Equal(t, 1, parser.parseCalled) require.Equal(t, 1, next.called) - require.NotNil(t, next.jwtClaims) - nextIssuer, err := next.jwtClaims.GetIssuer() + require.NotNil(t, authkit.GetJWTClaimsFromContext(next.request.Context())) + nextIssuer, err := authkit.GetJWTClaimsFromContext(next.request.Context()).GetIssuer() require.NoError(t, err) require.Equal(t, issuer, nextIssuer) }) @@ -120,15 +124,16 @@ func TestJWTAuthMiddleware(t *testing.T) { introspector := &mockTokenIntrospector{errToReturn: errors.New("introspection failed")} next := &mockJWTAuthMiddlewareNextHandler{} req := httptest.NewRequest(http.MethodGet, "/", http.NoBody) - req.Header.Set(HeaderAuthorization, "Bearer a.b.c") + withBearerToken(req, testBearerToken) resp := httptest.NewRecorder() testutil.RequireSamplesCountInCounter(t, metrics.GetPrometheusMetrics("", metrics.SourceHTTPMiddleware). TokenIntrospectionsTotal.WithLabelValues(metrics.TokenIntrospectionStatusError), 0) - JWTAuthMiddleware(errDomain, parser, WithJWTAuthMiddlewareTokenIntrospector(introspector))(next).ServeHTTP(resp, req) + authkit.JWTAuthMiddleware(testErrDomain, parser, authkit.WithJWTAuthMiddlewareTokenIntrospector(introspector))(next). + ServeHTTP(resp, req) - testutil.RequireErrorInRecorder(t, resp, http.StatusUnauthorized, errDomain, ErrCodeAuthenticationFailed) + testutil.RequireErrorInRecorder(t, resp, http.StatusUnauthorized, testErrDomain, authkit.ErrCodeAuthenticationFailed) require.Equal(t, 1, introspector.introspectCalled) require.Equal(t, 0, parser.parseCalled) require.Equal(t, 0, next.called) @@ -143,19 +148,20 @@ func TestJWTAuthMiddleware(t *testing.T) { introspector := &mockTokenIntrospector{errToReturn: idptoken.ErrTokenIntrospectionNotNeeded} next := &mockJWTAuthMiddlewareNextHandler{} req := httptest.NewRequest(http.MethodGet, "/", http.NoBody) - req.Header.Set(HeaderAuthorization, "Bearer a.b.c") + withBearerToken(req, testBearerToken) resp := httptest.NewRecorder() testutil.RequireSamplesCountInCounter(t, metrics.GetPrometheusMetrics("", metrics.SourceHTTPMiddleware). TokenIntrospectionsTotal.WithLabelValues(metrics.TokenIntrospectionStatusNotNeeded), 0) - JWTAuthMiddleware(errDomain, parser, WithJWTAuthMiddlewareTokenIntrospector(introspector))(next).ServeHTTP(resp, req) + authkit.JWTAuthMiddleware(testErrDomain, parser, authkit.WithJWTAuthMiddlewareTokenIntrospector(introspector))(next). + ServeHTTP(resp, req) require.Equal(t, http.StatusOK, resp.Code) require.Equal(t, 1, introspector.introspectCalled) require.Equal(t, 1, parser.parseCalled) require.Equal(t, 1, next.called) - nextIssuer, err := next.jwtClaims.GetIssuer() + nextIssuer, err := authkit.GetJWTClaimsFromContext(next.request.Context()).GetIssuer() require.NoError(t, err) require.Equal(t, issuer, nextIssuer) @@ -169,20 +175,21 @@ func TestJWTAuthMiddleware(t *testing.T) { introspector := &mockTokenIntrospector{errToReturn: idptoken.ErrTokenNotIntrospectable} next := &mockJWTAuthMiddlewareNextHandler{} req := httptest.NewRequest(http.MethodGet, "/", http.NoBody) - req.Header.Set(HeaderAuthorization, "Bearer a.b.c") + withBearerToken(req, testBearerToken) resp := httptest.NewRecorder() testutil.RequireSamplesCountInCounter(t, metrics.GetPrometheusMetrics("", metrics.SourceHTTPMiddleware). TokenIntrospectionsTotal.WithLabelValues(metrics.TokenIntrospectionStatusNotIntrospectable), 0) - JWTAuthMiddleware(errDomain, parser, WithJWTAuthMiddlewareTokenIntrospector(introspector))(next).ServeHTTP(resp, req) + authkit.JWTAuthMiddleware(testErrDomain, parser, authkit.WithJWTAuthMiddlewareTokenIntrospector(introspector))(next). + ServeHTTP(resp, req) require.Equal(t, http.StatusOK, resp.Code) require.Equal(t, 1, introspector.introspectCalled) require.Equal(t, 1, parser.parseCalled) require.Equal(t, 1, next.called) - require.NotNil(t, next.jwtClaims) - nextIssuer, err := next.jwtClaims.GetIssuer() + require.NotNil(t, authkit.GetJWTClaimsFromContext(next.request.Context())) + nextIssuer, err := authkit.GetJWTClaimsFromContext(next.request.Context()).GetIssuer() require.NoError(t, err) require.Equal(t, issuer, nextIssuer) @@ -195,15 +202,16 @@ func TestJWTAuthMiddleware(t *testing.T) { introspector := &mockTokenIntrospector{resultToReturn: &idptoken.DefaultIntrospectionResult{Active: false}} next := &mockJWTAuthMiddlewareNextHandler{} req := httptest.NewRequest(http.MethodGet, "/", http.NoBody) - req.Header.Set(HeaderAuthorization, "Bearer a.b.c") + withBearerToken(req, testBearerToken) resp := httptest.NewRecorder() testutil.RequireSamplesCountInCounter(t, metrics.GetPrometheusMetrics("", metrics.SourceHTTPMiddleware). TokenIntrospectionsTotal.WithLabelValues(metrics.TokenIntrospectionStatusNotActive), 0) - JWTAuthMiddleware(errDomain, parser, WithJWTAuthMiddlewareTokenIntrospector(introspector))(next).ServeHTTP(resp, req) + authkit.JWTAuthMiddleware(testErrDomain, parser, authkit.WithJWTAuthMiddlewareTokenIntrospector(introspector))(next). + ServeHTTP(resp, req) - testutil.RequireErrorInRecorder(t, resp, http.StatusUnauthorized, errDomain, ErrCodeAuthenticationFailed) + testutil.RequireErrorInRecorder(t, resp, http.StatusUnauthorized, testErrDomain, authkit.ErrCodeAuthenticationFailed) require.Equal(t, 1, introspector.introspectCalled) require.Equal(t, 0, parser.parseCalled) require.Equal(t, 0, next.called) @@ -219,45 +227,74 @@ func TestJWTAuthMiddleware(t *testing.T) { Active: true, DefaultClaims: jwt.DefaultClaims{RegisteredClaims: jwtgo.RegisteredClaims{Issuer: issuer}}}} next := &mockJWTAuthMiddlewareNextHandler{} req := httptest.NewRequest(http.MethodGet, "/", http.NoBody) - req.Header.Set(HeaderAuthorization, "Bearer a.b.c") + withBearerToken(req, testBearerToken) resp := httptest.NewRecorder() testutil.RequireSamplesCountInCounter(t, metrics.GetPrometheusMetrics("", metrics.SourceHTTPMiddleware). TokenIntrospectionsTotal.WithLabelValues(metrics.TokenIntrospectionStatusActive), 0) - JWTAuthMiddleware(errDomain, parser, WithJWTAuthMiddlewareTokenIntrospector(introspector))(next).ServeHTTP(resp, req) + authkit.JWTAuthMiddleware(testErrDomain, parser, authkit.WithJWTAuthMiddlewareTokenIntrospector(introspector))(next). + ServeHTTP(resp, req) require.Equal(t, http.StatusOK, resp.Code) require.Equal(t, 1, introspector.introspectCalled) require.Equal(t, 0, parser.parseCalled) require.Equal(t, 1, next.called) - require.NotNil(t, next.jwtClaims) - nextIssuer, err := next.jwtClaims.GetIssuer() + require.NotNil(t, authkit.GetJWTClaimsFromContext(next.request.Context())) + nextIssuer, err := authkit.GetJWTClaimsFromContext(next.request.Context()).GetIssuer() require.NoError(t, err) require.Equal(t, issuer, nextIssuer) testutil.RequireSamplesCountInCounter(t, metrics.GetPrometheusMetrics("", metrics.SourceHTTPMiddleware). TokenIntrospectionsTotal.WithLabelValues(metrics.TokenIntrospectionStatusActive), 1) }) + + t.Run("context keys added by verifyAccess are preserved", func(t *testing.T) { + const issuer = "my-idp.com" + parser := &mockJWTParser{claimsToReturn: &jwt.DefaultClaims{RegisteredClaims: jwtgo.RegisteredClaims{Issuer: issuer}}} + next := &mockJWTAuthMiddlewareNextHandler{} + req := httptest.NewRequest(http.MethodGet, "/", http.NoBody) + withBearerToken(req, testBearerToken) + resp := httptest.NewRecorder() + + const ( + ctxKey = "verify-access-key" + ctxValue = "verify-access-value" + ) + var verifyAccess = func(r *http.Request, claims jwt.Claims) bool { + *r = *r.WithContext(context.WithValue(r.Context(), ctxKey, ctxValue)) + return true + } + + authkit.JWTAuthMiddleware(testErrDomain, parser, authkit.WithJWTAuthMiddlewareVerifyAccess(verifyAccess))(next). + ServeHTTP(resp, req) + + require.Equal(t, http.StatusOK, resp.Code) + require.Equal(t, 1, parser.parseCalled) + require.Equal(t, 1, next.called) + require.Equal(t, testBearerToken, authkit.GetBearerTokenFromContext(next.request.Context()), + "context is missing bearer token") + require.Equal(t, ctxValue, next.request.Context().Value(ctxKey), + "context key added by verifyAccess is not preserved") + }) } func TestJWTAuthMiddlewareWithVerifyAccess(t *testing.T) { - const errDomain = "TestDomain" - t.Run("authorization failed", func(t *testing.T) { parser := &mockJWTParser{claimsToReturn: &jwt.DefaultClaims{}} next := &mockJWTAuthMiddlewareNextHandler{} req := httptest.NewRequest(http.MethodGet, "/", http.NoBody) - req.Header.Set(HeaderAuthorization, "Bearer a.b.c") + withBearerToken(req, testBearerToken) resp := httptest.NewRecorder() - verifyAccess := NewVerifyAccessByRolesInJWT(Role{Namespace: "my-service", Name: "admin"}) - JWTAuthMiddleware(errDomain, parser, WithJWTAuthMiddlewareVerifyAccess(verifyAccess))(next).ServeHTTP(resp, req) + verifyAccess := authkit.NewVerifyAccessByRolesInJWT(authkit.Role{Namespace: "my-service", Name: "admin"}) + authkit.JWTAuthMiddleware(testErrDomain, parser, authkit.WithJWTAuthMiddlewareVerifyAccess(verifyAccess))(next). + ServeHTTP(resp, req) - testutil.RequireErrorInRecorder(t, resp, http.StatusForbidden, errDomain, ErrCodeAuthorizationFailed) + testutil.RequireErrorInRecorder(t, resp, http.StatusForbidden, testErrDomain, authkit.ErrCodeAuthorizationFailed) require.Equal(t, 1, parser.parseCalled) require.Equal(t, 0, next.called) - require.Nil(t, next.jwtClaims) + require.Nil(t, next.request) }) t.Run("ok", func(t *testing.T) { @@ -265,16 +302,21 @@ func TestJWTAuthMiddlewareWithVerifyAccess(t *testing.T) { parser := &mockJWTParser{claimsToReturn: &jwt.DefaultClaims{Scope: scope}} next := &mockJWTAuthMiddlewareNextHandler{} req := httptest.NewRequest(http.MethodGet, "/", http.NoBody) - req.Header.Set(HeaderAuthorization, "Bearer a.b.c") + withBearerToken(req, testBearerToken) resp := httptest.NewRecorder() - verifyAccess := NewVerifyAccessByRolesInJWT(Role{Namespace: "my-service", Name: "admin"}) - JWTAuthMiddleware(errDomain, parser, WithJWTAuthMiddlewareVerifyAccess(verifyAccess))(next).ServeHTTP(resp, req) + verifyAccess := authkit.NewVerifyAccessByRolesInJWT(authkit.Role{Namespace: "my-service", Name: "admin"}) + authkit.JWTAuthMiddleware(testErrDomain, parser, authkit.WithJWTAuthMiddlewareVerifyAccess(verifyAccess))(next). + ServeHTTP(resp, req) require.Equal(t, http.StatusOK, resp.Code) require.Equal(t, 1, parser.parseCalled) require.Equal(t, 1, next.called) - require.NotNil(t, next.jwtClaims) - require.EqualValues(t, scope, next.jwtClaims.GetScope()) + require.NotNil(t, authkit.GetJWTClaimsFromContext(next.request.Context())) + require.EqualValues(t, scope, authkit.GetJWTClaimsFromContext(next.request.Context()).GetScope()) }) } + +func withBearerToken(r *http.Request, t string) { + r.Header.Set(authkit.HeaderAuthorization, "Bearer "+t) +}