From d3ea83d5c6bf7c1d7f116b59617f518f42997d55 Mon Sep 17 00:00:00 2001 From: p53 Date: Fri, 1 Mar 2024 22:59:30 +0100 Subject: [PATCH] Refactor oauthcallback handler --- pkg/keycloak/proxy/handlers.go | 114 ++++++++++++++++++--------------- pkg/keycloak/proxy/server.go | 14 +++- 2 files changed, 76 insertions(+), 52 deletions(-) diff --git a/pkg/keycloak/proxy/handlers.go b/pkg/keycloak/proxy/handlers.go index f4526c3c..37cc1c4d 100644 --- a/pkg/keycloak/proxy/handlers.go +++ b/pkg/keycloak/proxy/handlers.go @@ -106,72 +106,84 @@ func getRedirectionURL( } // oauthAuthorizationHandler is responsible for performing the redirection to oauth provider -func (r *OauthProxy) oauthAuthorizationHandler(wrt http.ResponseWriter, req *http.Request) { - if r.Config.SkipTokenVerification { - wrt.WriteHeader(http.StatusNotAcceptable) - return - } +func oauthAuthorizationHandler( + logger *zap.Logger, + skipTokenVerification bool, + scopes []string, + enablePKCE bool, + signInPage string, + cookManager *cookie.Manager, + newOAuth2Config func(redirectionURL string) *oauth2.Config, + getRedirectionURL func(wrt http.ResponseWriter, req *http.Request) string, + customSignInPage func(wrt http.ResponseWriter, authURL string), +) func(wrt http.ResponseWriter, req *http.Request) { + return func(wrt http.ResponseWriter, req *http.Request) { + if skipTokenVerification { + wrt.WriteHeader(http.StatusNotAcceptable) + return + } - scope, assertOk := req.Context().Value(constant.ContextScopeName).(*RequestScope) - if !assertOk { - r.Log.Error(apperrors.ErrAssertionFailed.Error()) - return - } + scope, assertOk := req.Context().Value(constant.ContextScopeName).(*RequestScope) + if !assertOk { + logger.Error(apperrors.ErrAssertionFailed.Error()) + return + } - scope.Logger.Debug("authorization handler") + scope.Logger.Debug("authorization handler") - conf := r.newOAuth2Config(r.getRedirectionURL(wrt, req)) - // step: set the access type of the session - accessType := oauth2.AccessTypeOnline + conf := newOAuth2Config(getRedirectionURL(wrt, req)) + // step: set the access type of the session + accessType := oauth2.AccessTypeOnline - if utils.ContainedIn("offline", r.Config.Scopes) { - accessType = oauth2.AccessTypeOffline - } + if utils.ContainedIn("offline", scopes) { + accessType = oauth2.AccessTypeOffline + } - authCodeOptions := []oauth2.AuthCodeOption{ - accessType, - } + authCodeOptions := []oauth2.AuthCodeOption{ + accessType, + } - if r.Config.EnablePKCE { - codeVerifier, err := pkce.NewCodeVerifierWithLength(96) - if err != nil { - r.Log.Error( - apperrors.ErrPKCECodeCreation.Error(), + if enablePKCE { + codeVerifier, err := pkce.NewCodeVerifierWithLength(96) + if err != nil { + logger.Error( + apperrors.ErrPKCECodeCreation.Error(), + ) + return + } + + codeChallenge := pkce.CodeChallengeS256(codeVerifier) + authCodeOptions = append( + authCodeOptions, + oauth2.SetAuthURLParam(pkce.ParamCodeChallenge, codeChallenge), + oauth2.SetAuthURLParam(pkce.ParamCodeChallengeMethod, pkce.MethodS256), ) - return + cookManager.DropPKCECookie(wrt, codeVerifier) } - codeChallenge := pkce.CodeChallengeS256(codeVerifier) - authCodeOptions = append( - authCodeOptions, - oauth2.SetAuthURLParam(pkce.ParamCodeChallenge, codeChallenge), - oauth2.SetAuthURLParam(pkce.ParamCodeChallengeMethod, pkce.MethodS256), + authURL := conf.AuthCodeURL( + req.URL.Query().Get("state"), + authCodeOptions..., ) - r.Cm.DropPKCECookie(wrt, codeVerifier) - } - authURL := conf.AuthCodeURL( - req.URL.Query().Get("state"), - authCodeOptions..., - ) + clientIP := utils.RealIP(req) - clientIP := utils.RealIP(req) + scope.Logger.Debug( + "incoming authorization request from client address", + zap.Any("access_type", accessType), + zap.String("client_ip", clientIP), + zap.String("remote_addr", req.RemoteAddr), + ) - scope.Logger.Debug( - "incoming authorization request from client address", - zap.Any("access_type", accessType), - zap.String("client_ip", clientIP), - zap.String("remote_addr", req.RemoteAddr), - ) + // step: if we have a custom sign in page, lets display that + if signInPage != "" { + customSignInPage(wrt, signInPage) + return + } - // step: if we have a custom sign in page, lets display that - if r.Config.SignInPage != "" { - r.customSignInPage(wrt, r.Config.SignInPage) - return + scope.Logger.Debug("redirecting to auth_url", zap.String("auth_url", authURL)) + redirectToURL(scope.Logger, authURL, wrt, req, http.StatusSeeOther) } - - scope.Logger.Debug("redirecting to auth_url", zap.String("auth_url", authURL)) - redirectToURL(scope.Logger, authURL, wrt, req, http.StatusSeeOther) } /* diff --git a/pkg/keycloak/proxy/server.go b/pkg/keycloak/proxy/server.go index 9ada4e8c..2c14dcad 100644 --- a/pkg/keycloak/proxy/server.go +++ b/pkg/keycloak/proxy/server.go @@ -492,10 +492,22 @@ func (r *OauthProxy) CreateReverseProxy() error { r.accessError, ) + oauthAuthorizationHand := oauthAuthorizationHandler( + r.Log, + r.Config.SkipTokenVerification, + r.Config.Scopes, + r.Config.EnablePKCE, + r.Config.SignInPage, + r.Cm, + r.newOAuth2Config, + r.getRedirectionURL, + r.customSignInPage, + ) + // step: add the routing for oauth engine.With(proxyDenyMiddleware(r.Log)).Route(r.Config.BaseURI+r.Config.OAuthURI, func(eng chi.Router) { eng.MethodNotAllowed(handlers.MethodNotAllowHandlder) - eng.HandleFunc(constant.AuthorizationURL, r.oauthAuthorizationHandler) + eng.HandleFunc(constant.AuthorizationURL, oauthAuthorizationHand) eng.Get(constant.CallbackURL, oauthCallbackHand) eng.Get(constant.ExpiredURL, expirationHandler(r.GetIdentity, r.Config.CookieAccessName)) eng.With(authMid).Get(constant.LogoutURL, logoutHand)