From 5195727f7ea0c67e3a1499e878e3d55da8dcf8cb Mon Sep 17 00:00:00 2001 From: Arne Luenser Date: Mon, 2 Sep 2024 16:07:25 +0200 Subject: [PATCH] fix: pass URL params explicitly --- selfservice/strategy/code/strategy_login.go | 4 ++-- .../strategy/code/strategy_recovery.go | 18 ++++++++-------- selfservice/strategy/oidc/strategy.go | 21 +++++++++---------- .../strategy/oidc/strategy_registration.go | 1 + 4 files changed, 22 insertions(+), 22 deletions(-) diff --git a/selfservice/strategy/code/strategy_login.go b/selfservice/strategy/code/strategy_login.go index cb734b1bf4a1..63a143ef9596 100644 --- a/selfservice/strategy/code/strategy_login.go +++ b/selfservice/strategy/code/strategy_login.go @@ -231,7 +231,7 @@ func (s *Strategy) Login(w http.ResponseWriter, r *http.Request, f *login.Flow, } return nil, nil case flow.StateEmailSent: - i, err := s.loginVerifyCode(ctx, r, f, &p, sess) + i, err := s.loginVerifyCode(ctx, f, &p, sess) if err != nil { return nil, s.HandleLoginError(r, f, &p, err) } @@ -437,7 +437,7 @@ func maybeNormalizeEmail(input string) string { return input } -func (s *Strategy) loginVerifyCode(ctx context.Context, r *http.Request, f *login.Flow, p *updateLoginFlowWithCodeMethod, sess *session.Session) (_ *identity.Identity, err error) { +func (s *Strategy) loginVerifyCode(ctx context.Context, f *login.Flow, p *updateLoginFlowWithCodeMethod, sess *session.Session) (_ *identity.Identity, err error) { ctx, span := s.deps.Tracer(ctx).Tracer().Start(ctx, "selfservice.strategy.code.strategy.loginVerifyCode") defer otelx.End(span, &err) diff --git a/selfservice/strategy/code/strategy_recovery.go b/selfservice/strategy/code/strategy_recovery.go index 8ab34c4f1ce2..d4401beb4478 100644 --- a/selfservice/strategy/code/strategy_recovery.go +++ b/selfservice/strategy/code/strategy_recovery.go @@ -206,7 +206,7 @@ func (s *Strategy) recoveryIssueSession(w http.ResponseWriter, r *http.Request, return s.retryRecoveryFlow(w, r, f.Type, RetryWithError(err)) } case flow.TypeAPI: - if err := s.deps.SessionPersister().UpsertSession(r.Context(), sess); err != nil { + if err := s.deps.SessionPersister().UpsertSession(ctx, sess); err != nil { return s.retryRecoveryFlow(w, r, f.Type, RetryWithError(err)) } f.ContinueWith = append(f.ContinueWith, flow.NewContinueWithSetToken(sess.Token)) @@ -217,7 +217,7 @@ func (s *Strategy) recoveryIssueSession(w http.ResponseWriter, r *http.Request, return s.retryRecoveryFlow(w, r, f.Type, RetryWithError(err)) } - returnToURL := s.deps.Config().SelfServiceFlowRecoveryReturnTo(r.Context(), nil) + returnToURL := s.deps.Config().SelfServiceFlowRecoveryReturnTo(ctx, nil) returnTo := "" if returnToURL != nil { returnTo = returnToURL.String() @@ -230,12 +230,12 @@ func (s *Strategy) recoveryIssueSession(w http.ResponseWriter, r *http.Request, config := s.deps.Config() sf.UI.Messages.Set(text.NewRecoverySuccessful(time.Now().Add(config.SelfServiceFlowSettingsPrivilegedSessionMaxAge(ctx)))) - if err := s.deps.SettingsFlowPersister().UpdateSettingsFlow(r.Context(), sf); err != nil { + if err := s.deps.SettingsFlowPersister().UpdateSettingsFlow(ctx, sf); err != nil { return s.retryRecoveryFlow(w, r, f.Type, RetryWithError(err)) } if s.deps.Config().UseContinueWithTransitions(ctx) { - redirectTo := sf.AppendTo(s.deps.Config().SelfServiceFlowSettingsUI(r.Context())).String() + redirectTo := sf.AppendTo(s.deps.Config().SelfServiceFlowSettingsUI(ctx)).String() switch { case f.Type.IsAPI(), x.IsJSONRequest(r): f.ContinueWith = append(f.ContinueWith, flow.NewContinueWithSettingsUI(sf, redirectTo)) @@ -245,9 +245,9 @@ func (s *Strategy) recoveryIssueSession(w http.ResponseWriter, r *http.Request, } } else { if x.IsJSONRequest(r) { - s.deps.Writer().WriteError(w, r, flow.NewBrowserLocationChangeRequiredError(sf.AppendTo(s.deps.Config().SelfServiceFlowSettingsUI(r.Context())).String())) + s.deps.Writer().WriteError(w, r, flow.NewBrowserLocationChangeRequiredError(sf.AppendTo(s.deps.Config().SelfServiceFlowSettingsUI(ctx)).String())) } else { - http.Redirect(w, r, sf.AppendTo(s.deps.Config().SelfServiceFlowSettingsUI(r.Context())).String(), http.StatusSeeOther) + http.Redirect(w, r, sf.AppendTo(s.deps.Config().SelfServiceFlowSettingsUI(ctx)).String(), http.StatusSeeOther) } } @@ -265,7 +265,7 @@ func (s *Strategy) recoveryUseCode(w http.ResponseWriter, r *http.Request, body } if f.Type == flow.TypeBrowser && !x.IsJSONRequest(r) { - http.Redirect(w, r, f.AppendTo(s.deps.Config().SelfServiceFlowRecoveryUI(r.Context())).String(), http.StatusSeeOther) + http.Redirect(w, r, f.AppendTo(s.deps.Config().SelfServiceFlowRecoveryUI(ctx)).String(), http.StatusSeeOther) } else { s.deps.Writer().Write(w, r, f) } @@ -395,7 +395,7 @@ func (s *Strategy) recoveryHandleFormSubmission(w http.ResponseWriter, r *http.R // re-initialize the UI with a "clean" new state f.UI = &container.Container{ Method: "POST", - Action: flow.AppendFlowTo(urlx.AppendPaths(s.deps.Config().SelfPublicURL(r.Context()), recovery.RouteSubmitFlow), f.ID).String(), + Action: flow.AppendFlowTo(urlx.AppendPaths(s.deps.Config().SelfPublicURL(ctx), recovery.RouteSubmitFlow), f.ID).String(), } f.UI.SetCSRF(s.deps.GenerateCSRFToken(r)) @@ -420,7 +420,7 @@ func (s *Strategy) recoveryHandleFormSubmission(w http.ResponseWriter, r *http.R f.UI.Nodes.Append(node.NewInputField("email", body.Email, node.CodeGroup, node.InputAttributeTypeSubmit). WithMetaLabel(text.NewInfoNodeResendOTP()), ) - if err := s.deps.RecoveryFlowPersister().UpdateRecoveryFlow(r.Context(), f); err != nil { + if err := s.deps.RecoveryFlowPersister().UpdateRecoveryFlow(ctx, f); err != nil { return s.HandleRecoveryError(w, r, f, body, err) } diff --git a/selfservice/strategy/oidc/strategy.go b/selfservice/strategy/oidc/strategy.go index 38483e14fdc4..fc92b2f3a16a 100644 --- a/selfservice/strategy/oidc/strategy.go +++ b/selfservice/strategy/oidc/strategy.go @@ -246,7 +246,7 @@ func (s *Strategy) validateFlow(ctx context.Context, r *http.Request, rid uuid.U return ar, err // this must return the error } -func (s *Strategy) ValidateCallback(w http.ResponseWriter, r *http.Request) (f contextFlow, providerID string, ac *AuthCodeContainer, err error) { +func (s *Strategy) ValidateCallback(w http.ResponseWriter, r *http.Request, urlParams httprouter.Params) (f contextFlow, providerID string, ac *AuthCodeContainer, err error) { var ( codeParam = stringsx.Coalesce(r.URL.Query().Get("code"), r.URL.Query().Get("authCode")) stateParam = r.URL.Query().Get("state") @@ -268,7 +268,6 @@ func (s *Strategy) ValidateCallback(w http.ResponseWriter, r *http.Request) (f c // Determine the provider from the flow context or the URL. providerID = providerFromFlow(f) - urlParams, _ := r.Context().Value(httprouter.ParamsKey).(httprouter.Params) if providerFromURL := urlParams.ByName("provider"); providerFromURL != "" { // We're serving an old-style OIDC callback URL with provider in the URL. if providerID == "" { @@ -372,7 +371,7 @@ func (s *Strategy) HandleCallback(w http.ResponseWriter, r *http.Request, ps htt defer otelx.End(span, &err) r = r.WithContext(ctx) - f, pid, cntnr, err := s.ValidateCallback(w, r) + f, pid, cntnr, err := s.ValidateCallback(w, r, ps) if err != nil { if f != nil { s.forwardError(w, r, f, s.handleError(ctx, w, r, f, pid, nil, err)) @@ -578,28 +577,28 @@ func (s *Strategy) handleError(ctx context.Context, w http.ResponseWriter, r *ht rf.UI.Messages.Add(text.NewErrorValidationDuplicateCredentialsOnOIDCLink()) } - lf, err := s.registrationToLogin(w, r, rf, usedProviderID) + lf, err := s.registrationToLogin(w, r, rf) if err != nil { return err } // return a new login flow with the error message embedded in the login flow. var redirectURL *url.URL if lf.Type == flow.TypeAPI { - returnTo := s.d.Config().SelfServiceBrowserDefaultReturnTo(r.Context()) + returnTo := s.d.Config().SelfServiceBrowserDefaultReturnTo(ctx) if redirecter, ok := f.(flow.FlowWithRedirect); ok { - secureReturnTo, err := x.SecureRedirectTo(r, returnTo, redirecter.SecureRedirectToOpts(r.Context(), s.d)...) + secureReturnTo, err := x.SecureRedirectTo(r, returnTo, redirecter.SecureRedirectToOpts(ctx, s.d)...) if err == nil { returnTo = secureReturnTo } } redirectURL = lf.AppendTo(returnTo) } else { - redirectURL = lf.AppendTo(s.d.Config().SelfServiceFlowLoginUI(r.Context())) + redirectURL = lf.AppendTo(s.d.Config().SelfServiceFlowLoginUI(ctx)) } if dc, err := flow.DuplicateCredentials(lf); err == nil && dc != nil { redirectURL = urlx.CopyWithQuery(redirectURL, url.Values{"no_org_ui": {"true"}}) - s.populateAccountLinkingUI(r.Context(), lf, usedProviderID, dc.DuplicateIdentifier, dup.AvailableCredentials()) - if err := s.d.LoginFlowPersister().UpdateLoginFlow(r.Context(), lf); err != nil { + s.populateAccountLinkingUI(ctx, lf, usedProviderID, dc.DuplicateIdentifier, dup.AvailableCredentials()) + if err := s.d.LoginFlowPersister().UpdateLoginFlow(ctx, lf); err != nil { return err } } @@ -615,12 +614,12 @@ func (s *Strategy) handleError(ctx context.Context, w http.ResponseWriter, r *ht AddProvider(rf.UI, usedProviderID, text.NewInfoRegistrationContinue()) if traits != nil { - ds, err := s.d.Config().DefaultIdentityTraitsSchemaURL(r.Context()) + ds, err := s.d.Config().DefaultIdentityTraitsSchemaURL(ctx) if err != nil { return err } - traitNodes, err := container.NodesFromJSONSchema(r.Context(), node.OpenIDConnectGroup, ds.String(), "", nil) + traitNodes, err := container.NodesFromJSONSchema(ctx, node.OpenIDConnectGroup, ds.String(), "", nil) if err != nil { return err } diff --git a/selfservice/strategy/oidc/strategy_registration.go b/selfservice/strategy/oidc/strategy_registration.go index e786d9fa4211..cfc49ce1e8b4 100644 --- a/selfservice/strategy/oidc/strategy_registration.go +++ b/selfservice/strategy/oidc/strategy_registration.go @@ -296,6 +296,7 @@ func (s *Strategy) registrationToLogin(w http.ResponseWriter, r *http.Request, r func (s *Strategy) processRegistration(ctx context.Context, w http.ResponseWriter, r *http.Request, rf *registration.Flow, token *identity.CredentialsOIDCEncryptedTokens, claims *Claims, provider Provider, container *AuthCodeContainer, idToken string) (_ *login.Flow, err error) { ctx, span := s.d.Tracer(ctx).Tracer().Start(ctx, "selfservice.strategy.oidc.strategy.processRegistration") defer otelx.End(span, &err) + r = r.WithContext(ctx) if _, _, err := s.d.PrivilegedIdentityPool().FindByCredentialsIdentifier(r.Context(), identity.CredentialsTypeOIDC, identity.OIDCUniqueID(provider.Config().ID, claims.Subject)); err == nil { // If the identity already exists, we should perform the login flow instead.