Skip to content

Commit

Permalink
fix: refactor internal API to prevent panics (#4028)
Browse files Browse the repository at this point in the history
  • Loading branch information
zepatrik authored Aug 12, 2024
1 parent 4f4394c commit 81bc152
Show file tree
Hide file tree
Showing 20 changed files with 222 additions and 222 deletions.
1 change: 1 addition & 0 deletions internal/client-go/go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ github.com/golang/protobuf v1.2.0/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5y
golang.org/x/net v0.0.0-20180724234803-3673e40ba225/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4=
golang.org/x/net v0.0.0-20190108225652-1e06a53dbb7e h1:bRhVy7zSSasaqNksaRZiA5EEI+Ei4I1nO5Jh72wfHlg=
golang.org/x/net v0.0.0-20190108225652-1e06a53dbb7e/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4=
golang.org/x/oauth2 v0.0.0-20190604053449-0f29369cfe45/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw=
golang.org/x/sync v0.0.0-20181221193216-37e7f081c4d4 h1:YUO/7uOKsKeq9UokNS62b8FYywz3ker1l1vDZRCRefw=
golang.org/x/sync v0.0.0-20181221193216-37e7f081c4d4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
Expand Down
2 changes: 1 addition & 1 deletion selfservice/flow/settings/handler_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -580,7 +580,7 @@ func TestHandler(t *testing.T) {
require.NoError(t, json.Unmarshal(body, &f))

actual, res := testhelpers.SettingsMakeRequest(t, false, true, &f, primaryUser, fmt.Sprintf(`{"method":"profile", "numby": 15, "csrf_token": "%s"}`, x.FakeCSRFToken))
assert.Equal(t, http.StatusOK, res.StatusCode)
require.Equal(t, http.StatusOK, res.StatusCode)
require.Len(t, primaryUser.Jar.Cookies(urlx.ParseOrPanic(publicTS.URL+login.RouteGetFlow)), 1)
require.Contains(t, fmt.Sprintf("%v", primaryUser.Jar.Cookies(urlx.ParseOrPanic(publicTS.URL))), "ory_kratos_session")
assert.Equal(t, "Your changes have been saved!", gjson.Get(actual, "ui.messages.0.text").String(), actual)
Expand Down
34 changes: 19 additions & 15 deletions selfservice/strategy/code/strategy_verification.go
Original file line number Diff line number Diff line change
Expand Up @@ -66,11 +66,15 @@ func (s *Strategy) decodeVerification(r *http.Request) (*updateVerificationFlowW
}

// handleVerificationError is a convenience function for handling all types of errors that may occur (e.g. validation error).
func (s *Strategy) handleVerificationError(w http.ResponseWriter, r *http.Request, f *verification.Flow, body *updateVerificationFlowWithCodeMethod, err error) error {
func (s *Strategy) handleVerificationError(r *http.Request, f *verification.Flow, body *updateVerificationFlowWithCodeMethod, err error) error {
if f != nil {
f.UI.SetCSRF(s.deps.GenerateCSRFToken(r))
email := ""
if body != nil {
email = body.Email
}
f.UI.GetNodes().Upsert(
node.NewInputField("email", body.Email, node.CodeGroup, node.InputAttributeTypeEmail, node.WithRequiredInputAttribute).WithMetaLabel(text.NewInfoNodeInputEmail()),
node.NewInputField("email", email, node.CodeGroup, node.InputAttributeTypeEmail, node.WithRequiredInputAttribute).WithMetaLabel(text.NewInfoNodeInputEmail()),
)
}

Expand Down Expand Up @@ -137,17 +141,17 @@ func (s *Strategy) Verify(w http.ResponseWriter, r *http.Request, f *verificatio

body, err := s.decodeVerification(r)
if err != nil {
return s.handleVerificationError(w, r, nil, body, err)
return s.handleVerificationError(r, nil, body, err)
}

f.TransientPayload = body.TransientPayload

if err := flow.MethodEnabledAndAllowed(r.Context(), f.GetFlowName(), s.VerificationStrategyID(), string(body.getMethod()), s.deps); err != nil {
return s.handleVerificationError(w, r, f, body, err)
return s.handleVerificationError(r, f, body, err)
}

if err := f.Valid(); err != nil {
return s.handleVerificationError(w, r, f, body, err)
return s.handleVerificationError(r, f, body, err)
}

switch f.State {
Expand Down Expand Up @@ -197,28 +201,28 @@ func (s *Strategy) verificationHandleFormSubmission(w http.ResponseWriter, r *ht
return s.verificationUseCode(w, r, body.Code, f)
} else if len(body.Email) == 0 {
// If no code and no email was provided, fail with a validation error
return s.handleVerificationError(w, r, f, body, schema.NewRequiredError("#/email", "email"))
return s.handleVerificationError(r, f, body, schema.NewRequiredError("#/email", "email"))
}

if err := flow.EnsureCSRF(s.deps, r, f.Type, s.deps.Config().DisableAPIFlowEnforcement(r.Context()), s.deps.GenerateCSRFToken, body.CSRFToken); err != nil {
return s.handleVerificationError(w, r, f, body, err)
return s.handleVerificationError(r, f, body, err)
}

if err := s.deps.VerificationCodePersister().DeleteVerificationCodesOfFlow(r.Context(), f.ID); err != nil {
return s.handleVerificationError(w, r, f, body, err)
return s.handleVerificationError(r, f, body, err)
}

if err := s.deps.CodeSender().SendVerificationCode(r.Context(), f, identity.VerifiableAddressTypeEmail, body.Email); err != nil {
if !errors.Is(err, ErrUnknownAddress) {
return s.handleVerificationError(w, r, f, body, err)
return s.handleVerificationError(r, f, body, err)
}
// Continue execution
}

f.State = flow.StateEmailSent

if err := s.PopulateVerificationMethod(r, f); err != nil {
return s.handleVerificationError(w, r, f, body, err)
return s.handleVerificationError(r, f, body, err)
}

if body.Email != "" {
Expand All @@ -229,7 +233,7 @@ func (s *Strategy) verificationHandleFormSubmission(w http.ResponseWriter, r *ht
}

if err := s.deps.VerificationFlowPersister().UpdateVerificationFlow(r.Context(), f); err != nil {
return s.handleVerificationError(w, r, f, body, err)
return s.handleVerificationError(r, f, body, err)
}

return nil
Expand Down Expand Up @@ -305,13 +309,13 @@ func (s *Strategy) retryVerificationFlowWithMessage(w http.ResponseWriter, r *ht
f, err := verification.NewFlow(s.deps.Config(),
s.deps.Config().SelfServiceFlowVerificationRequestLifespan(r.Context()), s.deps.CSRFHandler().RegenerateToken(w, r), r, s, ft)
if err != nil {
return s.handleVerificationError(w, r, f, nil, err)
return s.handleVerificationError(r, f, nil, err)
}

f.UI.Messages.Add(message)

if err := s.deps.VerificationFlowPersister().CreateVerificationFlow(r.Context(), f); err != nil {
return s.handleVerificationError(w, r, f, nil, err)
return s.handleVerificationError(r, f, nil, err)
}

if x.IsJSONRequest(r) {
Expand All @@ -333,7 +337,7 @@ func (s *Strategy) retryVerificationFlowWithError(w http.ResponseWriter, r *http
f, err := verification.NewFlow(s.deps.Config(),
s.deps.Config().SelfServiceFlowVerificationRequestLifespan(r.Context()), s.deps.CSRFHandler().RegenerateToken(w, r), r, s, ft)
if err != nil {
return s.handleVerificationError(w, r, f, nil, err)
return s.handleVerificationError(r, f, nil, err)
}

var toReturn error
Expand All @@ -346,7 +350,7 @@ func (s *Strategy) retryVerificationFlowWithError(w http.ResponseWriter, r *http
}

if err := s.deps.VerificationFlowPersister().CreateVerificationFlow(r.Context(), f); err != nil {
return s.handleVerificationError(w, r, f, nil, err)
return s.handleVerificationError(r, f, nil, err)
}

if x.IsJSONRequest(r) {
Expand Down
16 changes: 8 additions & 8 deletions selfservice/strategy/idfirst/strategy_login.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ var (
ErrNoCredentialsFound = errors.New("no credentials found")
)

func (s *Strategy) handleLoginError(w http.ResponseWriter, r *http.Request, f *login.Flow, payload *updateLoginFlowWithIdentifierFirstMethod, err error) error {
func (s *Strategy) handleLoginError(r *http.Request, f *login.Flow, payload updateLoginFlowWithIdentifierFirstMethod, err error) error {
if f != nil {
f.UI.Nodes.SetValueAttribute("identifier", payload.Identifier)
if f.Type == flow.TypeBrowser {
Expand All @@ -52,12 +52,12 @@ func (s *Strategy) Login(w http.ResponseWriter, r *http.Request, f *login.Flow,
decoderx.HTTPDecoderSetValidatePayloads(true),
decoderx.MustHTTPRawJSONSchemaCompiler(loginSchema),
decoderx.HTTPDecoderJSONFollowsFormFormat()); err != nil {
return nil, s.handleLoginError(w, r, f, &p, err)
return nil, s.handleLoginError(r, f, p, err)
}
f.TransientPayload = p.TransientPayload

if err := flow.EnsureCSRF(s.d, r, f.Type, s.d.Config().DisableAPIFlowEnforcement(r.Context()), s.d.GenerateCSRFToken, p.CSRFToken); err != nil {
return nil, s.handleLoginError(w, r, f, &p, err)
return nil, s.handleLoginError(r, f, p, err)
}

var opts []login.FormHydratorModifier
Expand All @@ -74,11 +74,11 @@ func (s *Strategy) Login(w http.ResponseWriter, r *http.Request, f *login.Flow,
// This will later be handled by `didPopulate`.
} else if err != nil {
// An error happened during lookup
return nil, s.handleLoginError(w, r, f, &p, err)
return nil, s.handleLoginError(r, f, p, err)
} else if !s.d.Config().SecurityAccountEnumerationMitigate(r.Context()) {
// Hydrate credentials
if err := s.d.PrivilegedIdentityPool().HydrateIdentityAssociations(r.Context(), identityHint, identity.ExpandCredentials); err != nil {
return nil, s.handleLoginError(w, r, f, &p, err)
return nil, s.handleLoginError(r, f, p, err)
}
}

Expand All @@ -102,7 +102,7 @@ func (s *Strategy) Login(w http.ResponseWriter, r *http.Request, f *login.Flow,
} else if errors.Is(err, ErrNoCredentialsFound) {
// This strategy is not responsible for this flow. We do not set didPopulate to true if that happens.
} else if err != nil {
return nil, s.handleLoginError(w, r, f, &p, err)
return nil, s.handleLoginError(r, f, p, err)
} else {
didPopulate = true
}
Expand All @@ -111,7 +111,7 @@ func (s *Strategy) Login(w http.ResponseWriter, r *http.Request, f *login.Flow,
// If no strategy populated, it means that the account (very likely) does not exist. We show a user not found error,
// but only if account enumeration mitigation is disabled. Otherwise, we proceed to render the rest of the form.
if !didPopulate && !s.d.Config().SecurityAccountEnumerationMitigate(r.Context()) {
return nil, s.handleLoginError(w, r, f, &p, errors.WithStack(schema.NewAccountNotFoundError()))
return nil, s.handleLoginError(r, f, p, errors.WithStack(schema.NewAccountNotFoundError()))
}

// We found credentials - hide the identifier.
Expand All @@ -134,7 +134,7 @@ func (s *Strategy) Login(w http.ResponseWriter, r *http.Request, f *login.Flow,

f.Active = s.ID()
if err = s.d.LoginFlowPersister().UpdateLoginFlow(r.Context(), f); err != nil {
return nil, s.handleLoginError(w, r, f, &p, err)
return nil, s.handleLoginError(r, f, p, err)
}

if x.IsJSONRequest(r) {
Expand Down
43 changes: 22 additions & 21 deletions selfservice/strategy/link/strategy_verification.go
Original file line number Diff line number Diff line change
Expand Up @@ -79,12 +79,16 @@ func (s *Strategy) decodeVerification(r *http.Request) (*verificationSubmitPaylo
}

// handleVerificationError is a convenience function for handling all types of errors that may occur (e.g. validation error).
func (s *Strategy) handleVerificationError(w http.ResponseWriter, r *http.Request, f *verification.Flow, body *verificationSubmitPayload, err error) error {
func (s *Strategy) handleVerificationError(r *http.Request, f *verification.Flow, body *verificationSubmitPayload, err error) error {
if f != nil {
f.UI.SetCSRF(s.d.GenerateCSRFToken(r))
email := ""
if body != nil {
email = body.Email
}
f.UI.GetNodes().Upsert(
// v0.5: form.Field{Name: "email", Type: "email", Required: true, Value: body.Body.Email}
node.NewInputField("email", body.Email, node.LinkGroup, node.InputAttributeTypeEmail, node.WithRequiredInputAttribute).WithMetaLabel(text.NewInfoNodeInputEmail()),
node.NewInputField("email", email, node.LinkGroup, node.InputAttributeTypeEmail, node.WithRequiredInputAttribute).WithMetaLabel(text.NewInfoNodeInputEmail()),
)
}

Expand Down Expand Up @@ -132,56 +136,53 @@ func (s *Strategy) Verify(w http.ResponseWriter, r *http.Request, f *verificatio

body, err := s.decodeVerification(r)
if err != nil {
return s.handleVerificationError(w, r, nil, body, err)
return s.handleVerificationError(r, nil, body, err)
}
f.TransientPayload = body.TransientPayload

if len(body.Token) > 0 {
if err := flow.MethodEnabledAndAllowed(r.Context(), f.GetFlowName(), s.VerificationStrategyID(), s.VerificationStrategyID(), s.d); err != nil {
return s.handleVerificationError(w, r, nil, body, err)
return s.handleVerificationError(r, nil, body, err)
}

return s.verificationUseToken(w, r, body, f)
}

if err := flow.MethodEnabledAndAllowed(r.Context(), f.GetFlowName(), s.VerificationStrategyID(), body.Method, s.d); err != nil {
return s.handleVerificationError(w, r, f, body, err)
return s.handleVerificationError(r, f, body, err)
}

if err := f.Valid(); err != nil {
return s.handleVerificationError(w, r, f, body, err)
return s.handleVerificationError(r, f, body, err)
}

switch f.State {
case flow.StateChooseMethod:
fallthrough
case flow.StateEmailSent:
// Do nothing (continue with execution after this switch statement)
return s.verificationHandleFormSubmission(w, r, f)
case flow.StateChooseMethod, flow.StateEmailSent:
return s.verificationHandleFormSubmission(r, f)
case flow.StatePassedChallenge:
return s.retryVerificationFlowWithMessage(w, r, f.Type, text.NewErrorValidationVerificationRetrySuccess())
default:
return s.retryVerificationFlowWithMessage(w, r, f.Type, text.NewErrorValidationVerificationStateFailure())
}
}

func (s *Strategy) verificationHandleFormSubmission(w http.ResponseWriter, r *http.Request, f *verification.Flow) error {
func (s *Strategy) verificationHandleFormSubmission(r *http.Request, f *verification.Flow) error {
body, err := s.decodeVerification(r)
if err != nil {
return s.handleVerificationError(w, r, f, body, err)
return s.handleVerificationError(r, f, body, err)
}

if len(body.Email) == 0 {
return s.handleVerificationError(w, r, f, body, schema.NewRequiredError("#/email", "email"))
return s.handleVerificationError(r, f, body, schema.NewRequiredError("#/email", "email"))
}

if err := flow.EnsureCSRF(s.d, r, f.Type, s.d.Config().DisableAPIFlowEnforcement(r.Context()), s.d.GenerateCSRFToken, body.CSRFToken); err != nil {
return s.handleVerificationError(w, r, f, body, err)
return s.handleVerificationError(r, f, body, err)
}

if err := s.d.LinkSender().SendVerificationLink(r.Context(), f, identity.VerifiableAddressTypeEmail, body.Email); err != nil {
if !errors.Is(err, ErrUnknownAddress) {
return s.handleVerificationError(w, r, f, body, err)
return s.handleVerificationError(r, f, body, err)
}
// Continue execution
}
Expand All @@ -196,7 +197,7 @@ func (s *Strategy) verificationHandleFormSubmission(w http.ResponseWriter, r *ht
f.State = flow.StateEmailSent
f.UI.Messages.Set(text.NewVerificationEmailSent())
if err := s.d.VerificationFlowPersister().UpdateVerificationFlow(r.Context(), f); err != nil {
return s.handleVerificationError(w, r, f, body, err)
return s.handleVerificationError(r, f, body, err)
}

return nil
Expand Down Expand Up @@ -268,12 +269,12 @@ func (s *Strategy) retryVerificationFlowWithMessage(w http.ResponseWriter, r *ht
f, err := verification.NewFlow(s.d.Config(),
s.d.Config().SelfServiceFlowVerificationRequestLifespan(r.Context()), s.d.CSRFHandler().RegenerateToken(w, r), r, s, ft)
if err != nil {
return s.handleVerificationError(w, r, f, nil, err)
return s.handleVerificationError(r, f, nil, err)
}

f.UI.Messages.Add(message)
if err := s.d.VerificationFlowPersister().CreateVerificationFlow(r.Context(), f); err != nil {
return s.handleVerificationError(w, r, f, nil, err)
return s.handleVerificationError(r, f, nil, err)
}

if ft == flow.TypeBrowser {
Expand All @@ -292,7 +293,7 @@ func (s *Strategy) retryVerificationFlowWithError(w http.ResponseWriter, r *http
f, err := verification.NewFlow(s.d.Config(),
s.d.Config().SelfServiceFlowVerificationRequestLifespan(r.Context()), s.d.CSRFHandler().RegenerateToken(w, r), r, s, ft)
if err != nil {
return s.handleVerificationError(w, r, f, nil, err)
return s.handleVerificationError(r, f, nil, err)
}

if expired := new(flow.ExpiredError); errors.As(verErr, &expired) {
Expand All @@ -304,7 +305,7 @@ func (s *Strategy) retryVerificationFlowWithError(w http.ResponseWriter, r *http
}

if err := s.d.VerificationFlowPersister().CreateVerificationFlow(r.Context(), f); err != nil {
return s.handleVerificationError(w, r, f, nil, err)
return s.handleVerificationError(r, f, nil, err)
}

if ft == flow.TypeBrowser {
Expand Down
Loading

0 comments on commit 81bc152

Please sign in to comment.