Skip to content

Commit

Permalink
Refactor oauthcallback handler
Browse files Browse the repository at this point in the history
  • Loading branch information
p53 authored Mar 1, 2024
1 parent 3624c43 commit d3ea83d
Show file tree
Hide file tree
Showing 2 changed files with 76 additions and 52 deletions.
114 changes: 63 additions & 51 deletions pkg/keycloak/proxy/handlers.go
Original file line number Diff line number Diff line change
Expand Up @@ -106,72 +106,84 @@ func getRedirectionURL(
}

// oauthAuthorizationHandler is responsible for performing the redirection to oauth provider
func (r *OauthProxy) oauthAuthorizationHandler(wrt http.ResponseWriter, req *http.Request) {
if r.Config.SkipTokenVerification {
wrt.WriteHeader(http.StatusNotAcceptable)
return
}
func oauthAuthorizationHandler(
logger *zap.Logger,
skipTokenVerification bool,
scopes []string,
enablePKCE bool,
signInPage string,
cookManager *cookie.Manager,
newOAuth2Config func(redirectionURL string) *oauth2.Config,
getRedirectionURL func(wrt http.ResponseWriter, req *http.Request) string,
customSignInPage func(wrt http.ResponseWriter, authURL string),
) func(wrt http.ResponseWriter, req *http.Request) {
return func(wrt http.ResponseWriter, req *http.Request) {
if skipTokenVerification {
wrt.WriteHeader(http.StatusNotAcceptable)
return
}

scope, assertOk := req.Context().Value(constant.ContextScopeName).(*RequestScope)
if !assertOk {
r.Log.Error(apperrors.ErrAssertionFailed.Error())
return
}
scope, assertOk := req.Context().Value(constant.ContextScopeName).(*RequestScope)
if !assertOk {
logger.Error(apperrors.ErrAssertionFailed.Error())
return
}

scope.Logger.Debug("authorization handler")
scope.Logger.Debug("authorization handler")

conf := r.newOAuth2Config(r.getRedirectionURL(wrt, req))
// step: set the access type of the session
accessType := oauth2.AccessTypeOnline
conf := newOAuth2Config(getRedirectionURL(wrt, req))
// step: set the access type of the session
accessType := oauth2.AccessTypeOnline

if utils.ContainedIn("offline", r.Config.Scopes) {
accessType = oauth2.AccessTypeOffline
}
if utils.ContainedIn("offline", scopes) {
accessType = oauth2.AccessTypeOffline
}

authCodeOptions := []oauth2.AuthCodeOption{
accessType,
}
authCodeOptions := []oauth2.AuthCodeOption{
accessType,
}

if r.Config.EnablePKCE {
codeVerifier, err := pkce.NewCodeVerifierWithLength(96)
if err != nil {
r.Log.Error(
apperrors.ErrPKCECodeCreation.Error(),
if enablePKCE {
codeVerifier, err := pkce.NewCodeVerifierWithLength(96)
if err != nil {
logger.Error(
apperrors.ErrPKCECodeCreation.Error(),
)
return
}

codeChallenge := pkce.CodeChallengeS256(codeVerifier)
authCodeOptions = append(
authCodeOptions,
oauth2.SetAuthURLParam(pkce.ParamCodeChallenge, codeChallenge),
oauth2.SetAuthURLParam(pkce.ParamCodeChallengeMethod, pkce.MethodS256),
)
return
cookManager.DropPKCECookie(wrt, codeVerifier)
}

codeChallenge := pkce.CodeChallengeS256(codeVerifier)
authCodeOptions = append(
authCodeOptions,
oauth2.SetAuthURLParam(pkce.ParamCodeChallenge, codeChallenge),
oauth2.SetAuthURLParam(pkce.ParamCodeChallengeMethod, pkce.MethodS256),
authURL := conf.AuthCodeURL(
req.URL.Query().Get("state"),
authCodeOptions...,
)
r.Cm.DropPKCECookie(wrt, codeVerifier)
}

authURL := conf.AuthCodeURL(
req.URL.Query().Get("state"),
authCodeOptions...,
)
clientIP := utils.RealIP(req)

clientIP := utils.RealIP(req)
scope.Logger.Debug(
"incoming authorization request from client address",
zap.Any("access_type", accessType),
zap.String("client_ip", clientIP),
zap.String("remote_addr", req.RemoteAddr),
)

scope.Logger.Debug(
"incoming authorization request from client address",
zap.Any("access_type", accessType),
zap.String("client_ip", clientIP),
zap.String("remote_addr", req.RemoteAddr),
)
// step: if we have a custom sign in page, lets display that
if signInPage != "" {
customSignInPage(wrt, signInPage)
return
}

// step: if we have a custom sign in page, lets display that
if r.Config.SignInPage != "" {
r.customSignInPage(wrt, r.Config.SignInPage)
return
scope.Logger.Debug("redirecting to auth_url", zap.String("auth_url", authURL))
redirectToURL(scope.Logger, authURL, wrt, req, http.StatusSeeOther)
}

scope.Logger.Debug("redirecting to auth_url", zap.String("auth_url", authURL))
redirectToURL(scope.Logger, authURL, wrt, req, http.StatusSeeOther)
}

/*
Expand Down
14 changes: 13 additions & 1 deletion pkg/keycloak/proxy/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -492,10 +492,22 @@ func (r *OauthProxy) CreateReverseProxy() error {
r.accessError,
)

oauthAuthorizationHand := oauthAuthorizationHandler(
r.Log,
r.Config.SkipTokenVerification,
r.Config.Scopes,
r.Config.EnablePKCE,
r.Config.SignInPage,
r.Cm,
r.newOAuth2Config,
r.getRedirectionURL,
r.customSignInPage,
)

// step: add the routing for oauth
engine.With(proxyDenyMiddleware(r.Log)).Route(r.Config.BaseURI+r.Config.OAuthURI, func(eng chi.Router) {
eng.MethodNotAllowed(handlers.MethodNotAllowHandlder)
eng.HandleFunc(constant.AuthorizationURL, r.oauthAuthorizationHandler)
eng.HandleFunc(constant.AuthorizationURL, oauthAuthorizationHand)
eng.Get(constant.CallbackURL, oauthCallbackHand)
eng.Get(constant.ExpiredURL, expirationHandler(r.GetIdentity, r.Config.CookieAccessName))
eng.With(authMid).Get(constant.LogoutURL, logoutHand)
Expand Down

0 comments on commit d3ea83d

Please sign in to comment.