diff --git a/internal/client-go/go.sum b/internal/client-go/go.sum index c966c8ddfd0d..6cc3f5911d11 100644 --- a/internal/client-go/go.sum +++ b/internal/client-go/go.sum @@ -4,6 +4,7 @@ github.com/golang/protobuf v1.2.0/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5y golang.org/x/net v0.0.0-20180724234803-3673e40ba225/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= golang.org/x/net v0.0.0-20190108225652-1e06a53dbb7e h1:bRhVy7zSSasaqNksaRZiA5EEI+Ei4I1nO5Jh72wfHlg= golang.org/x/net v0.0.0-20190108225652-1e06a53dbb7e/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= +golang.org/x/oauth2 v0.0.0-20190604053449-0f29369cfe45/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw= golang.org/x/sync v0.0.0-20181221193216-37e7f081c4d4 h1:YUO/7uOKsKeq9UokNS62b8FYywz3ker1l1vDZRCRefw= golang.org/x/sync v0.0.0-20181221193216-37e7f081c4d4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= diff --git a/selfservice/flow/settings/handler_test.go b/selfservice/flow/settings/handler_test.go index 60920c101b19..9d4b3e670a17 100644 --- a/selfservice/flow/settings/handler_test.go +++ b/selfservice/flow/settings/handler_test.go @@ -580,7 +580,7 @@ func TestHandler(t *testing.T) { require.NoError(t, json.Unmarshal(body, &f)) actual, res := testhelpers.SettingsMakeRequest(t, false, true, &f, primaryUser, fmt.Sprintf(`{"method":"profile", "numby": 15, "csrf_token": "%s"}`, x.FakeCSRFToken)) - assert.Equal(t, http.StatusOK, res.StatusCode) + require.Equal(t, http.StatusOK, res.StatusCode) require.Len(t, primaryUser.Jar.Cookies(urlx.ParseOrPanic(publicTS.URL+login.RouteGetFlow)), 1) require.Contains(t, fmt.Sprintf("%v", primaryUser.Jar.Cookies(urlx.ParseOrPanic(publicTS.URL))), "ory_kratos_session") assert.Equal(t, "Your changes have been saved!", gjson.Get(actual, "ui.messages.0.text").String(), actual) diff --git a/selfservice/strategy/code/strategy_verification.go b/selfservice/strategy/code/strategy_verification.go index 5b30e4c2f5ec..9ca51a68f525 100644 --- a/selfservice/strategy/code/strategy_verification.go +++ b/selfservice/strategy/code/strategy_verification.go @@ -66,11 +66,15 @@ func (s *Strategy) decodeVerification(r *http.Request) (*updateVerificationFlowW } // handleVerificationError is a convenience function for handling all types of errors that may occur (e.g. validation error). -func (s *Strategy) handleVerificationError(w http.ResponseWriter, r *http.Request, f *verification.Flow, body *updateVerificationFlowWithCodeMethod, err error) error { +func (s *Strategy) handleVerificationError(r *http.Request, f *verification.Flow, body *updateVerificationFlowWithCodeMethod, err error) error { if f != nil { f.UI.SetCSRF(s.deps.GenerateCSRFToken(r)) + email := "" + if body != nil { + email = body.Email + } f.UI.GetNodes().Upsert( - node.NewInputField("email", body.Email, node.CodeGroup, node.InputAttributeTypeEmail, node.WithRequiredInputAttribute).WithMetaLabel(text.NewInfoNodeInputEmail()), + node.NewInputField("email", email, node.CodeGroup, node.InputAttributeTypeEmail, node.WithRequiredInputAttribute).WithMetaLabel(text.NewInfoNodeInputEmail()), ) } @@ -137,17 +141,17 @@ func (s *Strategy) Verify(w http.ResponseWriter, r *http.Request, f *verificatio body, err := s.decodeVerification(r) if err != nil { - return s.handleVerificationError(w, r, nil, body, err) + return s.handleVerificationError(r, nil, body, err) } f.TransientPayload = body.TransientPayload if err := flow.MethodEnabledAndAllowed(r.Context(), f.GetFlowName(), s.VerificationStrategyID(), string(body.getMethod()), s.deps); err != nil { - return s.handleVerificationError(w, r, f, body, err) + return s.handleVerificationError(r, f, body, err) } if err := f.Valid(); err != nil { - return s.handleVerificationError(w, r, f, body, err) + return s.handleVerificationError(r, f, body, err) } switch f.State { @@ -197,20 +201,20 @@ func (s *Strategy) verificationHandleFormSubmission(w http.ResponseWriter, r *ht return s.verificationUseCode(w, r, body.Code, f) } else if len(body.Email) == 0 { // If no code and no email was provided, fail with a validation error - return s.handleVerificationError(w, r, f, body, schema.NewRequiredError("#/email", "email")) + return s.handleVerificationError(r, f, body, schema.NewRequiredError("#/email", "email")) } if err := flow.EnsureCSRF(s.deps, r, f.Type, s.deps.Config().DisableAPIFlowEnforcement(r.Context()), s.deps.GenerateCSRFToken, body.CSRFToken); err != nil { - return s.handleVerificationError(w, r, f, body, err) + return s.handleVerificationError(r, f, body, err) } if err := s.deps.VerificationCodePersister().DeleteVerificationCodesOfFlow(r.Context(), f.ID); err != nil { - return s.handleVerificationError(w, r, f, body, err) + return s.handleVerificationError(r, f, body, err) } if err := s.deps.CodeSender().SendVerificationCode(r.Context(), f, identity.VerifiableAddressTypeEmail, body.Email); err != nil { if !errors.Is(err, ErrUnknownAddress) { - return s.handleVerificationError(w, r, f, body, err) + return s.handleVerificationError(r, f, body, err) } // Continue execution } @@ -218,7 +222,7 @@ func (s *Strategy) verificationHandleFormSubmission(w http.ResponseWriter, r *ht f.State = flow.StateEmailSent if err := s.PopulateVerificationMethod(r, f); err != nil { - return s.handleVerificationError(w, r, f, body, err) + return s.handleVerificationError(r, f, body, err) } if body.Email != "" { @@ -229,7 +233,7 @@ func (s *Strategy) verificationHandleFormSubmission(w http.ResponseWriter, r *ht } if err := s.deps.VerificationFlowPersister().UpdateVerificationFlow(r.Context(), f); err != nil { - return s.handleVerificationError(w, r, f, body, err) + return s.handleVerificationError(r, f, body, err) } return nil @@ -305,13 +309,13 @@ func (s *Strategy) retryVerificationFlowWithMessage(w http.ResponseWriter, r *ht f, err := verification.NewFlow(s.deps.Config(), s.deps.Config().SelfServiceFlowVerificationRequestLifespan(r.Context()), s.deps.CSRFHandler().RegenerateToken(w, r), r, s, ft) if err != nil { - return s.handleVerificationError(w, r, f, nil, err) + return s.handleVerificationError(r, f, nil, err) } f.UI.Messages.Add(message) if err := s.deps.VerificationFlowPersister().CreateVerificationFlow(r.Context(), f); err != nil { - return s.handleVerificationError(w, r, f, nil, err) + return s.handleVerificationError(r, f, nil, err) } if x.IsJSONRequest(r) { @@ -333,7 +337,7 @@ func (s *Strategy) retryVerificationFlowWithError(w http.ResponseWriter, r *http f, err := verification.NewFlow(s.deps.Config(), s.deps.Config().SelfServiceFlowVerificationRequestLifespan(r.Context()), s.deps.CSRFHandler().RegenerateToken(w, r), r, s, ft) if err != nil { - return s.handleVerificationError(w, r, f, nil, err) + return s.handleVerificationError(r, f, nil, err) } var toReturn error @@ -346,7 +350,7 @@ func (s *Strategy) retryVerificationFlowWithError(w http.ResponseWriter, r *http } if err := s.deps.VerificationFlowPersister().CreateVerificationFlow(r.Context(), f); err != nil { - return s.handleVerificationError(w, r, f, nil, err) + return s.handleVerificationError(r, f, nil, err) } if x.IsJSONRequest(r) { diff --git a/selfservice/strategy/idfirst/strategy_login.go b/selfservice/strategy/idfirst/strategy_login.go index d479ae1a5b05..eaca286f1fb7 100644 --- a/selfservice/strategy/idfirst/strategy_login.go +++ b/selfservice/strategy/idfirst/strategy_login.go @@ -27,7 +27,7 @@ var ( ErrNoCredentialsFound = errors.New("no credentials found") ) -func (s *Strategy) handleLoginError(w http.ResponseWriter, r *http.Request, f *login.Flow, payload *updateLoginFlowWithIdentifierFirstMethod, err error) error { +func (s *Strategy) handleLoginError(r *http.Request, f *login.Flow, payload updateLoginFlowWithIdentifierFirstMethod, err error) error { if f != nil { f.UI.Nodes.SetValueAttribute("identifier", payload.Identifier) if f.Type == flow.TypeBrowser { @@ -52,12 +52,12 @@ func (s *Strategy) Login(w http.ResponseWriter, r *http.Request, f *login.Flow, decoderx.HTTPDecoderSetValidatePayloads(true), decoderx.MustHTTPRawJSONSchemaCompiler(loginSchema), decoderx.HTTPDecoderJSONFollowsFormFormat()); err != nil { - return nil, s.handleLoginError(w, r, f, &p, err) + return nil, s.handleLoginError(r, f, p, err) } f.TransientPayload = p.TransientPayload if err := flow.EnsureCSRF(s.d, r, f.Type, s.d.Config().DisableAPIFlowEnforcement(r.Context()), s.d.GenerateCSRFToken, p.CSRFToken); err != nil { - return nil, s.handleLoginError(w, r, f, &p, err) + return nil, s.handleLoginError(r, f, p, err) } var opts []login.FormHydratorModifier @@ -74,11 +74,11 @@ func (s *Strategy) Login(w http.ResponseWriter, r *http.Request, f *login.Flow, // This will later be handled by `didPopulate`. } else if err != nil { // An error happened during lookup - return nil, s.handleLoginError(w, r, f, &p, err) + return nil, s.handleLoginError(r, f, p, err) } else if !s.d.Config().SecurityAccountEnumerationMitigate(r.Context()) { // Hydrate credentials if err := s.d.PrivilegedIdentityPool().HydrateIdentityAssociations(r.Context(), identityHint, identity.ExpandCredentials); err != nil { - return nil, s.handleLoginError(w, r, f, &p, err) + return nil, s.handleLoginError(r, f, p, err) } } @@ -102,7 +102,7 @@ func (s *Strategy) Login(w http.ResponseWriter, r *http.Request, f *login.Flow, } else if errors.Is(err, ErrNoCredentialsFound) { // This strategy is not responsible for this flow. We do not set didPopulate to true if that happens. } else if err != nil { - return nil, s.handleLoginError(w, r, f, &p, err) + return nil, s.handleLoginError(r, f, p, err) } else { didPopulate = true } @@ -111,7 +111,7 @@ func (s *Strategy) Login(w http.ResponseWriter, r *http.Request, f *login.Flow, // If no strategy populated, it means that the account (very likely) does not exist. We show a user not found error, // but only if account enumeration mitigation is disabled. Otherwise, we proceed to render the rest of the form. if !didPopulate && !s.d.Config().SecurityAccountEnumerationMitigate(r.Context()) { - return nil, s.handleLoginError(w, r, f, &p, errors.WithStack(schema.NewAccountNotFoundError())) + return nil, s.handleLoginError(r, f, p, errors.WithStack(schema.NewAccountNotFoundError())) } // We found credentials - hide the identifier. @@ -134,7 +134,7 @@ func (s *Strategy) Login(w http.ResponseWriter, r *http.Request, f *login.Flow, f.Active = s.ID() if err = s.d.LoginFlowPersister().UpdateLoginFlow(r.Context(), f); err != nil { - return nil, s.handleLoginError(w, r, f, &p, err) + return nil, s.handleLoginError(r, f, p, err) } if x.IsJSONRequest(r) { diff --git a/selfservice/strategy/link/strategy_verification.go b/selfservice/strategy/link/strategy_verification.go index 61f95da52fef..200334700c80 100644 --- a/selfservice/strategy/link/strategy_verification.go +++ b/selfservice/strategy/link/strategy_verification.go @@ -79,12 +79,16 @@ func (s *Strategy) decodeVerification(r *http.Request) (*verificationSubmitPaylo } // handleVerificationError is a convenience function for handling all types of errors that may occur (e.g. validation error). -func (s *Strategy) handleVerificationError(w http.ResponseWriter, r *http.Request, f *verification.Flow, body *verificationSubmitPayload, err error) error { +func (s *Strategy) handleVerificationError(r *http.Request, f *verification.Flow, body *verificationSubmitPayload, err error) error { if f != nil { f.UI.SetCSRF(s.d.GenerateCSRFToken(r)) + email := "" + if body != nil { + email = body.Email + } f.UI.GetNodes().Upsert( // v0.5: form.Field{Name: "email", Type: "email", Required: true, Value: body.Body.Email} - node.NewInputField("email", body.Email, node.LinkGroup, node.InputAttributeTypeEmail, node.WithRequiredInputAttribute).WithMetaLabel(text.NewInfoNodeInputEmail()), + node.NewInputField("email", email, node.LinkGroup, node.InputAttributeTypeEmail, node.WithRequiredInputAttribute).WithMetaLabel(text.NewInfoNodeInputEmail()), ) } @@ -132,32 +136,29 @@ func (s *Strategy) Verify(w http.ResponseWriter, r *http.Request, f *verificatio body, err := s.decodeVerification(r) if err != nil { - return s.handleVerificationError(w, r, nil, body, err) + return s.handleVerificationError(r, nil, body, err) } f.TransientPayload = body.TransientPayload if len(body.Token) > 0 { if err := flow.MethodEnabledAndAllowed(r.Context(), f.GetFlowName(), s.VerificationStrategyID(), s.VerificationStrategyID(), s.d); err != nil { - return s.handleVerificationError(w, r, nil, body, err) + return s.handleVerificationError(r, nil, body, err) } return s.verificationUseToken(w, r, body, f) } if err := flow.MethodEnabledAndAllowed(r.Context(), f.GetFlowName(), s.VerificationStrategyID(), body.Method, s.d); err != nil { - return s.handleVerificationError(w, r, f, body, err) + return s.handleVerificationError(r, f, body, err) } if err := f.Valid(); err != nil { - return s.handleVerificationError(w, r, f, body, err) + return s.handleVerificationError(r, f, body, err) } switch f.State { - case flow.StateChooseMethod: - fallthrough - case flow.StateEmailSent: - // Do nothing (continue with execution after this switch statement) - return s.verificationHandleFormSubmission(w, r, f) + case flow.StateChooseMethod, flow.StateEmailSent: + return s.verificationHandleFormSubmission(r, f) case flow.StatePassedChallenge: return s.retryVerificationFlowWithMessage(w, r, f.Type, text.NewErrorValidationVerificationRetrySuccess()) default: @@ -165,23 +166,23 @@ func (s *Strategy) Verify(w http.ResponseWriter, r *http.Request, f *verificatio } } -func (s *Strategy) verificationHandleFormSubmission(w http.ResponseWriter, r *http.Request, f *verification.Flow) error { +func (s *Strategy) verificationHandleFormSubmission(r *http.Request, f *verification.Flow) error { body, err := s.decodeVerification(r) if err != nil { - return s.handleVerificationError(w, r, f, body, err) + return s.handleVerificationError(r, f, body, err) } if len(body.Email) == 0 { - return s.handleVerificationError(w, r, f, body, schema.NewRequiredError("#/email", "email")) + return s.handleVerificationError(r, f, body, schema.NewRequiredError("#/email", "email")) } if err := flow.EnsureCSRF(s.d, r, f.Type, s.d.Config().DisableAPIFlowEnforcement(r.Context()), s.d.GenerateCSRFToken, body.CSRFToken); err != nil { - return s.handleVerificationError(w, r, f, body, err) + return s.handleVerificationError(r, f, body, err) } if err := s.d.LinkSender().SendVerificationLink(r.Context(), f, identity.VerifiableAddressTypeEmail, body.Email); err != nil { if !errors.Is(err, ErrUnknownAddress) { - return s.handleVerificationError(w, r, f, body, err) + return s.handleVerificationError(r, f, body, err) } // Continue execution } @@ -196,7 +197,7 @@ func (s *Strategy) verificationHandleFormSubmission(w http.ResponseWriter, r *ht f.State = flow.StateEmailSent f.UI.Messages.Set(text.NewVerificationEmailSent()) if err := s.d.VerificationFlowPersister().UpdateVerificationFlow(r.Context(), f); err != nil { - return s.handleVerificationError(w, r, f, body, err) + return s.handleVerificationError(r, f, body, err) } return nil @@ -268,12 +269,12 @@ func (s *Strategy) retryVerificationFlowWithMessage(w http.ResponseWriter, r *ht f, err := verification.NewFlow(s.d.Config(), s.d.Config().SelfServiceFlowVerificationRequestLifespan(r.Context()), s.d.CSRFHandler().RegenerateToken(w, r), r, s, ft) if err != nil { - return s.handleVerificationError(w, r, f, nil, err) + return s.handleVerificationError(r, f, nil, err) } f.UI.Messages.Add(message) if err := s.d.VerificationFlowPersister().CreateVerificationFlow(r.Context(), f); err != nil { - return s.handleVerificationError(w, r, f, nil, err) + return s.handleVerificationError(r, f, nil, err) } if ft == flow.TypeBrowser { @@ -292,7 +293,7 @@ func (s *Strategy) retryVerificationFlowWithError(w http.ResponseWriter, r *http f, err := verification.NewFlow(s.d.Config(), s.d.Config().SelfServiceFlowVerificationRequestLifespan(r.Context()), s.d.CSRFHandler().RegenerateToken(w, r), r, s, ft) if err != nil { - return s.handleVerificationError(w, r, f, nil, err) + return s.handleVerificationError(r, f, nil, err) } if expired := new(flow.ExpiredError); errors.As(verErr, &expired) { @@ -304,7 +305,7 @@ func (s *Strategy) retryVerificationFlowWithError(w http.ResponseWriter, r *http } if err := s.d.VerificationFlowPersister().CreateVerificationFlow(r.Context(), f); err != nil { - return s.handleVerificationError(w, r, f, nil, err) + return s.handleVerificationError(r, f, nil, err) } if ft == flow.TypeBrowser { diff --git a/selfservice/strategy/lookup/settings.go b/selfservice/strategy/lookup/settings.go index 1136d4d83414..3eee82d8d28a 100644 --- a/selfservice/strategy/lookup/settings.go +++ b/selfservice/strategy/lookup/settings.go @@ -101,20 +101,20 @@ func (s *Strategy) Settings(w http.ResponseWriter, r *http.Request, f *settings. var p updateSettingsFlowWithLookupMethod ctxUpdate, err := settings.PrepareUpdate(s.d, w, r, f, ss, settings.ContinuityKey(s.SettingsStrategyID()), &p) if errors.Is(err, settings.ErrContinuePreviousAction) { - return ctxUpdate, s.continueSettingsFlow(w, r, ctxUpdate, &p) + return ctxUpdate, s.continueSettingsFlow(r, ctxUpdate, p) } else if err != nil { - return ctxUpdate, s.handleSettingsError(w, r, ctxUpdate, &p, err) + return ctxUpdate, s.handleSettingsError(w, r, ctxUpdate, p, err) } if err := s.decodeSettingsFlow(r, &p); err != nil { - return ctxUpdate, s.handleSettingsError(w, r, ctxUpdate, &p, err) + return ctxUpdate, s.handleSettingsError(w, r, ctxUpdate, p, err) } if p.RegenerateLookup || p.RevealLookup || p.ConfirmLookup || p.DisableLookup { // This method has only two submit buttons p.Method = s.SettingsStrategyID() if err := flow.MethodEnabledAndAllowed(r.Context(), f.GetFlowName(), s.SettingsStrategyID(), p.Method, s.d); err != nil { - return nil, s.handleSettingsError(w, r, ctxUpdate, &p, err) + return nil, s.handleSettingsError(w, r, ctxUpdate, p, err) } } else { return nil, errors.WithStack(flow.ErrStrategyNotResponsible) @@ -122,8 +122,8 @@ func (s *Strategy) Settings(w http.ResponseWriter, r *http.Request, f *settings. // This does not come from the payload! p.Flow = ctxUpdate.Flow.ID.String() - if err := s.continueSettingsFlow(w, r, ctxUpdate, &p); err != nil { - return ctxUpdate, s.handleSettingsError(w, r, ctxUpdate, &p, err) + if err := s.continueSettingsFlow(r, ctxUpdate, p); err != nil { + return ctxUpdate, s.handleSettingsError(w, r, ctxUpdate, p, err) } return ctxUpdate, nil @@ -141,12 +141,10 @@ func (s *Strategy) decodeSettingsFlow(r *http.Request, dest interface{}) error { ) } -func (s *Strategy) continueSettingsFlow( - w http.ResponseWriter, r *http.Request, - ctxUpdate *settings.UpdateContext, p *updateSettingsFlowWithLookupMethod, -) error { +func (s *Strategy) continueSettingsFlow(r *http.Request, ctxUpdate *settings.UpdateContext, p updateSettingsFlowWithLookupMethod) error { + ctx := r.Context() if p.ConfirmLookup || p.RevealLookup || p.RegenerateLookup || p.DisableLookup { - if err := flow.MethodEnabledAndAllowed(r.Context(), flow.SettingsFlow, s.SettingsStrategyID(), s.SettingsStrategyID(), s.d); err != nil { + if err := flow.MethodEnabledAndAllowed(ctx, flow.SettingsFlow, s.SettingsStrategyID(), s.SettingsStrategyID(), s.d); err != nil { return err } @@ -162,16 +160,16 @@ func (s *Strategy) continueSettingsFlow( } if p.ConfirmLookup { - return s.continueSettingsFlowConfirm(w, r, ctxUpdate, p) + return s.continueSettingsFlowConfirm(ctx, ctxUpdate) } else if p.RevealLookup { - if err := s.continueSettingsFlowReveal(w, r, ctxUpdate, p); err != nil { + if err := s.continueSettingsFlowReveal(ctx, ctxUpdate); err != nil { return err } return flow.ErrStrategyAsksToReturnToUI } else if p.DisableLookup { - return s.continueSettingsFlowDisable(w, r, ctxUpdate, p) + return s.continueSettingsFlowDisable(ctx, ctxUpdate) } else if p.RegenerateLookup { - if err := s.continueSettingsFlowRegenerate(w, r, ctxUpdate, p); err != nil { + if err := s.continueSettingsFlowRegenerate(ctx, ctxUpdate); err != nil { return err } // regen @@ -181,8 +179,8 @@ func (s *Strategy) continueSettingsFlow( return errors.New("ended up in unexpected state") } -func (s *Strategy) continueSettingsFlowDisable(w http.ResponseWriter, r *http.Request, ctxUpdate *settings.UpdateContext, p *updateSettingsFlowWithLookupMethod) error { - i, err := s.d.PrivilegedIdentityPool().GetIdentityConfidential(r.Context(), ctxUpdate.Session.Identity.ID) +func (s *Strategy) continueSettingsFlowDisable(ctx context.Context, ctxUpdate *settings.UpdateContext) error { + i, err := s.d.PrivilegedIdentityPool().GetIdentityConfidential(ctx, ctxUpdate.Session.Identity.ID) if err != nil { return err } @@ -203,7 +201,7 @@ func (s *Strategy) continueSettingsFlowDisable(w http.ResponseWriter, r *http.Re return err } - if err := s.d.SettingsFlowPersister().UpdateSettingsFlow(r.Context(), ctxUpdate.Flow); err != nil { + if err := s.d.SettingsFlowPersister().UpdateSettingsFlow(ctx, ctxUpdate.Flow); err != nil { return err } @@ -211,8 +209,8 @@ func (s *Strategy) continueSettingsFlowDisable(w http.ResponseWriter, r *http.Re return nil } -func (s *Strategy) continueSettingsFlowReveal(w http.ResponseWriter, r *http.Request, ctxUpdate *settings.UpdateContext, p *updateSettingsFlowWithLookupMethod) error { - hasLookup, err := s.identityHasLookup(r.Context(), ctxUpdate.Session.IdentityID) +func (s *Strategy) continueSettingsFlowReveal(ctx context.Context, ctxUpdate *settings.UpdateContext) error { + hasLookup, err := s.identityHasLookup(ctx, ctxUpdate.Session.IdentityID) if err != nil { return err } @@ -221,7 +219,7 @@ func (s *Strategy) continueSettingsFlowReveal(w http.ResponseWriter, r *http.Req return errors.WithStack(herodot.ErrBadRequest.WithReasonf("Can not reveal lookup codes because you have none.")) } - _, cred, err := s.d.PrivilegedIdentityPool().FindByCredentialsIdentifier(r.Context(), s.ID(), ctxUpdate.Session.IdentityID.String()) + _, cred, err := s.d.PrivilegedIdentityPool().FindByCredentialsIdentifier(ctx, s.ID(), ctxUpdate.Session.IdentityID.String()) if err != nil { return err } @@ -244,14 +242,14 @@ func (s *Strategy) continueSettingsFlowReveal(w http.ResponseWriter, r *http.Req return err } - if err := s.d.SettingsFlowPersister().UpdateSettingsFlow(r.Context(), ctxUpdate.Flow); err != nil { + if err := s.d.SettingsFlowPersister().UpdateSettingsFlow(ctx, ctxUpdate.Flow); err != nil { return err } return nil } -func (s *Strategy) continueSettingsFlowRegenerate(w http.ResponseWriter, r *http.Request, ctxUpdate *settings.UpdateContext, p *updateSettingsFlowWithLookupMethod) error { +func (s *Strategy) continueSettingsFlowRegenerate(ctx context.Context, ctxUpdate *settings.UpdateContext) error { codes := make([]identity.RecoveryCode, numCodes) for k := range codes { codes[k] = identity.RecoveryCode{Code: randx.MustString(8, randx.AlphaLowerNum)} @@ -270,14 +268,14 @@ func (s *Strategy) continueSettingsFlowRegenerate(w http.ResponseWriter, r *http return err } - if err := s.d.SettingsFlowPersister().UpdateSettingsFlow(r.Context(), ctxUpdate.Flow); err != nil { + if err := s.d.SettingsFlowPersister().UpdateSettingsFlow(ctx, ctxUpdate.Flow); err != nil { return err } return nil } -func (s *Strategy) continueSettingsFlowConfirm(w http.ResponseWriter, r *http.Request, ctxUpdate *settings.UpdateContext, p *updateSettingsFlowWithLookupMethod) error { +func (s *Strategy) continueSettingsFlowConfirm(ctx context.Context, ctxUpdate *settings.UpdateContext) error { codes := gjson.GetBytes(ctxUpdate.Flow.InternalContext, flow.PrefixInternalContextKey(s.ID(), InternalContextKeyRegenerated)).Array() if len(codes) != numCodes { return errors.WithStack(herodot.ErrBadRequest.WithReasonf("You must (re-)generate recovery backup codes before you can save them.")) @@ -293,7 +291,7 @@ func (s *Strategy) continueSettingsFlowConfirm(w http.ResponseWriter, r *http.Re return errors.WithStack(herodot.ErrInternalServerError.WithReasonf("Unable to encode totp options to JSON: %s", err)) } - i, err := s.d.PrivilegedIdentityPool().GetIdentityConfidential(r.Context(), ctxUpdate.Session.Identity.ID) + i, err := s.d.PrivilegedIdentityPool().GetIdentityConfidential(ctx, ctxUpdate.Session.Identity.ID) if err != nil { return err } @@ -309,12 +307,12 @@ func (s *Strategy) continueSettingsFlowConfirm(w http.ResponseWriter, r *http.Re return err } - if err := s.d.SettingsFlowPersister().UpdateSettingsFlow(r.Context(), ctxUpdate.Flow); err != nil { + if err := s.d.SettingsFlowPersister().UpdateSettingsFlow(ctx, ctxUpdate.Flow); err != nil { return err } // Since we added the method, it also means that we have authenticated it - if err := s.d.SessionManager().SessionAddAuthenticationMethods(r.Context(), ctxUpdate.Session.ID, session.AuthenticationMethod{ + if err := s.d.SessionManager().SessionAddAuthenticationMethods(ctx, ctxUpdate.Session.ID, session.AuthenticationMethod{ Method: s.ID(), AAL: identity.AuthenticatorAssuranceLevel2, }); err != nil { @@ -357,7 +355,7 @@ func (s *Strategy) PopulateSettingsMethod(r *http.Request, id *identity.Identity return nil } -func (s *Strategy) handleSettingsError(w http.ResponseWriter, r *http.Request, ctxUpdate *settings.UpdateContext, p *updateSettingsFlowWithLookupMethod, err error) error { +func (s *Strategy) handleSettingsError(w http.ResponseWriter, r *http.Request, ctxUpdate *settings.UpdateContext, p updateSettingsFlowWithLookupMethod, err error) error { // Do not pause flow if the flow type is an API flow as we can't save cookies in those flows. if e := new(settings.FlowNeedsReAuth); errors.As(err, &e) && ctxUpdate.Flow != nil && ctxUpdate.Flow.Type == flow.TypeBrowser { if err := s.d.ContinuityManager().Pause(r.Context(), w, r, settings.ContinuityKey(s.SettingsStrategyID()), settings.ContinuityOptions(p, ctxUpdate.GetSessionIdentity())...); err != nil { diff --git a/selfservice/strategy/lookup/settings_test.go b/selfservice/strategy/lookup/settings_test.go index 101f5919d04a..52b08786fd5a 100644 --- a/selfservice/strategy/lookup/settings_test.go +++ b/selfservice/strategy/lookup/settings_test.go @@ -398,7 +398,7 @@ func TestCompleteSettings(t *testing.T) { payloadConfirm(values) actual, res := testhelpers.SettingsMakeRequest(t, true, false, f, apiClient, testhelpers.EncodeFormAsJSON(t, true, values)) - assert.Equal(t, http.StatusOK, res.StatusCode) + require.Equal(t, http.StatusOK, res.StatusCode) assert.Contains(t, res.Request.URL.String(), publicTS.URL+settings.RouteSubmitFlow) assert.EqualValues(t, flow.StateSuccess, json.RawMessage(gjson.Get(actual, "state").String())) diff --git a/selfservice/strategy/oidc/strategy_login.go b/selfservice/strategy/oidc/strategy_login.go index 6b64194dcba6..d290276492a7 100644 --- a/selfservice/strategy/oidc/strategy_login.go +++ b/selfservice/strategy/oidc/strategy_login.go @@ -351,7 +351,7 @@ func (s *Strategy) PopulateLoginMethodIdentifierFirstCredentials(r *http.Request if o.IdentityHint != nil { var err error // If we have an identity hint we check if the identity has any providers configured. - if linked, err = s.linkedProviders(r.Context(), r, conf, o.IdentityHint); err != nil { + if linked, err = s.linkedProviders(conf, o.IdentityHint); err != nil { return err } } diff --git a/selfservice/strategy/oidc/strategy_settings.go b/selfservice/strategy/oidc/strategy_settings.go index d4a92056a20b..a01f3003df83 100644 --- a/selfservice/strategy/oidc/strategy_settings.go +++ b/selfservice/strategy/oidc/strategy_settings.go @@ -86,7 +86,7 @@ func (s *Strategy) decoderSettings(p *updateSettingsFlowWithOidcMethod, r *http. return nil } -func (s *Strategy) linkedProviders(ctx context.Context, r *http.Request, conf *ConfigurationCollection, confidential *identity.Identity) ([]Provider, error) { +func (s *Strategy) linkedProviders(conf *ConfigurationCollection, confidential *identity.Identity) ([]Provider, error) { creds, ok := confidential.GetCredentials(s.ID()) if !ok { return nil, nil @@ -111,7 +111,7 @@ func (s *Strategy) linkedProviders(ctx context.Context, r *http.Request, conf *C return result, nil } -func (s *Strategy) linkableProviders(ctx context.Context, r *http.Request, conf *ConfigurationCollection, confidential *identity.Identity) ([]Provider, error) { +func (s *Strategy) linkableProviders(conf *ConfigurationCollection, confidential *identity.Identity) ([]Provider, error) { var available identity.CredentialsOIDC creds, ok := confidential.GetCredentials(s.ID()) if ok { @@ -143,26 +143,28 @@ func (s *Strategy) linkableProviders(ctx context.Context, r *http.Request, conf } func (s *Strategy) PopulateSettingsMethod(r *http.Request, id *identity.Identity, sr *settings.Flow) error { + ctx := r.Context() + if sr.Type != flow.TypeBrowser { return nil } - conf, err := s.Config(r.Context()) + conf, err := s.Config(ctx) if err != nil { return err } - confidential, err := s.d.PrivilegedIdentityPool().GetIdentityConfidential(r.Context(), id.ID) + confidential, err := s.d.PrivilegedIdentityPool().GetIdentityConfidential(ctx, id.ID) if err != nil { return err } - linkable, err := s.linkableProviders(r.Context(), r, conf, confidential) + linkable, err := s.linkableProviders(conf, confidential) if err != nil { return err } - linked, err := s.linkedProviders(r.Context(), r, conf, confidential) + linked, err := s.linkedProviders(conf, confidential) if err != nil { return err } @@ -177,7 +179,7 @@ func (s *Strategy) PopulateSettingsMethod(r *http.Request, id *identity.Identity sr.UI.GetNodes().Append(NewLinkNode(l.Config().ID, stringsx.Coalesce(l.Config().Label, l.Config().ID))) } - count, err := s.d.IdentityManager().CountActiveFirstFactorCredentials(r.Context(), confidential) + count, err := s.d.IdentityManager().CountActiveFirstFactorCredentials(ctx, confidential) if err != nil { return err } @@ -322,18 +324,18 @@ func (s *Strategy) Settings(w http.ResponseWriter, r *http.Request, f *settings. }))) } -func (s *Strategy) isLinkable(r *http.Request, ctxUpdate *settings.UpdateContext, toLink string) (*identity.Identity, error) { - providers, err := s.Config(r.Context()) +func (s *Strategy) isLinkable(ctx context.Context, ctxUpdate *settings.UpdateContext, toLink string) (*identity.Identity, error) { + providers, err := s.Config(ctx) if err != nil { return nil, err } - i, err := s.d.PrivilegedIdentityPool().GetIdentityConfidential(r.Context(), ctxUpdate.Session.Identity.ID) + i, err := s.d.PrivilegedIdentityPool().GetIdentityConfidential(ctx, ctxUpdate.Session.Identity.ID) if err != nil { return nil, err } - linkable, err := s.linkableProviders(r.Context(), r, providers, i) + linkable, err := s.linkableProviders(providers, i) if err != nil { return nil, err } @@ -353,26 +355,27 @@ func (s *Strategy) isLinkable(r *http.Request, ctxUpdate *settings.UpdateContext } func (s *Strategy) initLinkProvider(w http.ResponseWriter, r *http.Request, ctxUpdate *settings.UpdateContext, p *updateSettingsFlowWithOidcMethod) error { - if _, err := s.isLinkable(r, ctxUpdate, p.Link); err != nil { + ctx := r.Context() + if _, err := s.isLinkable(ctx, ctxUpdate, p.Link); err != nil { return s.handleSettingsError(w, r, ctxUpdate, p, err) } - if ctxUpdate.Session.AuthenticatedAt.Add(s.d.Config().SelfServiceFlowSettingsPrivilegedSessionMaxAge(r.Context())).Before(time.Now()) { + if ctxUpdate.Session.AuthenticatedAt.Add(s.d.Config().SelfServiceFlowSettingsPrivilegedSessionMaxAge(ctx)).Before(time.Now()) { return s.handleSettingsError(w, r, ctxUpdate, p, errors.WithStack(settings.NewFlowNeedsReAuth())) } - provider, err := s.provider(r.Context(), r, p.Link) + provider, err := s.provider(ctx, r, p.Link) if err != nil { return s.handleSettingsError(w, r, ctxUpdate, p, err) } - req, err := s.validateFlow(r.Context(), r, ctxUpdate.Flow.ID) + req, err := s.validateFlow(ctx, r, ctxUpdate.Flow.ID) if err != nil { return s.handleSettingsError(w, r, ctxUpdate, p, err) } state := generateState(ctxUpdate.Flow.ID.String()) - if err := s.d.ContinuityManager().Pause(r.Context(), w, r, sessionName, + if err := s.d.ContinuityManager().Pause(ctx, w, r, sessionName, continuity.WithPayload(&AuthCodeContainer{ State: state.String(), FlowID: ctxUpdate.Flow.ID.String(), @@ -387,7 +390,7 @@ func (s *Strategy) initLinkProvider(w http.ResponseWriter, r *http.Request, ctxU return err } - codeURL, err := getAuthRedirectURL(r.Context(), provider, req, state, up) + codeURL, err := getAuthRedirectURL(ctx, provider, req, state, up) if err != nil { return s.handleSettingsError(w, r, ctxUpdate, p, err) } @@ -402,19 +405,20 @@ func (s *Strategy) initLinkProvider(w http.ResponseWriter, r *http.Request, ctxU } func (s *Strategy) linkProvider(w http.ResponseWriter, r *http.Request, ctxUpdate *settings.UpdateContext, token *identity.CredentialsOIDCEncryptedTokens, claims *Claims, provider Provider) error { + ctx := r.Context() p := &updateSettingsFlowWithOidcMethod{ Link: provider.Config().ID, FlowID: ctxUpdate.Flow.ID.String(), } - if ctxUpdate.Session.AuthenticatedAt.Add(s.d.Config().SelfServiceFlowSettingsPrivilegedSessionMaxAge(r.Context())).Before(time.Now()) { + if ctxUpdate.Session.AuthenticatedAt.Add(s.d.Config().SelfServiceFlowSettingsPrivilegedSessionMaxAge(ctx)).Before(time.Now()) { return s.handleSettingsError(w, r, ctxUpdate, p, errors.WithStack(settings.NewFlowNeedsReAuth())) } - i, err := s.isLinkable(r, ctxUpdate, p.Link) + i, err := s.isLinkable(ctx, ctxUpdate, p.Link) if err != nil { return s.handleSettingsError(w, r, ctxUpdate, p, err) } - if err := s.linkCredentials(r.Context(), i, token, provider.Config().ID, claims.Subject, provider.Config().OrganizationID); err != nil { + if err := s.linkCredentials(ctx, i, token, provider.Config().ID, claims.Subject, provider.Config().OrganizationID); err != nil { return s.handleSettingsError(w, r, ctxUpdate, p, err) } @@ -428,21 +432,22 @@ func (s *Strategy) linkProvider(w http.ResponseWriter, r *http.Request, ctxUpdat } func (s *Strategy) unlinkProvider(w http.ResponseWriter, r *http.Request, ctxUpdate *settings.UpdateContext, p *updateSettingsFlowWithOidcMethod) error { - if ctxUpdate.Session.AuthenticatedAt.Add(s.d.Config().SelfServiceFlowSettingsPrivilegedSessionMaxAge(r.Context())).Before(time.Now()) { + ctx := r.Context() + if ctxUpdate.Session.AuthenticatedAt.Add(s.d.Config().SelfServiceFlowSettingsPrivilegedSessionMaxAge(ctx)).Before(time.Now()) { return s.handleSettingsError(w, r, ctxUpdate, p, errors.WithStack(settings.NewFlowNeedsReAuth())) } - providers, err := s.Config(r.Context()) + providers, err := s.Config(ctx) if err != nil { return s.handleSettingsError(w, r, ctxUpdate, p, err) } - i, err := s.d.PrivilegedIdentityPool().GetIdentityConfidential(r.Context(), ctxUpdate.Session.Identity.ID) + i, err := s.d.PrivilegedIdentityPool().GetIdentityConfidential(ctx, ctxUpdate.Session.Identity.ID) if err != nil { return s.handleSettingsError(w, r, ctxUpdate, p, err) } - availableProviders, err := s.linkedProviders(r.Context(), r, providers, i) + availableProviders, err := s.linkedProviders(providers, i) if err != nil { return s.handleSettingsError(w, r, ctxUpdate, p, err) } @@ -453,7 +458,7 @@ func (s *Strategy) unlinkProvider(w http.ResponseWriter, r *http.Request, ctxUpd return s.handleSettingsError(w, r, ctxUpdate, p, err) } - count, err := s.d.IdentityManager().CountActiveFirstFactorCredentials(r.Context(), i) + count, err := s.d.IdentityManager().CountActiveFirstFactorCredentials(ctx, i) if err != nil { return s.handleSettingsError(w, r, ctxUpdate, p, err) } diff --git a/selfservice/strategy/passkey/passkey_settings.go b/selfservice/strategy/passkey/passkey_settings.go index 89423bf8adeb..7214e32660d3 100644 --- a/selfservice/strategy/passkey/passkey_settings.go +++ b/selfservice/strategy/passkey/passkey_settings.go @@ -167,20 +167,20 @@ func (s *Strategy) Settings(w http.ResponseWriter, r *http.Request, f *settings. var p updateSettingsFlowWithPasskeyMethod ctxUpdate, err := settings.PrepareUpdate(s.d, w, r, f, ss, settings.ContinuityKey(s.SettingsStrategyID()), &p) if errors.Is(err, settings.ErrContinuePreviousAction) { - return ctxUpdate, s.continueSettingsFlow(w, r, ctxUpdate, &p) + return ctxUpdate, s.continueSettingsFlow(w, r, ctxUpdate, p) } else if err != nil { - return ctxUpdate, s.handleSettingsError(w, r, ctxUpdate, &p, err) + return ctxUpdate, s.handleSettingsError(w, r, ctxUpdate, p, err) } if err := s.decodeSettingsFlow(r, &p); err != nil { - return ctxUpdate, s.handleSettingsError(w, r, ctxUpdate, &p, err) + return ctxUpdate, s.handleSettingsError(w, r, ctxUpdate, p, err) } if len(p.Register+p.Remove) > 0 { // This method has only two submit buttons p.Method = s.SettingsStrategyID() if err := flow.MethodEnabledAndAllowed(r.Context(), f.GetFlowName(), s.SettingsStrategyID(), p.Method, s.d); err != nil { - return nil, s.handleSettingsError(w, r, ctxUpdate, &p, err) + return nil, s.handleSettingsError(w, r, ctxUpdate, p, err) } } else { return nil, errors.WithStack(flow.ErrStrategyNotResponsible) @@ -188,8 +188,8 @@ func (s *Strategy) Settings(w http.ResponseWriter, r *http.Request, f *settings. // This does not come from the payload! p.Flow = ctxUpdate.Flow.ID.String() - if err := s.continueSettingsFlow(w, r, ctxUpdate, &p); err != nil { - return ctxUpdate, s.handleSettingsError(w, r, ctxUpdate, &p, err) + if err := s.continueSettingsFlow(w, r, ctxUpdate, p); err != nil { + return ctxUpdate, s.handleSettingsError(w, r, ctxUpdate, p, err) } return ctxUpdate, nil @@ -236,7 +236,7 @@ func (p *updateSettingsFlowWithPasskeyMethod) SetFlowID(rid uuid.UUID) { func (s *Strategy) continueSettingsFlow( w http.ResponseWriter, r *http.Request, - ctxUpdate *settings.UpdateContext, p *updateSettingsFlowWithPasskeyMethod, + ctxUpdate *settings.UpdateContext, p updateSettingsFlowWithPasskeyMethod, ) error { if len(p.Register+p.Remove) > 0 { if err := flow.MethodEnabledAndAllowed(r.Context(), flow.SettingsFlow, s.SettingsStrategyID(), s.SettingsStrategyID(), s.d); err != nil { @@ -264,7 +264,7 @@ func (s *Strategy) continueSettingsFlow( } } -func (s *Strategy) continueSettingsFlowRemove(w http.ResponseWriter, r *http.Request, ctxUpdate *settings.UpdateContext, p *updateSettingsFlowWithPasskeyMethod) error { +func (s *Strategy) continueSettingsFlowRemove(w http.ResponseWriter, r *http.Request, ctxUpdate *settings.UpdateContext, p updateSettingsFlowWithPasskeyMethod) error { i, err := s.d.PrivilegedIdentityPool().GetIdentityConfidential(r.Context(), ctxUpdate.Session.IdentityID) if err != nil { return err @@ -317,7 +317,7 @@ func (s *Strategy) continueSettingsFlowRemove(w http.ResponseWriter, r *http.Req return nil } -func (s *Strategy) continueSettingsFlowAdd(r *http.Request, ctxUpdate *settings.UpdateContext, p *updateSettingsFlowWithPasskeyMethod) error { +func (s *Strategy) continueSettingsFlowAdd(r *http.Request, ctxUpdate *settings.UpdateContext, p updateSettingsFlowWithPasskeyMethod) error { webAuthnSession := gjson.GetBytes(ctxUpdate.Flow.InternalContext, flow.PrefixInternalContextKey(s.ID(), InternalContextKeySessionData)) if !webAuthnSession.IsObject() { return errors.WithStack(herodot.ErrInternalServerError.WithReasonf("Expected WebAuthN in internal context to be an object.")) @@ -408,7 +408,7 @@ func (s *Strategy) decodeSettingsFlow(r *http.Request, dest interface{}) error { ) } -func (s *Strategy) handleSettingsError(w http.ResponseWriter, r *http.Request, ctxUpdate *settings.UpdateContext, p *updateSettingsFlowWithPasskeyMethod, err error) error { +func (s *Strategy) handleSettingsError(w http.ResponseWriter, r *http.Request, ctxUpdate *settings.UpdateContext, p updateSettingsFlowWithPasskeyMethod, err error) error { // Do not pause flow if the flow type is an API flow as we can't save cookies in those flows. if e := new(settings.FlowNeedsReAuth); errors.As(err, &e) && ctxUpdate.Flow != nil && ctxUpdate.Flow.Type == flow.TypeBrowser { if err := s.d.ContinuityManager().Pause(r.Context(), w, r, settings.ContinuityKey(s.SettingsStrategyID()), settings.ContinuityOptions(p, ctxUpdate.GetSessionIdentity())...); err != nil { diff --git a/selfservice/strategy/passkey/passkey_settings_test.go b/selfservice/strategy/passkey/passkey_settings_test.go index a37fe39a38f1..606caa838774 100644 --- a/selfservice/strategy/passkey/passkey_settings_test.go +++ b/selfservice/strategy/passkey/passkey_settings_test.go @@ -249,6 +249,7 @@ func TestCompleteSettings(t *testing.T) { values.Set("method", "passkey") values.Set(node.PasskeySettingsRegister, string(settingsFixtureSuccessResponse)) body, res := testhelpers.SettingsMakeRequest(t, false, spa, f, browserClient, testhelpers.EncodeFormAsJSON(t, spa, values)) + require.Equal(t, http.StatusOK, res.StatusCode, "%s", body) if spa { assert.Contains(t, res.Request.URL.String(), fix.publicTS.URL+settings.RouteSubmitFlow) @@ -260,7 +261,7 @@ func TestCompleteSettings(t *testing.T) { actual, err := fix.reg.Persister().GetIdentityConfidential(fix.ctx, id.ID) require.NoError(t, err) cred, ok := actual.GetCredentials(identity.CredentialsTypePasskey) - assert.True(t, ok) + require.True(t, ok) assert.Len(t, gjson.GetBytes(cred.Config, "credentials").Array(), 1) actualFlow, err := fix.reg.SettingsFlowPersister().GetSettingsFlow(fix.ctx, uuid.FromStringOrNil(f.Id)) diff --git a/selfservice/strategy/password/login.go b/selfservice/strategy/password/login.go index 89730d299463..91e59085c8ef 100644 --- a/selfservice/strategy/password/login.go +++ b/selfservice/strategy/password/login.go @@ -35,7 +35,7 @@ var _ login.FormHydrator = new(Strategy) func (s *Strategy) RegisterLoginRoutes(r *x.RouterPublic) { } -func (s *Strategy) handleLoginError(r *http.Request, f *login.Flow, payload *updateLoginFlowWithPasswordMethod, err error) error { +func (s *Strategy) handleLoginError(r *http.Request, f *login.Flow, payload updateLoginFlowWithPasswordMethod, err error) error { if f != nil { f.UI.Nodes.ResetNodes("password") f.UI.Nodes.SetValueAttribute("identifier", stringsx.Coalesce(payload.Identifier, payload.LegacyIdentifier)) @@ -61,19 +61,19 @@ func (s *Strategy) Login(w http.ResponseWriter, r *http.Request, f *login.Flow, decoderx.HTTPDecoderSetValidatePayloads(true), decoderx.MustHTTPRawJSONSchemaCompiler(loginSchema), decoderx.HTTPDecoderJSONFollowsFormFormat()); err != nil { - return nil, s.handleLoginError(r, f, &p, err) + return nil, s.handleLoginError(r, f, p, err) } f.TransientPayload = p.TransientPayload if err := flow.EnsureCSRF(s.d, r, f.Type, s.d.Config().DisableAPIFlowEnforcement(r.Context()), s.d.GenerateCSRFToken, p.CSRFToken); err != nil { - return nil, s.handleLoginError(r, f, &p, err) + return nil, s.handleLoginError(r, f, p, err) } identifier := stringsx.Coalesce(p.Identifier, p.LegacyIdentifier) i, c, err := s.d.PrivilegedIdentityPool().FindByCredentialsIdentifier(r.Context(), s.ID(), identifier) if err != nil { time.Sleep(x.RandomDelay(s.d.Config().HasherArgon2(r.Context()).ExpectedDuration, s.d.Config().HasherArgon2(r.Context()).ExpectedDeviation)) - return nil, s.handleLoginError(r, f, &p, errors.WithStack(schema.NewInvalidCredentialsError())) + return nil, s.handleLoginError(r, f, p, errors.WithStack(schema.NewInvalidCredentialsError())) } var o identity.CredentialsPassword @@ -91,27 +91,27 @@ func (s *Strategy) Login(w http.ResponseWriter, r *http.Request, f *login.Flow, migrationHook := hook.NewPasswordMigrationHook(s.d, pwHook.Config) err = migrationHook.Execute(r.Context(), &hook.PasswordMigrationRequest{Identifier: identifier, Password: p.Password}) if err != nil { - return nil, s.handleLoginError(r, f, &p, err) + return nil, s.handleLoginError(r, f, p, err) } if err := s.migratePasswordHash(r.Context(), i.ID, []byte(p.Password)); err != nil { - return nil, s.handleLoginError(r, f, &p, err) + return nil, s.handleLoginError(r, f, p, err) } } else { if err := hash.Compare(r.Context(), []byte(p.Password), []byte(o.HashedPassword)); err != nil { - return nil, s.handleLoginError(r, f, &p, errors.WithStack(schema.NewInvalidCredentialsError())) + return nil, s.handleLoginError(r, f, p, errors.WithStack(schema.NewInvalidCredentialsError())) } if !s.d.Hasher(r.Context()).Understands([]byte(o.HashedPassword)) { if err := s.migratePasswordHash(r.Context(), i.ID, []byte(p.Password)); err != nil { - return nil, s.handleLoginError(r, f, &p, err) + return nil, s.handleLoginError(r, f, p, err) } } } f.Active = s.ID() if err = s.d.LoginFlowPersister().UpdateLoginFlow(r.Context(), f); err != nil { - return nil, s.handleLoginError(r, f, &p, errors.WithStack(herodot.ErrInternalServerError.WithReason("Could not update flow").WithDebug(err.Error()))) + return nil, s.handleLoginError(r, f, p, errors.WithStack(herodot.ErrInternalServerError.WithReason("Could not update flow").WithDebug(err.Error()))) } return i, nil diff --git a/selfservice/strategy/password/registration.go b/selfservice/strategy/password/registration.go index 55970fdf0b5e..ba11733fd0ec 100644 --- a/selfservice/strategy/password/registration.go +++ b/selfservice/strategy/password/registration.go @@ -56,13 +56,11 @@ type UpdateRegistrationFlowWithPasswordMethod struct { func (s *Strategy) RegisterRegistrationRoutes(*x.RouterPublic) { } -func (s *Strategy) handleRegistrationError(_ http.ResponseWriter, r *http.Request, f *registration.Flow, p *UpdateRegistrationFlowWithPasswordMethod, err error) error { +func (s *Strategy) handleRegistrationError(r *http.Request, f *registration.Flow, p UpdateRegistrationFlowWithPasswordMethod, err error) error { if f != nil { - if p != nil { - for _, n := range container.NewFromJSON("", node.ProfileGroup, p.Traits, "traits").Nodes { - // we only set the value and not the whole field because we want to keep types from the initial form generation - f.UI.Nodes.SetValueAttribute(n.ID(), n.Attributes.GetValue()) - } + for _, n := range container.NewFromJSON("", node.ProfileGroup, p.Traits, "traits").Nodes { + // we only set the value and not the whole field because we want to keep types from the initial form generation + f.UI.Nodes.SetValueAttribute(n.ID(), n.Attributes.GetValue()) } if f.Type == flow.TypeBrowser { @@ -84,17 +82,17 @@ func (s *Strategy) Register(w http.ResponseWriter, r *http.Request, f *registrat var p UpdateRegistrationFlowWithPasswordMethod if err := s.decode(&p, r); err != nil { - return s.handleRegistrationError(w, r, f, &p, err) + return s.handleRegistrationError(r, f, p, err) } f.TransientPayload = p.TransientPayload if err := flow.EnsureCSRF(s.d, r, f.Type, s.d.Config().DisableAPIFlowEnforcement(r.Context()), s.d.GenerateCSRFToken, p.CSRFToken); err != nil { - return s.handleRegistrationError(w, r, f, &p, err) + return s.handleRegistrationError(r, f, p, err) } if len(p.Password) == 0 { - return s.handleRegistrationError(w, r, f, &p, schema.NewRequiredError("#/password", "password")) + return s.handleRegistrationError(r, f, p, schema.NewRequiredError("#/password", "password")) } if len(p.Traits) == 0 { @@ -116,27 +114,27 @@ func (s *Strategy) Register(w http.ResponseWriter, r *http.Request, f *registrat }() if err != nil { - return s.handleRegistrationError(w, r, f, &p, err) + return s.handleRegistrationError(r, f, p, err) } i.Traits = identity.Traits(p.Traits) // We have to set the credential here, so the identity validator can populate the identifiers. // The password hash is computed in parallel and set later. if err := i.SetCredentialsWithConfig(s.ID(), identity.Credentials{Type: s.ID(), Identifiers: []string{}}, json.RawMessage("{}")); err != nil { - return s.handleRegistrationError(w, r, f, &p, err) + return s.handleRegistrationError(r, f, p, err) } if err := s.validateCredentials(r.Context(), i, p.Password); err != nil { - return s.handleRegistrationError(w, r, f, &p, err) + return s.handleRegistrationError(r, f, p, err) } select { case err := <-errC: - return s.handleRegistrationError(w, r, f, &p, err) + return s.handleRegistrationError(r, f, p, err) case h := <-hpw: co, err := json.Marshal(&identity.CredentialsPassword{HashedPassword: string(h)}) if err != nil { - return s.handleRegistrationError(w, r, f, &p, errors.WithStack(herodot.ErrInternalServerError.WithReasonf("Unable to encode password options to JSON: %s", err))) + return s.handleRegistrationError(r, f, p, errors.WithStack(herodot.ErrInternalServerError.WithReasonf("Unable to encode password options to JSON: %s", err))) } i.UpsertCredentialsConfig(s.ID(), co, 0) } diff --git a/selfservice/strategy/password/settings.go b/selfservice/strategy/password/settings.go index f763163d3180..0995a73f5076 100644 --- a/selfservice/strategy/password/settings.go +++ b/selfservice/strategy/password/settings.go @@ -75,23 +75,23 @@ func (s *Strategy) Settings(w http.ResponseWriter, r *http.Request, f *settings. var p updateSettingsFlowWithPasswordMethod ctxUpdate, err := settings.PrepareUpdate(s.d, w, r, f, ss, settings.ContinuityKey(s.SettingsStrategyID()), &p) if errors.Is(err, settings.ErrContinuePreviousAction) { - return ctxUpdate, s.continueSettingsFlow(w, r, ctxUpdate, &p) + return ctxUpdate, s.continueSettingsFlow(r, ctxUpdate, p) } else if err != nil { - return ctxUpdate, s.handleSettingsError(w, r, ctxUpdate, &p, err) + return ctxUpdate, s.handleSettingsError(w, r, ctxUpdate, p, err) } if err := flow.MethodEnabledAndAllowedFromRequest(r, f.GetFlowName(), s.SettingsStrategyID(), s.d); err != nil { - return ctxUpdate, s.handleSettingsError(w, r, ctxUpdate, &p, err) + return ctxUpdate, s.handleSettingsError(w, r, ctxUpdate, p, err) } if err := s.decodeSettingsFlow(r, &p); err != nil { - return ctxUpdate, s.handleSettingsError(w, r, ctxUpdate, &p, err) + return ctxUpdate, s.handleSettingsError(w, r, ctxUpdate, p, err) } // This does not come from the payload! p.Flow = ctxUpdate.Flow.ID.String() - if err := s.continueSettingsFlow(w, r, ctxUpdate, &p); err != nil { - return ctxUpdate, s.handleSettingsError(w, r, ctxUpdate, &p, err) + if err := s.continueSettingsFlow(r, ctxUpdate, p); err != nil { + return ctxUpdate, s.handleSettingsError(w, r, ctxUpdate, p, err) } return ctxUpdate, nil @@ -110,10 +110,7 @@ func (s *Strategy) decodeSettingsFlow(r *http.Request, dest interface{}) error { ) } -func (s *Strategy) continueSettingsFlow( - w http.ResponseWriter, r *http.Request, - ctxUpdate *settings.UpdateContext, p *updateSettingsFlowWithPasswordMethod, -) error { +func (s *Strategy) continueSettingsFlow(r *http.Request, ctxUpdate *settings.UpdateContext, p updateSettingsFlowWithPasswordMethod) error { if err := flow.MethodEnabledAndAllowed(r.Context(), flow.SettingsFlow, s.SettingsStrategyID(), p.Method, s.d); err != nil { return err } @@ -175,7 +172,7 @@ func (s *Strategy) PopulateSettingsMethod(r *http.Request, _ *identity.Identity, return nil } -func (s *Strategy) handleSettingsError(w http.ResponseWriter, r *http.Request, ctxUpdate *settings.UpdateContext, p *updateSettingsFlowWithPasswordMethod, err error) error { +func (s *Strategy) handleSettingsError(w http.ResponseWriter, r *http.Request, ctxUpdate *settings.UpdateContext, p updateSettingsFlowWithPasswordMethod, err error) error { // Do not pause flow if the flow type is an API flow as we can't save cookies in those flows. if e := new(settings.FlowNeedsReAuth); errors.As(err, &e) && ctxUpdate.Flow != nil && ctxUpdate.Flow.Type == flow.TypeBrowser { if err := s.d.ContinuityManager().Pause(r.Context(), w, r, settings.ContinuityKey(s.SettingsStrategyID()), settings.ContinuityOptions(p, ctxUpdate.GetSessionIdentity())...); err != nil { diff --git a/selfservice/strategy/profile/strategy.go b/selfservice/strategy/profile/strategy.go index b83e6ad527b2..644f8dd6c263 100644 --- a/selfservice/strategy/profile/strategy.go +++ b/selfservice/strategy/profile/strategy.go @@ -113,9 +113,9 @@ func (s *Strategy) Settings(w http.ResponseWriter, r *http.Request, f *settings. var p updateSettingsFlowWithProfileMethod ctxUpdate, err := settings.PrepareUpdate(s.d, w, r, f, ss, settings.ContinuityKey(s.SettingsStrategyID()), &p) if errors.Is(err, settings.ErrContinuePreviousAction) { - return ctxUpdate, s.continueFlow(w, r, ctxUpdate, &p) + return ctxUpdate, s.continueFlow(r, ctxUpdate, p) } else if err != nil { - return ctxUpdate, s.handleSettingsError(w, r, ctxUpdate, nil, &p, err) + return ctxUpdate, s.handleSettingsError(w, r, ctxUpdate, nil, p, err) } if err := flow.MethodEnabledAndAllowedFromRequest(r, f.GetFlowName(), s.SettingsStrategyID(), s.d); err != nil { @@ -124,7 +124,7 @@ func (s *Strategy) Settings(w http.ResponseWriter, r *http.Request, f *settings. option, err := s.newSettingsProfileDecoder(r.Context(), ctxUpdate.GetSessionIdentity()) if err != nil { - return ctxUpdate, s.handleSettingsError(w, r, ctxUpdate, nil, &p, err) + return ctxUpdate, s.handleSettingsError(w, r, ctxUpdate, nil, p, err) } if err := s.dc.Decode(r, &p, option, @@ -132,20 +132,20 @@ func (s *Strategy) Settings(w http.ResponseWriter, r *http.Request, f *settings. decoderx.HTTPDecoderSetValidatePayloads(true), decoderx.HTTPDecoderJSONFollowsFormFormat(), ); err != nil { - return ctxUpdate, s.handleSettingsError(w, r, ctxUpdate, nil, &p, err) + return ctxUpdate, s.handleSettingsError(w, r, ctxUpdate, nil, p, err) } // Reset after decoding form p.SetFlowID(ctxUpdate.Flow.ID) - if err := s.continueFlow(w, r, ctxUpdate, &p); err != nil { - return ctxUpdate, s.handleSettingsError(w, r, ctxUpdate, nil, &p, err) + if err := s.continueFlow(r, ctxUpdate, p); err != nil { + return ctxUpdate, s.handleSettingsError(w, r, ctxUpdate, nil, p, err) } return ctxUpdate, nil } -func (s *Strategy) continueFlow(w http.ResponseWriter, r *http.Request, ctxUpdate *settings.UpdateContext, p *updateSettingsFlowWithProfileMethod) error { +func (s *Strategy) continueFlow(r *http.Request, ctxUpdate *settings.UpdateContext, p updateSettingsFlowWithProfileMethod) error { if err := flow.MethodEnabledAndAllowed(r.Context(), flow.SettingsFlow, s.SettingsStrategyID(), p.Method, s.d); err != nil { return err } @@ -237,7 +237,7 @@ func (s *Strategy) hydrateForm(r *http.Request, ar *settings.Flow, ss *session.S // handleSettingsError is a convenience function for handling all types of errors that may occur (e.g. validation error) // during a settings request. -func (s *Strategy) handleSettingsError(w http.ResponseWriter, r *http.Request, puc *settings.UpdateContext, traits json.RawMessage, p *updateSettingsFlowWithProfileMethod, err error) error { +func (s *Strategy) handleSettingsError(w http.ResponseWriter, r *http.Request, puc *settings.UpdateContext, traits json.RawMessage, p updateSettingsFlowWithProfileMethod, err error) error { if e := new(settings.FlowNeedsReAuth); errors.As(err, &e) { if err := s.d.ContinuityManager().Pause(r.Context(), w, r, settings.ContinuityKey(s.SettingsStrategyID()), diff --git a/selfservice/strategy/profile/two_step_registration.go b/selfservice/strategy/profile/two_step_registration.go index 98f4fe806913..bdaa4964ffc7 100644 --- a/selfservice/strategy/profile/two_step_registration.go +++ b/selfservice/strategy/profile/two_step_registration.go @@ -117,7 +117,7 @@ func (s *Strategy) Register(w http.ResponseWriter, r *http.Request, regFlow *reg var params updateRegistrationFlowWithProfileMethod if err = s.decode(¶ms, r); err != nil { - return s.handleRegistrationError(w, r, regFlow, ¶ms, err) + return s.handleRegistrationError(r, regFlow, params, err) } if params.Screen == "credential-selection" { @@ -139,12 +139,12 @@ func (s *Strategy) displayStepOneNodes(w http.ResponseWriter, r *http.Request, r regFlow.UI.ResetMessages() err := json.Unmarshal([]byte(gjson.GetBytes(regFlow.InternalContext, "stepOneNodes").Raw), ®Flow.UI.Nodes) if err != nil { - return s.handleRegistrationError(w, r, regFlow, ¶ms, err) + return s.handleRegistrationError(r, regFlow, params, err) } regFlow.UI.UpdateNodeValuesFromJSON(params.Traits, "traits", node.DefaultGroup) if err := s.d.RegistrationFlowPersister().UpdateRegistrationFlow(ctx, regFlow); err != nil { - return s.handleRegistrationError(w, r, regFlow, ¶ms, err) + return s.handleRegistrationError(r, regFlow, params, err) } redirectTo := regFlow.AppendTo(s.d.Config().SelfServiceFlowRegistrationUI(ctx)).String() @@ -168,7 +168,7 @@ func (s *Strategy) displayStepTwoNodes(w http.ResponseWriter, r *http.Request, r regFlow.TransientPayload = params.TransientPayload if err := flow.EnsureCSRF(s.d, r, regFlow.Type, s.d.Config().DisableAPIFlowEnforcement(r.Context()), s.d.GenerateCSRFToken, params.CSRFToken); err != nil { - return s.handleRegistrationError(w, r, regFlow, ¶ms, err) + return s.handleRegistrationError(r, regFlow, params, err) } if len(params.Traits) == 0 { @@ -176,12 +176,12 @@ func (s *Strategy) displayStepTwoNodes(w http.ResponseWriter, r *http.Request, r } i.Traits = identity.Traits(params.Traits) if err := s.d.IdentityValidator().Validate(ctx, i); err != nil { - return s.handleRegistrationError(w, r, regFlow, ¶ms, err) + return s.handleRegistrationError(r, regFlow, params, err) } err := json.Unmarshal([]byte(gjson.GetBytes(regFlow.InternalContext, "stepTwoNodes").Raw), ®Flow.UI.Nodes) if err != nil { - return s.handleRegistrationError(w, r, regFlow, ¶ms, err) + return s.handleRegistrationError(r, regFlow, params, err) } regFlow.UI.Messages.Add(text.NewInfoSelfServiceChooseCredentials()) @@ -208,7 +208,7 @@ func (s *Strategy) displayStepTwoNodes(w http.ResponseWriter, r *http.Request, r } if err = s.d.RegistrationFlowPersister().UpdateRegistrationFlow(ctx, regFlow); err != nil { - return s.handleRegistrationError(w, r, regFlow, ¶ms, err) + return s.handleRegistrationError(r, regFlow, params, err) } redirectTo := regFlow.AppendTo(s.d.Config().SelfServiceFlowRegistrationUI(ctx)).String() @@ -221,13 +221,11 @@ func (s *Strategy) displayStepTwoNodes(w http.ResponseWriter, r *http.Request, r return flow.ErrCompletedByStrategy } -func (s *Strategy) handleRegistrationError(_ http.ResponseWriter, r *http.Request, regFlow *registration.Flow, params *updateRegistrationFlowWithProfileMethod, err error) error { +func (s *Strategy) handleRegistrationError(r *http.Request, regFlow *registration.Flow, params updateRegistrationFlowWithProfileMethod, err error) error { if regFlow != nil { - if params != nil { - for _, n := range container.NewFromJSON("", node.ProfileGroup, params.Traits, "traits").Nodes { - // we only set the value and not the whole field because we want to keep types from the initial form generation - regFlow.UI.Nodes.SetValueAttribute(n.ID(), n.Attributes.GetValue()) - } + for _, n := range container.NewFromJSON("", node.ProfileGroup, params.Traits, "traits").Nodes { + // we only set the value and not the whole field because we want to keep types from the initial form generation + regFlow.UI.Nodes.SetValueAttribute(n.ID(), n.Attributes.GetValue()) } if regFlow.Type == flow.TypeBrowser { diff --git a/selfservice/strategy/totp/settings.go b/selfservice/strategy/totp/settings.go index bbb3f5496dc5..0c24d915868b 100644 --- a/selfservice/strategy/totp/settings.go +++ b/selfservice/strategy/totp/settings.go @@ -87,29 +87,29 @@ func (s *Strategy) Settings(w http.ResponseWriter, r *http.Request, f *settings. var p updateSettingsFlowWithTotpMethod ctxUpdate, err := settings.PrepareUpdate(s.d, w, r, f, ss, settings.ContinuityKey(s.SettingsStrategyID()), &p) if errors.Is(err, settings.ErrContinuePreviousAction) { - return ctxUpdate, s.continueSettingsFlow(w, r, ctxUpdate, &p) + return ctxUpdate, s.continueSettingsFlow(r, ctxUpdate, p) } else if err != nil { - return ctxUpdate, s.handleSettingsError(w, r, ctxUpdate, &p, err) + return ctxUpdate, s.handleSettingsError(w, r, ctxUpdate, p, err) } if err := s.decodeSettingsFlow(r, &p); err != nil { - return ctxUpdate, s.handleSettingsError(w, r, ctxUpdate, &p, err) + return ctxUpdate, s.handleSettingsError(w, r, ctxUpdate, p, err) } if p.UnlinkTOTP { // This is a submit so we need to manually set the type to TOTP p.Method = s.SettingsStrategyID() if err := flow.MethodEnabledAndAllowed(r.Context(), f.GetFlowName(), s.SettingsStrategyID(), p.Method, s.d); err != nil { - return nil, s.handleSettingsError(w, r, ctxUpdate, &p, err) + return nil, s.handleSettingsError(w, r, ctxUpdate, p, err) } } else if err := flow.MethodEnabledAndAllowedFromRequest(r, f.GetFlowName(), s.SettingsStrategyID(), s.d); err != nil { - return ctxUpdate, s.handleSettingsError(w, r, ctxUpdate, &p, err) + return ctxUpdate, s.handleSettingsError(w, r, ctxUpdate, p, err) } // This does not come from the payload! p.Flow = ctxUpdate.Flow.ID.String() - if err := s.continueSettingsFlow(w, r, ctxUpdate, &p); err != nil { - return ctxUpdate, s.handleSettingsError(w, r, ctxUpdate, &p, err) + if err := s.continueSettingsFlow(r, ctxUpdate, p); err != nil { + return ctxUpdate, s.handleSettingsError(w, r, ctxUpdate, p, err) } return ctxUpdate, nil @@ -128,23 +128,21 @@ func (s *Strategy) decodeSettingsFlow(r *http.Request, dest interface{}) error { ) } -func (s *Strategy) continueSettingsFlow( - w http.ResponseWriter, r *http.Request, - ctxUpdate *settings.UpdateContext, p *updateSettingsFlowWithTotpMethod, -) error { - if err := flow.MethodEnabledAndAllowed(r.Context(), flow.SettingsFlow, s.SettingsStrategyID(), p.Method, s.d); err != nil { +func (s *Strategy) continueSettingsFlow(r *http.Request, ctxUpdate *settings.UpdateContext, p updateSettingsFlowWithTotpMethod) error { + ctx := r.Context() + if err := flow.MethodEnabledAndAllowed(ctx, flow.SettingsFlow, s.SettingsStrategyID(), p.Method, s.d); err != nil { return err } - if err := flow.EnsureCSRF(s.d, r, ctxUpdate.Flow.Type, s.d.Config().DisableAPIFlowEnforcement(r.Context()), s.d.GenerateCSRFToken, p.CSRFToken); err != nil { + if err := flow.EnsureCSRF(s.d, r, ctxUpdate.Flow.Type, s.d.Config().DisableAPIFlowEnforcement(ctx), s.d.GenerateCSRFToken, p.CSRFToken); err != nil { return err } - if ctxUpdate.Session.AuthenticatedAt.Add(s.d.Config().SelfServiceFlowSettingsPrivilegedSessionMaxAge(r.Context())).Before(time.Now()) { + if ctxUpdate.Session.AuthenticatedAt.Add(s.d.Config().SelfServiceFlowSettingsPrivilegedSessionMaxAge(ctx)).Before(time.Now()) { return errors.WithStack(settings.NewFlowNeedsReAuth()) } - hasTOTP, err := s.identityHasTOTP(r.Context(), ctxUpdate.Session.IdentityID) + hasTOTP, err := s.identityHasTOTP(ctx, ctxUpdate.Session.IdentityID) if err != nil { return err } @@ -155,9 +153,9 @@ func (s *Strategy) continueSettingsFlow( // 2. TOTP should be added -> we do not have it yet var i *identity.Identity if hasTOTP { - i, err = s.continueSettingsFlowRemoveTOTP(w, r, ctxUpdate, p) + i, err = s.continueSettingsFlowRemoveTOTP(ctx, ctxUpdate, p) } else { - i, err = s.continueSettingsFlowAddTOTP(w, r, ctxUpdate, p) + i, err = s.continueSettingsFlowAddTOTP(ctx, ctxUpdate, p) } if err != nil { @@ -168,7 +166,7 @@ func (s *Strategy) continueSettingsFlow( return nil } -func (s *Strategy) continueSettingsFlowAddTOTP(w http.ResponseWriter, r *http.Request, ctxUpdate *settings.UpdateContext, p *updateSettingsFlowWithTotpMethod) (*identity.Identity, error) { +func (s *Strategy) continueSettingsFlowAddTOTP(ctx context.Context, ctxUpdate *settings.UpdateContext, p updateSettingsFlowWithTotpMethod) (*identity.Identity, error) { keyURL := gjson.GetBytes(ctxUpdate.Flow.InternalContext, flow.PrefixInternalContextKey(s.ID(), InternalContextKeyURL)).String() if len(keyURL) == 0 { return nil, errors.WithStack(herodot.ErrInternalServerError.WithReasonf("Could not find they TOTP key in the internal context. This is a code bug and should be reported to https://github.com/ory/kratos/.")) @@ -196,7 +194,7 @@ func (s *Strategy) continueSettingsFlowAddTOTP(w http.ResponseWriter, r *http.Re return nil, errors.WithStack(herodot.ErrInternalServerError.WithReasonf("Unable to encode totp options to JSON: %s", err)) } - i, err := s.d.PrivilegedIdentityPool().GetIdentityConfidential(r.Context(), ctxUpdate.Session.Identity.ID) + i, err := s.d.PrivilegedIdentityPool().GetIdentityConfidential(ctx, ctxUpdate.Session.Identity.ID) if err != nil { return nil, err } @@ -212,12 +210,12 @@ func (s *Strategy) continueSettingsFlowAddTOTP(w http.ResponseWriter, r *http.Re return nil, err } - if err := s.d.SettingsFlowPersister().UpdateSettingsFlow(r.Context(), ctxUpdate.Flow); err != nil { + if err := s.d.SettingsFlowPersister().UpdateSettingsFlow(ctx, ctxUpdate.Flow); err != nil { return nil, err } // Since we added the method, it also means that we have authenticated it - if err := s.d.SessionManager().SessionAddAuthenticationMethods(r.Context(), ctxUpdate.Session.ID, session.AuthenticationMethod{ + if err := s.d.SessionManager().SessionAddAuthenticationMethods(ctx, ctxUpdate.Session.ID, session.AuthenticationMethod{ Method: s.ID(), AAL: identity.AuthenticatorAssuranceLevel2, }); err != nil { @@ -227,12 +225,12 @@ func (s *Strategy) continueSettingsFlowAddTOTP(w http.ResponseWriter, r *http.Re return i, nil } -func (s *Strategy) continueSettingsFlowRemoveTOTP(w http.ResponseWriter, r *http.Request, ctxUpdate *settings.UpdateContext, p *updateSettingsFlowWithTotpMethod) (*identity.Identity, error) { +func (s *Strategy) continueSettingsFlowRemoveTOTP(ctx context.Context, ctxUpdate *settings.UpdateContext, p updateSettingsFlowWithTotpMethod) (*identity.Identity, error) { if !p.UnlinkTOTP { return ctxUpdate.Session.Identity, nil } - i, err := s.d.PrivilegedIdentityPool().GetIdentityConfidential(r.Context(), ctxUpdate.Session.Identity.ID) + i, err := s.d.PrivilegedIdentityPool().GetIdentityConfidential(ctx, ctxUpdate.Session.Identity.ID) if err != nil { return nil, err } @@ -295,7 +293,7 @@ func (s *Strategy) PopulateSettingsMethod(r *http.Request, id *identity.Identity return nil } -func (s *Strategy) handleSettingsError(w http.ResponseWriter, r *http.Request, ctxUpdate *settings.UpdateContext, p *updateSettingsFlowWithTotpMethod, err error) error { +func (s *Strategy) handleSettingsError(w http.ResponseWriter, r *http.Request, ctxUpdate *settings.UpdateContext, p updateSettingsFlowWithTotpMethod, err error) error { // Do not pause flow if the flow type is an API flow as we can't save cookies in those flows. if e := new(settings.FlowNeedsReAuth); errors.As(err, &e) && ctxUpdate.Flow != nil && ctxUpdate.Flow.Type == flow.TypeBrowser { if err := s.d.ContinuityManager().Pause(r.Context(), w, r, settings.ContinuityKey(s.SettingsStrategyID()), settings.ContinuityOptions(p, ctxUpdate.GetSessionIdentity())...); err != nil { diff --git a/selfservice/strategy/webauthn/registration.go b/selfservice/strategy/webauthn/registration.go index e3ca6c9e5fd7..81d94e0028e7 100644 --- a/selfservice/strategy/webauthn/registration.go +++ b/selfservice/strategy/webauthn/registration.go @@ -69,17 +69,15 @@ type updateRegistrationFlowWithWebAuthnMethod struct { func (s *Strategy) RegisterRegistrationRoutes(_ *x.RouterPublic) { } -func (s *Strategy) handleRegistrationError(_ http.ResponseWriter, r *http.Request, f *registration.Flow, p *updateRegistrationFlowWithWebAuthnMethod, err error) error { +func (s *Strategy) handleRegistrationError(r *http.Request, f *registration.Flow, p updateRegistrationFlowWithWebAuthnMethod, err error) error { if f != nil { - if p != nil { - for _, n := range container.NewFromJSON("", node.DefaultGroup, p.Traits, "traits").Nodes { - // we only set the value and not the whole field because we want to keep types from the initial form generation - f.UI.Nodes.SetValueAttribute(n.ID(), n.Attributes.GetValue()) - } - - f.UI.Nodes.SetValueAttribute(node.WebAuthnRegisterDisplayName, p.RegisterDisplayName) + for _, n := range container.NewFromJSON("", node.DefaultGroup, p.Traits, "traits").Nodes { + // we only set the value and not the whole field because we want to keep types from the initial form generation + f.UI.Nodes.SetValueAttribute(n.ID(), n.Attributes.GetValue()) } + f.UI.Nodes.SetValueAttribute(node.WebAuthnRegisterDisplayName, p.RegisterDisplayName) + if f.Type == flow.TypeBrowser { f.UI.SetCSRF(s.d.GenerateCSRFToken(r)) } @@ -101,13 +99,13 @@ func (s *Strategy) Register(w http.ResponseWriter, r *http.Request, regFlow *reg var p updateRegistrationFlowWithWebAuthnMethod if err := s.decode(&p, r); err != nil { - return s.handleRegistrationError(w, r, regFlow, &p, err) + return s.handleRegistrationError(r, regFlow, p, err) } regFlow.TransientPayload = p.TransientPayload if err := flow.EnsureCSRF(s.d, r, regFlow.Type, s.d.Config().DisableAPIFlowEnforcement(ctx), s.d.GenerateCSRFToken, p.CSRFToken); err != nil { - return s.handleRegistrationError(w, r, regFlow, &p, err) + return s.handleRegistrationError(r, regFlow, p, err) } if len(p.Register) == 0 { @@ -116,7 +114,7 @@ func (s *Strategy) Register(w http.ResponseWriter, r *http.Request, regFlow *reg p.Method = s.SettingsStrategyID() if err := flow.MethodEnabledAndAllowed(ctx, regFlow.GetFlowName(), s.SettingsStrategyID(), p.Method, s.d); err != nil { - return s.handleRegistrationError(w, r, regFlow, &p, err) + return s.handleRegistrationError(r, regFlow, p, err) } if len(p.Traits) == 0 { @@ -126,25 +124,25 @@ func (s *Strategy) Register(w http.ResponseWriter, r *http.Request, regFlow *reg webAuthnSession := gjson.GetBytes(regFlow.InternalContext, flow.PrefixInternalContextKey(s.ID(), InternalContextKeySessionData)) if !webAuthnSession.IsObject() { - return s.handleRegistrationError(w, r, regFlow, &p, errors.WithStack( + return s.handleRegistrationError(r, regFlow, p, errors.WithStack( herodot.ErrInternalServerError.WithReasonf("Expected WebAuthN in internal context to be an object."))) } var webAuthnSess webauthn.SessionData if err := json.Unmarshal([]byte(webAuthnSession.Raw), &webAuthnSess); err != nil { - return s.handleRegistrationError(w, r, regFlow, &p, errors.WithStack( + return s.handleRegistrationError(r, regFlow, p, errors.WithStack( herodot.ErrInternalServerError.WithReasonf("Expected WebAuthN in internal context to be an object but got: %s", err))) } webAuthnResponse, err := protocol.ParseCredentialCreationResponseBody(strings.NewReader(p.Register)) if err != nil { - return s.handleRegistrationError(w, r, regFlow, &p, errors.WithStack( + return s.handleRegistrationError(r, regFlow, p, errors.WithStack( herodot.ErrBadRequest.WithReasonf("Unable to parse WebAuthn response: %s", err))) } web, err := webauthn.New(s.d.Config().WebAuthnConfig(r.Context())) if err != nil { - return s.handleRegistrationError(w, r, regFlow, &p, errors.WithStack( + return s.handleRegistrationError(r, regFlow, p, errors.WithStack( herodot.ErrInternalServerError.WithReasonf("Unable to get webAuthn config.").WithDebug(err.Error()))) } @@ -153,7 +151,7 @@ func (s *Strategy) Register(w http.ResponseWriter, r *http.Request, regFlow *reg if devErr := new(protocol.Error); errors.As(err, &devErr) { s.d.Logger().WithError(err).WithField("error_devinfo", devErr.DevInfo).Error("Failed to create WebAuthn credential") } - return s.handleRegistrationError(w, r, regFlow, &p, errors.WithStack( + return s.handleRegistrationError(r, regFlow, p, errors.WithStack( herodot.ErrInternalServerError.WithReasonf("Unable to create WebAuthn credential: %s", err))) } @@ -164,23 +162,23 @@ func (s *Strategy) Register(w http.ResponseWriter, r *http.Request, regFlow *reg UserHandle: webAuthnSess.UserID, }) if err != nil { - return s.handleRegistrationError(w, r, regFlow, &p, errors.WithStack( + return s.handleRegistrationError(r, regFlow, p, errors.WithStack( herodot.ErrInternalServerError.WithReasonf("Unable to encode identity credentials.").WithDebug(err.Error()))) } i.UpsertCredentialsConfig(s.ID(), credentialWebAuthnConfig, 1) if err := s.validateCredentials(ctx, i); err != nil { - return s.handleRegistrationError(w, r, regFlow, &p, err) + return s.handleRegistrationError(r, regFlow, p, err) } // Remove the WebAuthn URL from the internal context now that it is set! regFlow.InternalContext, err = sjson.DeleteBytes(regFlow.InternalContext, flow.PrefixInternalContextKey(s.ID(), InternalContextKeySessionData)) if err != nil { - return s.handleRegistrationError(w, r, regFlow, &p, err) + return s.handleRegistrationError(r, regFlow, p, err) } if err := s.d.RegistrationFlowPersister().UpdateRegistrationFlow(ctx, regFlow); err != nil { - return s.handleRegistrationError(w, r, regFlow, &p, err) + return s.handleRegistrationError(r, regFlow, p, err) } return nil diff --git a/selfservice/strategy/webauthn/settings.go b/selfservice/strategy/webauthn/settings.go index 3626b3fbe61c..6e97c31b54f9 100644 --- a/selfservice/strategy/webauthn/settings.go +++ b/selfservice/strategy/webauthn/settings.go @@ -105,20 +105,20 @@ func (s *Strategy) Settings(w http.ResponseWriter, r *http.Request, f *settings. var p updateSettingsFlowWithWebAuthnMethod ctxUpdate, err := settings.PrepareUpdate(s.d, w, r, f, ss, settings.ContinuityKey(s.SettingsStrategyID()), &p) if errors.Is(err, settings.ErrContinuePreviousAction) { - return ctxUpdate, s.continueSettingsFlow(w, r, ctxUpdate, &p) + return ctxUpdate, s.continueSettingsFlow(w, r, ctxUpdate, p) } else if err != nil { - return ctxUpdate, s.handleSettingsError(w, r, ctxUpdate, &p, err) + return ctxUpdate, s.handleSettingsError(w, r, ctxUpdate, p, err) } if err := s.decodeSettingsFlow(r, &p); err != nil { - return ctxUpdate, s.handleSettingsError(w, r, ctxUpdate, &p, err) + return ctxUpdate, s.handleSettingsError(w, r, ctxUpdate, p, err) } if len(p.Register+p.Remove) > 0 { // This method has only two submit buttons p.Method = s.SettingsStrategyID() if err := flow.MethodEnabledAndAllowed(r.Context(), f.GetFlowName(), s.SettingsStrategyID(), p.Method, s.d); err != nil { - return nil, s.handleSettingsError(w, r, ctxUpdate, &p, err) + return nil, s.handleSettingsError(w, r, ctxUpdate, p, err) } } else { return nil, errors.WithStack(flow.ErrStrategyNotResponsible) @@ -126,8 +126,8 @@ func (s *Strategy) Settings(w http.ResponseWriter, r *http.Request, f *settings. // This does not come from the payload! p.Flow = ctxUpdate.Flow.ID.String() - if err := s.continueSettingsFlow(w, r, ctxUpdate, &p); err != nil { - return ctxUpdate, s.handleSettingsError(w, r, ctxUpdate, &p, err) + if err := s.continueSettingsFlow(w, r, ctxUpdate, p); err != nil { + return ctxUpdate, s.handleSettingsError(w, r, ctxUpdate, p, err) } return ctxUpdate, nil @@ -148,7 +148,7 @@ func (s *Strategy) decodeSettingsFlow(r *http.Request, dest interface{}) error { func (s *Strategy) continueSettingsFlow( w http.ResponseWriter, r *http.Request, - ctxUpdate *settings.UpdateContext, p *updateSettingsFlowWithWebAuthnMethod, + ctxUpdate *settings.UpdateContext, p updateSettingsFlowWithWebAuthnMethod, ) error { if len(p.Register+p.Remove) > 0 { if err := flow.MethodEnabledAndAllowed(r.Context(), flow.SettingsFlow, s.SettingsStrategyID(), s.SettingsStrategyID(), s.d); err != nil { @@ -175,7 +175,7 @@ func (s *Strategy) continueSettingsFlow( return errors.New("ended up in unexpected state") } -func (s *Strategy) continueSettingsFlowRemove(w http.ResponseWriter, r *http.Request, ctxUpdate *settings.UpdateContext, p *updateSettingsFlowWithWebAuthnMethod) error { +func (s *Strategy) continueSettingsFlowRemove(w http.ResponseWriter, r *http.Request, ctxUpdate *settings.UpdateContext, p updateSettingsFlowWithWebAuthnMethod) error { i, err := s.d.PrivilegedIdentityPool().GetIdentityConfidential(r.Context(), ctxUpdate.Session.IdentityID) if err != nil { return err @@ -231,7 +231,7 @@ func (s *Strategy) continueSettingsFlowRemove(w http.ResponseWriter, r *http.Req return nil } -func (s *Strategy) continueSettingsFlowAdd(r *http.Request, ctxUpdate *settings.UpdateContext, p *updateSettingsFlowWithWebAuthnMethod) error { +func (s *Strategy) continueSettingsFlowAdd(r *http.Request, ctxUpdate *settings.UpdateContext, p updateSettingsFlowWithWebAuthnMethod) error { webAuthnSession := gjson.GetBytes(ctxUpdate.Flow.InternalContext, flow.PrefixInternalContextKey(s.ID(), InternalContextKeySessionData)) if !webAuthnSession.IsObject() { return errors.WithStack(herodot.ErrInternalServerError.WithReasonf("Expected WebAuthN in internal context to be an object.")) @@ -388,7 +388,7 @@ func (s *Strategy) PopulateSettingsMethod(r *http.Request, id *identity.Identity return nil } -func (s *Strategy) handleSettingsError(w http.ResponseWriter, r *http.Request, ctxUpdate *settings.UpdateContext, p *updateSettingsFlowWithWebAuthnMethod, err error) error { +func (s *Strategy) handleSettingsError(w http.ResponseWriter, r *http.Request, ctxUpdate *settings.UpdateContext, p updateSettingsFlowWithWebAuthnMethod, err error) error { // Do not pause flow if the flow type is an API flow as we can't save cookies in those flows. if e := new(settings.FlowNeedsReAuth); errors.As(err, &e) && ctxUpdate.Flow != nil && ctxUpdate.Flow.Type == flow.TypeBrowser { if err := s.d.ContinuityManager().Pause(r.Context(), w, r, settings.ContinuityKey(s.SettingsStrategyID()), settings.ContinuityOptions(p, ctxUpdate.GetSessionIdentity())...); err != nil { diff --git a/selfservice/strategy/webauthn/settings_test.go b/selfservice/strategy/webauthn/settings_test.go index 01ccdbf48578..dd75fc335204 100644 --- a/selfservice/strategy/webauthn/settings_test.go +++ b/selfservice/strategy/webauthn/settings_test.go @@ -332,6 +332,7 @@ func TestCompleteSettings(t *testing.T) { values.Set(node.WebAuthnRegister, string(settingsFixtureSuccessResponse)) values.Set(node.WebAuthnRegisterDisplayName, "foobar") body, res := testhelpers.SettingsMakeRequest(t, false, spa, f, browserClient, testhelpers.EncodeFormAsJSON(t, spa, values)) + require.Equal(t, http.StatusOK, res.StatusCode, body) if spa { assert.Contains(t, res.Request.URL.String(), publicTS.URL+settings.RouteSubmitFlow) @@ -343,7 +344,7 @@ func TestCompleteSettings(t *testing.T) { actual, err := reg.Persister().GetIdentityConfidential(context.Background(), id.ID) require.NoError(t, err) cred, ok := actual.GetCredentials(identity.CredentialsTypeWebAuthn) - assert.True(t, ok) + require.True(t, ok) assert.Len(t, gjson.GetBytes(cred.Config, "credentials").Array(), 1) actualFlow, err := reg.SettingsFlowPersister().GetSettingsFlow(context.Background(), uuid.FromStringOrNil(f.Id))