Skip to content

Commit

Permalink
In jwtAuthHandler.ServeHTTP(), no longer saving the request context o…
Browse files Browse the repository at this point in the history
…ff, to allow injecting key/value pairs by user defined functions, such as verifyAccess.
  • Loading branch information
Ruggero (Руджеро) Ferretti authored and vasayxtx committed Dec 20, 2024
1 parent 9b0feff commit 138392d
Show file tree
Hide file tree
Showing 2 changed files with 106 additions and 66 deletions.
36 changes: 17 additions & 19 deletions middleware.go
Original file line number Diff line number Diff line change
Expand Up @@ -139,33 +139,34 @@ func JWTAuthMiddleware(errorDomain string, jwtParser JWTParser, opts ...JWTAuthM
}

func (h *jwtAuthHandler) ServeHTTP(rw http.ResponseWriter, r *http.Request) {
reqCtx := r.Context()
logger := idputil.GetLoggerFromProvider(r.Context(), h.loggerProvider)

bearerToken := GetBearerTokenFromRequest(r)
if bearerToken == "" {
apiErr := restapi.NewError(h.errorDomain, ErrCodeBearerTokenMissing, ErrMessageBearerTokenMissing)
restapi.RespondError(rw, http.StatusUnauthorized, apiErr, h.logger(reqCtx))
restapi.RespondError(rw, http.StatusUnauthorized, apiErr, logger)
return
}
// Add the bearer token to the request context
r = r.WithContext(NewContextWithBearerToken(r.Context(), bearerToken))

var jwtClaims jwt.Claims
if h.tokenIntrospector != nil {
if introspectionResult, err := h.tokenIntrospector.IntrospectToken(reqCtx, bearerToken); err != nil {
if introspectionResult, err := h.tokenIntrospector.IntrospectToken(r.Context(), bearerToken); err != nil {
switch {
case errors.Is(err, idptoken.ErrTokenIntrospectionNotNeeded):
// Do nothing. Access Token already contains all necessary information for authN/authZ.
h.logger(reqCtx).AtLevel(log.LevelDebug, func(logFunc log.LogFunc) {
logger.AtLevel(log.LevelDebug, func(logFunc log.LogFunc) {
logFunc("token's introspection is not needed")
})
h.promMetrics.IncTokenIntrospectionsTotal(metrics.TokenIntrospectionStatusNotNeeded)
case errors.Is(err, idptoken.ErrTokenNotIntrospectable):
// Token is not introspectable by some reason.
// In this case, we will parse it as JWT and use it for authZ.
h.logger(reqCtx).Warn("token is not introspectable, it will be used for authentication and authorization as is",
logger.Warn("token is not introspectable, it will be used for authentication and authorization as is",
log.Error(err))
h.promMetrics.IncTokenIntrospectionsTotal(metrics.TokenIntrospectionStatusNotIntrospectable)
default:
logger := h.logger(reqCtx)
logger.Error("token's introspection failed", log.Error(err))
h.promMetrics.IncTokenIntrospectionsTotal(metrics.TokenIntrospectionStatusError)
apiErr := restapi.NewError(h.errorDomain, ErrCodeAuthenticationFailed, ErrMessageAuthenticationFailed)
Expand All @@ -174,14 +175,14 @@ func (h *jwtAuthHandler) ServeHTTP(rw http.ResponseWriter, r *http.Request) {
}
} else {
if !introspectionResult.IsActive() {
h.logger(reqCtx).Warn("token was successfully introspected, but it is not active")
logger.Warn("token was successfully introspected, but it is not active")
h.promMetrics.IncTokenIntrospectionsTotal(metrics.TokenIntrospectionStatusNotActive)
apiErr := restapi.NewError(h.errorDomain, ErrCodeAuthenticationFailed, ErrMessageAuthenticationFailed)
restapi.RespondError(rw, http.StatusUnauthorized, apiErr, h.logger(reqCtx))
restapi.RespondError(rw, http.StatusUnauthorized, apiErr, logger)
return
}
jwtClaims = introspectionResult.GetClaims()
h.logger(reqCtx).AtLevel(log.LevelDebug, func(logFunc log.LogFunc) {
logger.AtLevel(log.LevelDebug, func(logFunc log.LogFunc) {
logFunc("token was successfully introspected")
})
h.promMetrics.IncTokenIntrospectionsTotal(metrics.TokenIntrospectionStatusActive)
Expand All @@ -190,30 +191,27 @@ func (h *jwtAuthHandler) ServeHTTP(rw http.ResponseWriter, r *http.Request) {

if jwtClaims == nil {
var err error
if jwtClaims, err = h.jwtParser.Parse(reqCtx, bearerToken); err != nil {
logger := h.logger(reqCtx)
if jwtClaims, err = h.jwtParser.Parse(r.Context(), bearerToken); err != nil {
logger.Error("authentication failed", log.Error(err))
apiErr := restapi.NewError(h.errorDomain, ErrCodeAuthenticationFailed, ErrMessageAuthenticationFailed)
restapi.RespondError(rw, http.StatusUnauthorized, apiErr, logger)
return
}
}
// Add the JWT claims to the request context
r = r.WithContext(NewContextWithJWTClaims(r.Context(), jwtClaims))

if h.verifyAccess != nil {
// By passing a *http.Request to verifyAccess, we allow its implementations
// to inject new key/value pairs into the request context.
if !h.verifyAccess(r, jwtClaims) {
apiErr := restapi.NewError(h.errorDomain, ErrCodeAuthorizationFailed, ErrMessageAuthorizationFailed)
restapi.RespondError(rw, http.StatusForbidden, apiErr, h.logger(reqCtx))
restapi.RespondError(rw, http.StatusForbidden, apiErr, logger)
return
}
}

reqCtx = NewContextWithBearerToken(reqCtx, bearerToken)
reqCtx = NewContextWithJWTClaims(reqCtx, jwtClaims)
h.next.ServeHTTP(rw, r.WithContext(reqCtx))
}

func (h *jwtAuthHandler) logger(ctx context.Context) log.FieldLogger {
return idputil.GetLoggerFromProvider(ctx, h.loggerProvider)
h.next.ServeHTTP(rw, r)
}

// GetBearerTokenFromRequest extracts jwt token from request headers.
Expand Down
Loading

0 comments on commit 138392d

Please sign in to comment.