diff --git a/Makefile b/Makefile index d1eb8fae5ef1..2759e18aae2e 100644 --- a/Makefile +++ b/Makefile @@ -83,7 +83,13 @@ test-short: .PHONY: test-coverage test-coverage: .bin/go-acc .bin/goveralls - go-acc -o coverage.out ./... -- -v -failfast -timeout=20m -tags sqlite + go-acc -o coverage.out ./... -- -v -failfast -timeout=20m -tags sqlite,json1 + +.PHONY: test-coverage-next +test-coverage-next: .bin/go-acc .bin/goveralls + go test -short -failfast -timeout=20m -tags sqlite,json1 -cover ./... --args test.gocoverdir="$$PWD/coverage" + go tool covdata percent -i=coverage + go tool covdata textfmt -i=./coverage -o coverage.new.out # Generates the SDK .PHONY: sdk diff --git a/coverage/.gitignore b/coverage/.gitignore new file mode 100644 index 000000000000..72e8ffc0db8a --- /dev/null +++ b/coverage/.gitignore @@ -0,0 +1 @@ +* diff --git a/driver/registry_default_registration.go b/driver/registry_default_registration.go index 0f6f7c6f05ea..89ed5e656c74 100644 --- a/driver/registry_default_registration.go +++ b/driver/registry_default_registration.go @@ -39,7 +39,7 @@ func (m *RegistryDefault) PostRegistrationPostPersistHooks(ctx context.Context, } if len(b) == initialHookCount { - // since we don't want merging hooks defined in a specific strategy and global hooks + // since we don't want merging hooks defined in a specific strategy and // global hooks are added only if no strategy specific hooks are defined for _, v := range m.getHooks(config.HookGlobal, m.Config().SelfServiceFlowRegistrationAfterHooks(ctx, config.HookGlobal)) { if hook, ok := v.(registration.PostHookPostPersistExecutor); ok { diff --git a/identity/credentials.go b/identity/credentials.go index 29283b29aa66..6ccbe867c89a 100644 --- a/identity/credentials.go +++ b/identity/credentials.go @@ -168,12 +168,6 @@ type Credentials struct { // Identifiers represents a list of unique identifiers this credential type matches. Identifiers []string `json:"identifiers" db:"-"` - // IdentifierAddressType represents the type of the identifiers (e.g. email, phone). - // This is used to determine the correct courier to send messages to. - // The value is set by the code extension schema and is not persisted. - // only applicable on the login, registration with `code` method. - IdentifierAddressType CredentialsIdentifierAddressType `json:"-" db:"-"` - // Config contains the concrete credential payload. This might contain the bcrypt-hashed password, the email // for passwordless authentication or access_token and refresh tokens from OpenID Connect flows. Config sqlxx.JSONRawMessage `json:"config,omitempty" db:"config"` diff --git a/identity/credentials_code.go b/identity/credentials_code.go index b6fc4a14b4fc..184479ae1700 100644 --- a/identity/credentials_code.go +++ b/identity/credentials_code.go @@ -14,7 +14,7 @@ const ( CodeAddressTypePhone CodeAddressType = AddressTypePhone ) -// CredentialsCode represents a one time login/registraiton code +// CredentialsCode represents a one time login/registration code // // swagger:model identityCredentialsCode type CredentialsCode struct { diff --git a/identity/extension_credentials.go b/identity/extension_credentials.go index 7885abf10bce..69fb53810fbc 100644 --- a/identity/extension_credentials.go +++ b/identity/extension_credentials.go @@ -41,7 +41,6 @@ func (r *SchemaExtensionCredentials) setIdentifier(ct CredentialsType, value int r.v[ct] = stringslice.Unique(append(r.v[ct], strings.ToLower(fmt.Sprintf("%s", value)))) cred.Identifiers = r.v[ct] - cred.IdentifierAddressType = addressType r.i.SetCredentials(ct, *cred) } @@ -64,7 +63,7 @@ func (r *SchemaExtensionCredentials) Run(ctx jsonschema.ValidationContext, s sch return ctx.Error("format", "%q is not a valid %q", value, s.Credentials.Code.Via) } - r.setIdentifier(CredentialsTypeCodeAuth, value, CredentialsIdentifierAddressType(AddressTypeEmail)) + r.setIdentifier(CredentialsTypeCodeAuth, value, AddressTypeEmail) // case f.AddCase(AddressTypePhone): // if !jsonschema.Formats["tel"](value) { // return ctx.Error("format", "%q is not a valid %q", value, s.Credentials.Code.Via) diff --git a/identity/extension_credentials_test.go b/identity/extension_credentials_test.go index cf580c6a0c10..95cd9d000c6a 100644 --- a/identity/extension_credentials_test.go +++ b/identity/extension_credentials_test.go @@ -72,6 +72,21 @@ func TestSchemaExtensionCredentials(t *testing.T) { }, ct: identity.CredentialsTypeWebAuthn, }, + { + doc: `{"email":"foo@ory.sh"}`, + schema: "file://./stub/extension/credentials/code.schema.json", + expect: []string{"foo@ory.sh"}, + ct: identity.CredentialsTypeCodeAuth, + }, + { + doc: `{"email":"FOO@ory.sh"}`, + schema: "file://./stub/extension/credentials/code.schema.json", + expect: []string{"foo@ory.sh"}, + existing: &identity.Credentials{ + Identifiers: []string{"not-foo@ory.sh"}, + }, + ct: identity.CredentialsTypeCodeAuth, + }, } { t.Run(fmt.Sprintf("case=%d", k), func(t *testing.T) { c := jsonschema.NewCompiler() diff --git a/identity/stub/extension/credentials/code.schema.json b/identity/stub/extension/credentials/code.schema.json new file mode 100644 index 000000000000..bef244bc9ae5 --- /dev/null +++ b/identity/stub/extension/credentials/code.schema.json @@ -0,0 +1,20 @@ +{ + "type": "object", + "properties": { + "email": { + "type": "string", + "format": "email", + "ory.sh/kratos": { + "credentials": { + "password": { + "identifier": true + }, + "code": { + "identifier": true, + "via": "email" + } + } + } + } + } +} diff --git a/persistence/sql/persister_code.go b/persistence/sql/persister_code.go new file mode 100644 index 000000000000..3b8103a36361 --- /dev/null +++ b/persistence/sql/persister_code.go @@ -0,0 +1,123 @@ +// Copyright © 2023 Ory Corp +// SPDX-License-Identifier: Apache-2.0 + +package sql + +import ( + "context" + "crypto/subtle" + "fmt" + "time" + + "github.com/gobuffalo/pop/v6" + "github.com/gofrs/uuid" + "github.com/pkg/errors" + + "github.com/ory/kratos/selfservice/strategy/code" + "github.com/ory/x/sqlcon" +) + +type oneTimeCodeProvider interface { + GetID() uuid.UUID + Validate() error + TableName(ctx context.Context) string + GetHMACCode() string +} + +type codeOptions struct { + IdentityID *uuid.UUID +} + +type codeOption func(o *codeOptions) + +func withCheckIdentityID(id uuid.UUID) codeOption { + return func(o *codeOptions) { + o.IdentityID = &id + } +} + +func useOneTimeCode[P any, U interface { + *P + oneTimeCodeProvider +}](ctx context.Context, p *Persister, flowID uuid.UUID, userProvidedCode string, flowTableName string, foreignKeyName string, opts ...codeOption) (U, error) { + ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.useOneTimeCode") + defer span.End() + + o := new(codeOptions) + for _, opt := range opts { + opt(o) + } + + var target U + nid := p.NetworkID(ctx) + if err := p.Transaction(ctx, func(ctx context.Context, tx *pop.Connection) error { + //#nosec G201 -- TableName is static + if err := tx.RawQuery(fmt.Sprintf("UPDATE %s SET submit_count = submit_count + 1 WHERE id = ? AND nid = ?", flowTableName), flowID, nid).Exec(); err != nil { + return err + } + + var submitCount int + // Because MySQL does not support "RETURNING" clauses, but we need the updated `submit_count` later on. + //#nosec G201 -- TableName is static + if err := sqlcon.HandleError(tx.RawQuery(fmt.Sprintf("SELECT submit_count FROM %s WHERE id = ? AND nid = ?", flowTableName), flowID, nid).First(&submitCount)); err != nil { + if errors.Is(err, sqlcon.ErrNoRows) { + // Return no error, as that would roll back the transaction + return nil + } + return err + } + + // This check prevents parallel brute force attacks by checking the submit count inside this database + // transaction. If the flow has been submitted more than 5 times, the transaction is aborted (regardless of + // whether the code was correct or not) and we thus give no indication whether the supplied code was correct or + // not. For more explanation see [this comment](https://github.com/ory/kratos/pull/2645#discussion_r984732899). + if submitCount > 5 { + return errors.WithStack(code.ErrCodeSubmittedTooOften) + } + + var codes []U + codesQuery := tx.Where(fmt.Sprintf("nid = ? AND %s = ?", foreignKeyName), nid, flowID) + if o.IdentityID != nil { + codesQuery = codesQuery.Where("identity_id = ?", *o.IdentityID) + } + + if err := sqlcon.HandleError(codesQuery.All(&codes)); err != nil { + if errors.Is(err, sqlcon.ErrNoRows) { + // Return no error, as that would roll back the transaction and reset the submit count. + return nil + } + + return err + } + + secrets: + for _, secret := range p.r.Config().SecretsSession(ctx) { + suppliedCode := []byte(p.hmacValueWithSecret(ctx, userProvidedCode, secret)) + for i := range codes { + c := codes[i] + if subtle.ConstantTimeCompare([]byte(c.GetHMACCode()), suppliedCode) == 0 { + // Not the supplied code + continue + } + target = c + break secrets + } + } + + if target.Validate() != nil { + // Return no error, as that would roll back the transaction + return nil + } + + //#nosec G201 -- TableName is static + return tx.RawQuery(fmt.Sprintf("UPDATE %s SET used_at = ? WHERE id = ? AND nid = ?", target.TableName(ctx)), time.Now().UTC(), target.GetID(), nid).Exec() + }); err != nil { + return nil, sqlcon.HandleError(err) + } + + if err := target.Validate(); err != nil { + return nil, err + } + + return target, nil +} diff --git a/persistence/sql/persister_login.go b/persistence/sql/persister_login.go index d34a832167f5..ec1da55babbb 100644 --- a/persistence/sql/persister_login.go +++ b/persistence/sql/persister_login.go @@ -5,21 +5,16 @@ package sql import ( "context" - "crypto/subtle" "fmt" "time" "github.com/gobuffalo/pop/v6" - "github.com/pkg/errors" - "github.com/gofrs/uuid" "github.com/ory/x/sqlcon" "github.com/ory/kratos/persistence/sql/update" - "github.com/ory/kratos/selfservice/flow" "github.com/ory/kratos/selfservice/flow/login" - "github.com/ory/kratos/selfservice/strategy/code" ) var _ login.FlowPersister = new(Persister) @@ -88,123 +83,3 @@ func (p *Persister) DeleteExpiredLoginFlows(ctx context.Context, expiresAt time. } return nil } - -func (p *Persister) CreateLoginCode(ctx context.Context, codeParams *code.CreateLoginCodeParams) (*code.LoginCode, error) { - ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.CreateLoginCode") - defer span.End() - - now := time.Now().UTC() - loginCode := &code.LoginCode{ - IdentityID: codeParams.IdentityID, - Address: codeParams.Address, - AddressType: codeParams.AddressType, - CodeHMAC: p.hmacValue(ctx, codeParams.RawCode), - IssuedAt: now, - ExpiresAt: now.UTC().Add(p.r.Config().SelfServiceCodeMethodLifespan(ctx)), - FlowID: codeParams.FlowID, - NID: p.NetworkID(ctx), - ID: uuid.Nil, - } - - if err := p.GetConnection(ctx).Create(loginCode); err != nil { - return nil, sqlcon.HandleError(err) - } - return loginCode, nil -} - -func (p *Persister) UseLoginCode(ctx context.Context, flowID uuid.UUID, identityID uuid.UUID, codeVal string) (*code.LoginCode, error) { - ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.UseLoginCode") - defer span.End() - - var loginCode *code.LoginCode - - nid := p.NetworkID(ctx) - flowTableName := new(login.Flow).TableName(ctx) - - if err := sqlcon.HandleError(p.Transaction(ctx, func(ctx context.Context, tx *pop.Connection) (err error) { - //#nosec G201 -- TableName is static - if err := sqlcon.HandleError(tx.RawQuery(fmt.Sprintf("UPDATE %s SET submit_count = submit_count + 1 WHERE id = ? AND nid = ?", flowTableName), flowID, nid).Exec()); err != nil { - return err - } - - var submitCount int - // Because MySQL does not support "RETURNING" clauses, but we need the updated `submit_count` later on. - //#nosec G201 -- TableName is static - if err := sqlcon.HandleError(tx.RawQuery(fmt.Sprintf("SELECT submit_count FROM %s WHERE id = ? AND nid = ?", flowTableName), flowID, nid).First(&submitCount)); err != nil { - if errors.Is(err, sqlcon.ErrNoRows) { - // Return no error, as that would roll back the transaction - return nil - } - return err - } - - if submitCount > 5 { - return errors.WithStack(code.ErrCodeSubmittedTooOften) - } - - var loginCodes []code.LoginCode - if err = sqlcon.HandleError(tx.Where("nid = ? AND selfservice_login_flow_id = ? AND identity_id = ?", nid, flowID, identityID).All(&loginCodes)); err != nil { - if errors.Is(err, sqlcon.ErrNoRows) { - return err - } - return nil - } - - secrets: - for _, secret := range p.r.Config().SecretsSession(ctx) { - suppliedCode := []byte(p.hmacValueWithSecret(ctx, codeVal, secret)) - for i := range loginCodes { - code := loginCodes[i] - if subtle.ConstantTimeCompare([]byte(code.CodeHMAC), suppliedCode) == 0 { - // Not the supplied code - continue - } - loginCode = &code - break secrets - } - } - - if loginCode == nil || !loginCode.IsValid() { - // Return no error, as that would roll back the transaction - return nil - } - - //#nosec G201 -- TableName is static - return sqlcon.HandleError(tx.RawQuery(fmt.Sprintf("UPDATE %s SET used_at = ? WHERE id = ? AND nid = ?", loginCode.TableName(ctx)), time.Now().UTC(), loginCode.ID, nid).Exec()) - })); err != nil { - return nil, err - } - - if loginCode == nil { - return nil, errors.WithStack(code.ErrCodeNotFound) - } - - if loginCode.IsExpired() { - return nil, errors.WithStack(flow.NewFlowExpiredError(loginCode.ExpiresAt)) - } - - if loginCode.WasUsed() { - return nil, errors.WithStack(code.ErrCodeAlreadyUsed) - } - - return loginCode, nil -} - -func (p *Persister) DeleteLoginCodesOfFlow(ctx context.Context, flowID uuid.UUID) error { - ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.DeleteLoginCodesOfFlow") - defer span.End() - - //#nosec G201 -- TableName is static - return p.GetConnection(ctx).RawQuery(fmt.Sprintf("DELETE FROM %s WHERE selfservice_login_flow_id = ? AND nid = ?", new(code.LoginCode).TableName(ctx)), flowID, p.NetworkID(ctx)).Exec() -} - -func (p *Persister) GetUsedLoginCode(ctx context.Context, flowID uuid.UUID) (*code.LoginCode, error) { - ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.GetUsedLoginCode") - defer span.End() - - var loginCode code.LoginCode - if err := p.Connection(ctx).RawQuery(fmt.Sprintf("SELECT * FROM %s WHERE selfservice_login_flow_id = ? AND nid = ? AND used_at IS NOT NULL", new(code.LoginCode).TableName(ctx)), flowID, p.NetworkID(ctx)).First(&loginCode); err != nil { - return nil, sqlcon.HandleError(err) - } - return &loginCode, nil -} diff --git a/persistence/sql/persister_login_code.go b/persistence/sql/persister_login_code.go new file mode 100644 index 000000000000..3d5dd027826d --- /dev/null +++ b/persistence/sql/persister_login_code.go @@ -0,0 +1,69 @@ +// Copyright © 2023 Ory Corp +// SPDX-License-Identifier: Apache-2.0 + +package sql + +import ( + "context" + "time" + + "github.com/gofrs/uuid" + + "github.com/ory/kratos/selfservice/flow/login" + "github.com/ory/kratos/selfservice/strategy/code" + "github.com/ory/x/sqlcon" +) + +func (p *Persister) CreateLoginCode(ctx context.Context, params *code.CreateLoginCodeParams) (*code.LoginCode, error) { + ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.CreateLoginCode") + defer span.End() + + now := time.Now().UTC() + loginCode := &code.LoginCode{ + IdentityID: params.IdentityID, + Address: params.Address, + AddressType: params.AddressType, + CodeHMAC: p.hmacValue(ctx, params.RawCode), + IssuedAt: now, + ExpiresAt: now.UTC().Add(p.r.Config().SelfServiceCodeMethodLifespan(ctx)), + FlowID: params.FlowID, + NID: p.NetworkID(ctx), + ID: uuid.Nil, + } + + if err := p.GetConnection(ctx).Create(loginCode); err != nil { + return nil, sqlcon.HandleError(err) + } + + return loginCode, nil +} + +func (p *Persister) UseLoginCode(ctx context.Context, flowID uuid.UUID, identityID uuid.UUID, userProvidedCode string) (*code.LoginCode, error) { + ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.UseLoginCode") + defer span.End() + + codeRow, err := useOneTimeCode[code.LoginCode, *code.LoginCode](ctx, p, flowID, userProvidedCode, new(login.Flow).TableName(ctx), "selfservice_login_flow_id", withCheckIdentityID(identityID)) + if err != nil { + return nil, err + } + + return codeRow, nil +} + +func (p *Persister) GetUsedLoginCode(ctx context.Context, flowID uuid.UUID) (*code.LoginCode, error) { + ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.GetUsedLoginCode") + defer span.End() + + var loginCode code.LoginCode + if err := p.Connection(ctx).Where("selfservice_login_flow_id = ? AND nid = ? AND used_at IS NOT NULL", flowID, p.NetworkID(ctx)).First(&loginCode); err != nil { + return nil, sqlcon.HandleError(err) + } + return &loginCode, nil +} + +func (p *Persister) DeleteLoginCodesOfFlow(ctx context.Context, flowID uuid.UUID) error { + ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.DeleteLoginCodesOfFlow") + defer span.End() + + return p.GetConnection(ctx).Where("selfservice_login_flow_id = ? AND nid = ?", flowID, p.NetworkID(ctx)).Delete(&code.LoginCode{}) +} diff --git a/persistence/sql/persister_recovery.go b/persistence/sql/persister_recovery.go index 34539832d254..8ac81cd009a5 100644 --- a/persistence/sql/persister_recovery.go +++ b/persistence/sql/persister_recovery.go @@ -5,7 +5,6 @@ package sql import ( "context" - "crypto/subtle" "fmt" "time" @@ -16,9 +15,7 @@ import ( "github.com/ory/kratos/identity" "github.com/ory/kratos/persistence/sql/update" - "github.com/ory/kratos/selfservice/flow" "github.com/ory/kratos/selfservice/flow/recovery" - "github.com/ory/kratos/selfservice/strategy/code" "github.com/ory/kratos/selfservice/strategy/link" "github.com/ory/x/sqlcon" ) @@ -137,144 +134,3 @@ func (p *Persister) DeleteExpiredRecoveryFlows(ctx context.Context, expiresAt ti } return nil } - -func (p *Persister) CreateRecoveryCode(ctx context.Context, dto *code.CreateRecoveryCodeParams) (*code.RecoveryCode, error) { - ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.CreateRecoveryCode") - defer span.End() - - now := time.Now() - - recoveryCode := &code.RecoveryCode{ - ID: uuid.Nil, - CodeHMAC: p.hmacValue(ctx, dto.RawCode), - ExpiresAt: now.UTC().Add(dto.ExpiresIn), - IssuedAt: now, - CodeType: dto.CodeType, - FlowID: dto.FlowID, - NID: p.NetworkID(ctx), - IdentityID: dto.IdentityID, - } - - if dto.RecoveryAddress != nil { - recoveryCode.RecoveryAddress = dto.RecoveryAddress - recoveryCode.RecoveryAddressID = uuid.NullUUID{ - UUID: dto.RecoveryAddress.ID, - Valid: true, - } - } - - // This should not create the request eagerly because otherwise we might accidentally create an address that isn't - // supposed to be in the database. - if err := p.GetConnection(ctx).Create(recoveryCode); err != nil { - return nil, err - } - - return recoveryCode, nil -} - -// UseRecoveryCode attempts to "use" the supplied code in the flow -// -// If the supplied code matched a code from the flow, no error is returned -// If an invalid code was submitted with this flow more than 5 times, an error is returned -// TODO: Extract the business logic to a new service/manager (https://github.com/ory/kratos/issues/2785) -func (p *Persister) UseRecoveryCode(ctx context.Context, fID uuid.UUID, codeVal string) (*code.RecoveryCode, error) { - ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.UseRecoveryCode") - defer span.End() - - var recoveryCode *code.RecoveryCode - - nid := p.NetworkID(ctx) - - flowTableName := new(recovery.Flow).TableName(ctx) - - if err := sqlcon.HandleError(p.Transaction(ctx, func(ctx context.Context, tx *pop.Connection) (err error) { - //#nosec G201 -- TableName is static - if err := sqlcon.HandleError(tx.RawQuery(fmt.Sprintf("UPDATE %s SET submit_count = submit_count + 1 WHERE id = ? AND nid = ?", flowTableName), fID, nid).Exec()); err != nil { - return err - } - - var submitCount int - // Because MySQL does not support "RETURNING" clauses, but we need the updated `submit_count` later on. - //#nosec G201 -- TableName is static - if err := sqlcon.HandleError(tx.RawQuery(fmt.Sprintf("SELECT submit_count FROM %s WHERE id = ? AND nid = ?", flowTableName), fID, nid).First(&submitCount)); err != nil { - if errors.Is(err, sqlcon.ErrNoRows) { - // Return no error, as that would roll back the transaction - return nil - } - - return err - } - - // This check prevents parallel brute force attacks to generate the recovery code - // by checking the submit count inside this database transaction. - // If the flow has been submitted more than 5 times, the transaction is aborted (regardless of whether the code was correct or not) - // and we thus give no indication whether the supplied code was correct or not. See also https://github.com/ory/kratos/pull/2645#discussion_r984732899 - if submitCount > 5 { - return errors.WithStack(code.ErrCodeSubmittedTooOften) - } - - var recoveryCodes []code.RecoveryCode - if err = sqlcon.HandleError(tx.Where("nid = ? AND selfservice_recovery_flow_id = ?", nid, fID).All(&recoveryCodes)); err != nil { - if errors.Is(err, sqlcon.ErrNoRows) { - // Return no error, as that would roll back the transaction - return nil - } - - return err - } - - secrets: - for _, secret := range p.r.Config().SecretsSession(ctx) { - suppliedCode := []byte(p.hmacValueWithSecret(ctx, codeVal, secret)) - for i := range recoveryCodes { - code := recoveryCodes[i] - if subtle.ConstantTimeCompare([]byte(code.CodeHMAC), suppliedCode) == 0 { - // Not the supplied code - continue - } - recoveryCode = &code - break secrets - } - } - - if recoveryCode == nil || !recoveryCode.IsValid() { - // Return no error, as that would roll back the transaction - return nil - } - - var ra identity.RecoveryAddress - if err := tx.Where("id = ? AND nid = ?", recoveryCode.RecoveryAddressID, nid).First(&ra); err != nil { - if err = sqlcon.HandleError(err); !errors.Is(err, sqlcon.ErrNoRows) { - return err - } - } - recoveryCode.RecoveryAddress = &ra - - //#nosec G201 -- TableName is static - return sqlcon.HandleError(tx.RawQuery(fmt.Sprintf("UPDATE %s SET used_at = ? WHERE id = ? AND nid = ?", recoveryCode.TableName(ctx)), time.Now().UTC(), recoveryCode.ID, nid).Exec()) - })); err != nil { - return nil, err - } - - if recoveryCode == nil { - return nil, errors.WithStack(code.ErrCodeNotFound) - } - - if recoveryCode.IsExpired() { - return nil, errors.WithStack(flow.NewFlowExpiredError(recoveryCode.ExpiresAt)) - } - - if recoveryCode.WasUsed() { - return nil, errors.WithStack(code.ErrCodeAlreadyUsed) - } - - return recoveryCode, nil -} - -func (p *Persister) DeleteRecoveryCodesOfFlow(ctx context.Context, fID uuid.UUID) error { - ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.DeleteRecoveryCodesOfFlow") - defer span.End() - - //#nosec G201 -- TableName is static - return p.GetConnection(ctx).RawQuery(fmt.Sprintf("DELETE FROM %s WHERE selfservice_recovery_flow_id = ? AND nid = ?", new(code.RecoveryCode).TableName(ctx)), fID, p.NetworkID(ctx)).Exec() -} diff --git a/persistence/sql/persister_recovery_code.go b/persistence/sql/persister_recovery_code.go new file mode 100644 index 000000000000..725b9578a205 --- /dev/null +++ b/persistence/sql/persister_recovery_code.go @@ -0,0 +1,84 @@ +// Copyright © 2023 Ory Corp +// SPDX-License-Identifier: Apache-2.0 + +package sql + +import ( + "context" + "time" + + "github.com/gofrs/uuid" + "github.com/pkg/errors" + + "github.com/ory/kratos/identity" + "github.com/ory/kratos/selfservice/flow/recovery" + "github.com/ory/kratos/selfservice/strategy/code" + "github.com/ory/x/sqlcon" +) + +func (p *Persister) CreateRecoveryCode(ctx context.Context, params *code.CreateRecoveryCodeParams) (*code.RecoveryCode, error) { + ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.CreateRecoveryCode") + defer span.End() + + now := time.Now() + recoveryCode := &code.RecoveryCode{ + ID: uuid.Nil, + CodeHMAC: p.hmacValue(ctx, params.RawCode), + ExpiresAt: now.UTC().Add(params.ExpiresIn), + IssuedAt: now, + CodeType: params.CodeType, + FlowID: params.FlowID, + NID: p.NetworkID(ctx), + IdentityID: params.IdentityID, + } + + if params.RecoveryAddress != nil { + recoveryCode.RecoveryAddress = params.RecoveryAddress + recoveryCode.RecoveryAddressID = uuid.NullUUID{ + UUID: params.RecoveryAddress.ID, + Valid: true, + } + } + + // This should not create the request eagerly because otherwise we might accidentally create an address that isn't + // supposed to be in the database. + if err := p.GetConnection(ctx).Create(recoveryCode); err != nil { + return nil, err + } + + return recoveryCode, nil +} + +// UseRecoveryCode attempts to "use" the supplied code in the flow +// +// If the supplied code matched a code from the flow, no error is returned +// If an invalid code was submitted with this flow more than 5 times, an error is returned +func (p *Persister) UseRecoveryCode(ctx context.Context, flowID uuid.UUID, userProvidedCode string) (*code.RecoveryCode, error) { + ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.UseRecoveryCode") + defer span.End() + + codeRow, err := useOneTimeCode[code.RecoveryCode, *code.RecoveryCode](ctx, p, flowID, userProvidedCode, new(recovery.Flow).TableName(ctx), "selfservice_recovery_flow_id") + if err != nil { + return nil, err + } + + var ra identity.RecoveryAddress + if err := sqlcon.HandleError(p.GetConnection(ctx).Where("id = ? AND nid = ?", codeRow.RecoveryAddressID, p.NetworkID(ctx)).First(&ra)); err != nil { + if errors.Is(err, sqlcon.ErrNoRows) { + // This is ok, it can happen when an administrator initiates account recovery. This works even if the + // user has no recovery address! + } else { + return nil, err + } + } + codeRow.RecoveryAddress = &ra + + return codeRow, nil +} + +func (p *Persister) DeleteRecoveryCodesOfFlow(ctx context.Context, flowID uuid.UUID) error { + ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.DeleteRecoveryCodesOfFlow") + defer span.End() + + return p.GetConnection(ctx).Where("selfservice_recovery_flow_id = ? AND nid = ?", flowID, p.NetworkID(ctx)).Delete(&code.RecoveryCode{}) +} diff --git a/persistence/sql/persister_registration.go b/persistence/sql/persister_registration.go index 7bd665fcedf0..fe7e25ceeac3 100644 --- a/persistence/sql/persister_registration.go +++ b/persistence/sql/persister_registration.go @@ -5,21 +5,15 @@ package sql import ( "context" - "crypto/subtle" "fmt" "time" - "github.com/bxcodec/faker/v3/support/slice" - "github.com/gobuffalo/pop/v6" "github.com/gofrs/uuid" - "github.com/pkg/errors" "github.com/ory/x/sqlcon" "github.com/ory/kratos/persistence/sql/update" - "github.com/ory/kratos/selfservice/flow" "github.com/ory/kratos/selfservice/flow/registration" - "github.com/ory/kratos/selfservice/strategy/code" ) func (p *Persister) CreateRegistrationFlow(ctx context.Context, r *registration.Flow) error { @@ -70,133 +64,3 @@ func (p *Persister) DeleteExpiredRegistrationFlows(ctx context.Context, expiresA } return nil } - -func (p *Persister) CreateRegistrationCode(ctx context.Context, codeParams *code.CreateRegistrationCodeParams) (*code.RegistrationCode, error) { - ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.CreateRegistrationCode") - defer span.End() - - now := time.Now() - - registrationCode := &code.RegistrationCode{ - Address: codeParams.Address, - AddressType: codeParams.AddressType, - CodeHMAC: p.hmacValue(ctx, codeParams.RawCode), - IssuedAt: now, - ExpiresAt: now.UTC().Add(p.r.Config().SelfServiceCodeMethodLifespan(ctx)), - FlowID: codeParams.FlowID, - NID: p.NetworkID(ctx), - ID: uuid.Nil, - } - - if err := p.GetConnection(ctx).Create(registrationCode); err != nil { - return nil, sqlcon.HandleError(err) - } - return registrationCode, nil -} - -func (p *Persister) UseRegistrationCode(ctx context.Context, flowID uuid.UUID, rawCode string, addresses ...string) (*code.RegistrationCode, error) { - ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.UseRegistrationCode") - defer span.End() - - nid := p.NetworkID(ctx) - - flowTableName := new(registration.Flow).TableName(ctx) - - var registrationCode *code.RegistrationCode - if err := sqlcon.HandleError(p.GetConnection(ctx).Transaction(func(tx *pop.Connection) error { - if err := tx.RawQuery(fmt.Sprintf("UPDATE %s SET submit_count = submit_count + 1 WHERE id = ? AND nid = ?", flowTableName), flowID, nid).Exec(); err != nil { - return err - } - - var submitCount int - // Because MySQL does not support "RETURNING" clauses, but we need the updated `submit_count` later on. - //#nosec G201 -- TableName is static - if err := sqlcon.HandleError(tx.RawQuery(fmt.Sprintf("SELECT submit_count FROM %s WHERE id = ? AND nid = ?", flowTableName), flowID, nid).First(&submitCount)); err != nil { - if errors.Is(err, sqlcon.ErrNoRows) { - // Return no error, as that would roll back the transaction - return nil - } - return err - } - - // This check prevents parallel brute force attacks to generate the recovery code - // by checking the submit count inside this database transaction. - // If the flow has been submitted more than 5 times, the transaction is aborted (regardless of whether the code was correct or not) - // and we thus give no indication whether the supplied code was correct or not. See also https://github.com/ory/kratos/pull/2645#discussion_r984732899 - if submitCount > 5 { - return errors.WithStack(code.ErrCodeSubmittedTooOften) - } - - var registrationCodes []code.RegistrationCode - if err := sqlcon.HandleError(tx.Where("nid = ? AND selfservice_registration_flow_id = ?", nid, flowID).All(®istrationCodes)); err != nil { - if errors.Is(err, sqlcon.ErrNoRows) { - // Return no error, as that would roll back the transaction - return nil - } - - return err - } - - secrets: - for _, secret := range p.r.Config().SecretsSession(ctx) { - suppliedCode := []byte(p.hmacValueWithSecret(ctx, rawCode, secret)) - for i := range registrationCodes { - code := registrationCodes[i] - if subtle.ConstantTimeCompare([]byte(code.CodeHMAC), suppliedCode) == 0 { - // Not the supplied code - continue - } - registrationCode = &code - break secrets - } - } - - if registrationCode == nil || !registrationCode.IsValid() { - // Return no error, as that would roll back the transaction - return nil - } - - //#nosec G201 -- TableName is static - return sqlcon.HandleError(tx.RawQuery(fmt.Sprintf("UPDATE %s SET used_at = ? WHERE id = ? AND nid = ?", registrationCode.TableName(ctx)), time.Now().UTC(), registrationCode.ID, nid).Exec()) - })); err != nil { - return nil, err - } - - if registrationCode == nil { - return nil, errors.WithStack(code.ErrCodeNotFound) - } - - if registrationCode.IsExpired() { - return nil, errors.WithStack(flow.NewFlowExpiredError(registrationCode.ExpiresAt)) - } - - if registrationCode.WasUsed() { - return nil, errors.WithStack(code.ErrCodeAlreadyUsed) - } - - // ensure that the identifiers extracted from the traits are contained in the registration code - if !slice.Contains(addresses, registrationCode.Address) { - return nil, errors.WithStack(code.ErrCodeNotFound) - } - - return registrationCode, nil -} - -func (p *Persister) DeleteRegistrationCodesOfFlow(ctx context.Context, flowID uuid.UUID) error { - ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.DeleteRegistrationCodesOfFlow") - defer span.End() - - //#nosec G201 -- TableName is static - return p.GetConnection(ctx).RawQuery(fmt.Sprintf("DELETE FROM %s WHERE selfservice_registration_flow_id = ? AND nid = ?", new(code.RegistrationCode).TableName(ctx)), flowID, p.NetworkID(ctx)).Exec() -} - -func (p *Persister) GetUsedRegistrationCode(ctx context.Context, flowID uuid.UUID) (*code.RegistrationCode, error) { - ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.GetUsedRegistrationCode") - defer span.End() - - var registrationCode code.RegistrationCode - if err := p.Connection(ctx).RawQuery(fmt.Sprintf("SELECT * FROM %s WHERE selfservice_registration_flow_id = ? AND used_at IS NOT NULL AND nid = ?", new(code.RegistrationCode).TableName(ctx)), flowID, p.NetworkID(ctx)).First(®istrationCode); err != nil { - return nil, sqlcon.HandleError(err) - } - return ®istrationCode, nil -} diff --git a/persistence/sql/persister_registration_code.go b/persistence/sql/persister_registration_code.go new file mode 100644 index 000000000000..5c9ac909838c --- /dev/null +++ b/persistence/sql/persister_registration_code.go @@ -0,0 +1,76 @@ +// Copyright © 2023 Ory Corp +// SPDX-License-Identifier: Apache-2.0 + +package sql + +import ( + "context" + "time" + + "github.com/bxcodec/faker/v3/support/slice" + "github.com/gofrs/uuid" + "github.com/pkg/errors" + + "github.com/ory/kratos/selfservice/flow/registration" + "github.com/ory/kratos/selfservice/strategy/code" + "github.com/ory/x/sqlcon" +) + +func (p *Persister) CreateRegistrationCode(ctx context.Context, params *code.CreateRegistrationCodeParams) (*code.RegistrationCode, error) { + ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.CreateRegistrationCode") + defer span.End() + + now := time.Now().UTC() + registrationCode := &code.RegistrationCode{ + Address: params.Address, + AddressType: params.AddressType, + CodeHMAC: p.hmacValue(ctx, params.RawCode), + IssuedAt: now, + ExpiresAt: now.UTC().Add(p.r.Config().SelfServiceCodeMethodLifespan(ctx)), + FlowID: params.FlowID, + NID: p.NetworkID(ctx), + ID: uuid.Nil, + } + + if err := p.GetConnection(ctx).Create(registrationCode); err != nil { + return nil, sqlcon.HandleError(err) + } + + return registrationCode, nil +} + +func (p *Persister) UseRegistrationCode(ctx context.Context, flowID uuid.UUID, userProvidedCode string, addresses ...string) (*code.RegistrationCode, error) { + ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.UseRegistrationCode") + defer span.End() + + codeRow, err := useOneTimeCode[code.RegistrationCode, *code.RegistrationCode](ctx, p, flowID, userProvidedCode, new(registration.Flow).TableName(ctx), "selfservice_registration_flow_id") + if err != nil { + return nil, err + } + + // ensure that the identifiers extracted from the traits are contained in the registration code + if !slice.Contains(addresses, codeRow.Address) { + return nil, errors.WithStack(code.ErrCodeNotFound) + } + + return codeRow, nil +} + +func (p *Persister) GetUsedRegistrationCode(ctx context.Context, flowID uuid.UUID) (*code.RegistrationCode, error) { + ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.GetUsedRegistrationCode") + defer span.End() + + var registrationCode code.RegistrationCode + if err := p.Connection(ctx).Where("selfservice_registration_flow_id = ? AND used_at IS NOT NULL AND nid = ?", flowID, p.NetworkID(ctx)).First(®istrationCode); err != nil { + return nil, sqlcon.HandleError(err) + } + + return ®istrationCode, nil +} + +func (p *Persister) DeleteRegistrationCodesOfFlow(ctx context.Context, flowID uuid.UUID) error { + ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.DeleteRegistrationCodesOfFlow") + defer span.End() + + return p.GetConnection(ctx).Where("selfservice_registration_flow_id = ? AND nid = ?", flowID, p.NetworkID(ctx)).Delete(&code.RegistrationCode{}) +} diff --git a/persistence/sql/persister_verification.go b/persistence/sql/persister_verification.go index 9892f41e9fc0..b2f19f94726b 100644 --- a/persistence/sql/persister_verification.go +++ b/persistence/sql/persister_verification.go @@ -5,13 +5,11 @@ package sql import ( "context" - "crypto/subtle" "fmt" "time" "github.com/pkg/errors" - "github.com/ory/herodot" "github.com/ory/kratos/identity" "github.com/ory/kratos/persistence/sql/update" @@ -21,7 +19,6 @@ import ( "github.com/ory/x/sqlcon" "github.com/ory/kratos/selfservice/flow/verification" - "github.com/ory/kratos/selfservice/strategy/code" "github.com/ory/kratos/selfservice/strategy/link" ) @@ -137,154 +134,3 @@ func (p *Persister) DeleteExpiredVerificationFlows(ctx context.Context, expiresA } return nil } -func (p *Persister) UseVerificationCode(ctx context.Context, fID uuid.UUID, codeVal string) (*code.VerificationCode, error) { - ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.UseVerificationCode") - defer span.End() - - var verificationCode *code.VerificationCode - - nid := p.NetworkID(ctx) - - flowTableName := new(verification.Flow).TableName(ctx) - - if err := sqlcon.HandleError(p.Transaction(ctx, func(ctx context.Context, tx *pop.Connection) (err error) { - - if err := sqlcon.HandleError( - tx.RawQuery( - //#nosec G201 -- TableName is static - fmt.Sprintf("UPDATE %s SET submit_count = submit_count + 1 WHERE id = ? AND nid = ?", flowTableName), - fID, - nid, - ).Exec(), - ); err != nil { - return err - } - - var submitCount int - // Because MySQL does not support "RETURNING" clauses, but we need the updated `submit_count` later on. - if err := sqlcon.HandleError( - tx.RawQuery( - //#nosec G201 -- TableName is static - fmt.Sprintf("SELECT submit_count FROM %s WHERE id = ? AND nid = ?", flowTableName), - fID, - nid, - ).First(&submitCount), - ); err != nil { - if errors.Is(err, sqlcon.ErrNoRows) { - // Return no error, as that would roll back the transaction - return nil - } - - return err - } - // This check prevents parallel brute force attacks to generate the verification code - // by checking the submit count inside this database transaction. - // If the flow has been submitted more than 5 times, the transaction is aborted (regardless of whether the code was correct or not) - // and we thus give no indication whether the supplied code was correct or not. See also https://github.com/ory/kratos/pull/2645#discussion_r984732899 - if submitCount > 5 { - return errors.WithStack(code.ErrCodeSubmittedTooOften) - } - - var verificationCodes []code.VerificationCode - if err = sqlcon.HandleError( - tx.Where("nid = ? AND selfservice_verification_flow_id = ?", nid, fID). - All(&verificationCodes), - ); err != nil { - if errors.Is(err, sqlcon.ErrNoRows) { - // Return no error, as that would roll back the transaction - return nil - } - - return err - } - - secrets: - for _, secret := range p.r.Config().SecretsSession(ctx) { - suppliedCode := []byte(p.hmacValueWithSecret(ctx, codeVal, secret)) - for i := range verificationCodes { - code := verificationCodes[i] - if subtle.ConstantTimeCompare([]byte(code.CodeHMAC), suppliedCode) == 0 { - // Not the supplied code - continue - } - verificationCode = &code - break secrets - } - } - - if verificationCode == nil || verificationCode.Validate() != nil { - // Return no error, as that would roll back the transaction - return nil - } - - var va identity.VerifiableAddress - if err := tx.Where("id = ? AND nid = ?", verificationCode.VerifiableAddressID, nid).First(&va); err != nil { - return sqlcon.HandleError(err) - } - - verificationCode.VerifiableAddress = &va - - //#nosec G201 -- TableName is static - return tx. - RawQuery( - fmt.Sprintf("UPDATE %s SET used_at = ? WHERE id = ? AND nid = ?", verificationCode.TableName(ctx)), - time.Now().UTC(), - verificationCode.ID, - nid, - ).Exec() - })); err != nil { - return nil, err - } - - if verificationCode == nil { - return nil, errors.WithStack(code.ErrCodeNotFound) - } - - return verificationCode, nil -} - -func (p *Persister) CreateVerificationCode(ctx context.Context, c *code.CreateVerificationCodeParams) (*code.VerificationCode, error) { - ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.CreateVerificationCode") - defer span.End() - - now := time.Now().UTC() - - verificationCode := &code.VerificationCode{ - ID: uuid.Nil, - CodeHMAC: p.hmacValue(ctx, c.RawCode), - ExpiresAt: now.Add(c.ExpiresIn), - IssuedAt: now, - FlowID: c.FlowID, - NID: p.NetworkID(ctx), - } - - if c.VerifiableAddress == nil { - return nil, errors.WithStack(herodot.ErrNotFound.WithReason("can't create a verification code without a verifiable address")) - } - - verificationCode.VerifiableAddress = c.VerifiableAddress - verificationCode.VerifiableAddressID = uuid.NullUUID{ - UUID: c.VerifiableAddress.ID, - Valid: true, - } - - // This should not create the request eagerly because otherwise we might accidentally create an address that isn't - // supposed to be in the database. - if err := p.GetConnection(ctx).Create(verificationCode); err != nil { - return nil, err - } - return verificationCode, nil -} - -func (p *Persister) DeleteVerificationCodesOfFlow(ctx context.Context, fID uuid.UUID) error { - ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.DeleteVerificationCodesOfFlow") - defer span.End() - - //#nosec G201 -- TableName is static - return p.GetConnection(ctx). - RawQuery( - fmt.Sprintf("DELETE FROM %s WHERE selfservice_verification_flow_id = ? AND nid = ?", new(code.VerificationCode).TableName(ctx)), - fID, - p.NetworkID(ctx), - ).Exec() -} diff --git a/persistence/sql/persister_verification_code.go b/persistence/sql/persister_verification_code.go new file mode 100644 index 000000000000..0b469bad07ce --- /dev/null +++ b/persistence/sql/persister_verification_code.go @@ -0,0 +1,77 @@ +// Copyright © 2023 Ory Corp +// SPDX-License-Identifier: Apache-2.0 + +package sql + +import ( + "context" + "time" + + "github.com/gofrs/uuid" + "github.com/pkg/errors" + + "github.com/ory/herodot" + "github.com/ory/kratos/identity" + "github.com/ory/kratos/selfservice/flow/verification" + "github.com/ory/kratos/selfservice/strategy/code" + "github.com/ory/x/sqlcon" +) + +func (p *Persister) CreateVerificationCode(ctx context.Context, params *code.CreateVerificationCodeParams) (*code.VerificationCode, error) { + ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.CreateVerificationCode") + defer span.End() + + now := time.Now().UTC() + verificationCode := &code.VerificationCode{ + ID: uuid.Nil, + CodeHMAC: p.hmacValue(ctx, params.RawCode), + ExpiresAt: now.Add(params.ExpiresIn), + IssuedAt: now, + FlowID: params.FlowID, + NID: p.NetworkID(ctx), + } + + if params.VerifiableAddress == nil { + return nil, errors.WithStack(herodot.ErrNotFound.WithReason("can't create a verification code without a verifiable address")) + } + + verificationCode.VerifiableAddress = params.VerifiableAddress + verificationCode.VerifiableAddressID = uuid.NullUUID{ + UUID: params.VerifiableAddress.ID, + Valid: true, + } + + // This should not create the request eagerly because otherwise we might accidentally create an address that isn't + // supposed to be in the database. + if err := p.GetConnection(ctx).Create(verificationCode); err != nil { + return nil, err + } + + return verificationCode, nil +} + +func (p *Persister) UseVerificationCode(ctx context.Context, flowID uuid.UUID, userProvidedCode string) (*code.VerificationCode, error) { + ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.UseVerificationCode") + defer span.End() + + codeRow, err := useOneTimeCode[code.VerificationCode, *code.VerificationCode](ctx, p, flowID, userProvidedCode, new(verification.Flow).TableName(ctx), "selfservice_verification_flow_id") + if err != nil { + return nil, err + } + + var va identity.VerifiableAddress + if err := p.Connection(ctx).Where("id = ? AND nid = ?", codeRow.VerifiableAddressID, p.NetworkID(ctx)).First(&va); err != nil { + // This should fail on not found errors too, because the verifiable address must exist for the flow to work. + return nil, sqlcon.HandleError(err) + } + codeRow.VerifiableAddress = &va + + return codeRow, nil +} + +func (p *Persister) DeleteVerificationCodesOfFlow(ctx context.Context, fID uuid.UUID) error { + ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.DeleteVerificationCodesOfFlow") + defer span.End() + + return p.GetConnection(ctx).Where("selfservice_verification_flow_id = ? AND nid = ?", fID, p.NetworkID(ctx)).Delete(&code.VerificationCode{}) +} diff --git a/schema/errors.go b/schema/errors.go index 6c2f20a411b5..29a9ef0869a6 100644 --- a/schema/errors.go +++ b/schema/errors.go @@ -360,7 +360,7 @@ func NewNoCodeAuthnCredentials() error { func NewTraitsMismatch() error { return errors.WithStack(&ValidationError{ ValidationError: &jsonschema.ValidationError{ - Message: `the submitted form data has changed from the previous submission. Please try again.`, + Message: `the submitted form data has changed from the previous submission`, InstancePtr: "#/", }, Messages: new(text.Messages).Add(text.NewErrorValidationTraitsMismatch()), @@ -370,7 +370,7 @@ func NewTraitsMismatch() error { func NewRegistrationCodeInvalid() error { return errors.WithStack(&ValidationError{ ValidationError: &jsonschema.ValidationError{ - Message: `the provided code is invalid or has already been used. Please try again.`, + Message: `the provided code is invalid or has already been used`, InstancePtr: "#/", }, Messages: new(text.Messages).Add(text.NewErrorValidationRegistrationCodeInvalidOrAlreadyUsed()), @@ -380,7 +380,7 @@ func NewRegistrationCodeInvalid() error { func NewLoginCodeInvalid() error { return errors.WithStack(&ValidationError{ ValidationError: &jsonschema.ValidationError{ - Message: `the provided code is invalid or has already been used. Please try again.`, + Message: `the provided code is invalid or has already been used`, InstancePtr: "#/", }, Messages: new(text.Messages).Add(text.NewErrorValidationLoginCodeInvalidOrAlreadyUsed()), diff --git a/selfservice/flow/request.go b/selfservice/flow/request.go index ef1f9ea3c617..6c7bc9709525 100644 --- a/selfservice/flow/request.go +++ b/selfservice/flow/request.go @@ -102,10 +102,7 @@ func MethodEnabledAndAllowedFromRequest(r *http.Request, flow FlowName, expected return MethodEnabledAndAllowed(r.Context(), flow, expected, method.Method, d) } -func MethodEnabledAndAllowed(ctx context.Context, flowName FlowName, expected, actual string, d interface { - config.Provider -}, -) error { +func MethodEnabledAndAllowed(ctx context.Context, flowName FlowName, expected, actual string, d config.Provider) error { if actual != expected { return errors.WithStack(ErrStrategyNotResponsible) } @@ -117,8 +114,6 @@ func MethodEnabledAndAllowed(ctx context.Context, flowName FlowName, expected, a ok = d.Config().SelfServiceCodeStrategy(ctx).PasswordlessEnabled case VerificationFlow, RecoveryFlow: ok = d.Config().SelfServiceCodeStrategy(ctx).Enabled - default: - ok = false } } else { ok = d.Config().SelfServiceStrategy(ctx, expected).Enabled diff --git a/selfservice/strategy/code/code_login.go b/selfservice/strategy/code/code_login.go index 7c183413799d..689d52f0cb4f 100644 --- a/selfservice/strategy/code/code_login.go +++ b/selfservice/strategy/code/code_login.go @@ -8,6 +8,10 @@ import ( "database/sql" "time" + "github.com/pkg/errors" + + "github.com/ory/kratos/selfservice/flow" + "github.com/gofrs/uuid" "github.com/ory/kratos/identity" @@ -61,16 +65,25 @@ func (LoginCode) TableName(ctx context.Context) string { return "identity_login_codes" } -func (f LoginCode) IsExpired() bool { - return f.ExpiresAt.Before(time.Now()) +func (f *LoginCode) Validate() error { + if f == nil { + return errors.WithStack(ErrCodeNotFound) + } + if f.ExpiresAt.Before(time.Now().UTC()) { + return errors.WithStack(flow.NewFlowExpiredError(f.ExpiresAt)) + } + if f.UsedAt.Valid { + return errors.WithStack(ErrCodeAlreadyUsed) + } + return nil } -func (r LoginCode) WasUsed() bool { - return r.UsedAt.Valid +func (f *LoginCode) GetHMACCode() string { + return f.CodeHMAC } -func (f LoginCode) IsValid() bool { - return !f.IsExpired() && !f.WasUsed() +func (f *LoginCode) GetID() uuid.UUID { + return f.ID } // swagger:ignore diff --git a/selfservice/strategy/code/code_login_test.go b/selfservice/strategy/code/code_login_test.go new file mode 100644 index 000000000000..87b50155a15b --- /dev/null +++ b/selfservice/strategy/code/code_login_test.go @@ -0,0 +1,81 @@ +// Copyright © 2023 Ory Corp +// SPDX-License-Identifier: Apache-2.0 + +package code_test + +import ( + "database/sql" + "net/http" + "testing" + "time" + + "github.com/stretchr/testify/require" + + "github.com/ory/kratos/internal" + "github.com/ory/kratos/selfservice/flow" + "github.com/ory/kratos/selfservice/flow/login" + "github.com/ory/kratos/selfservice/strategy/code" + "github.com/ory/kratos/x" + "github.com/ory/x/urlx" +) + +func TestLoginCode(t *testing.T) { + conf, _ := internal.NewFastRegistryWithMocks(t) + + newCode := func(expiresIn time.Duration, f *login.Flow) *code.LoginCode { + return &code.LoginCode{ + ID: x.NewUUID(), + FlowID: f.ID, + ExpiresAt: time.Now().Add(expiresIn), + } + } + + req := &http.Request{URL: urlx.ParseOrPanic("https://www.ory.sh/")} + t.Run("method=Validate", func(t *testing.T) { + t.Parallel() + + t.Run("case=returns error if flow is expired", func(t *testing.T) { + f, err := login.NewFlow(conf, -time.Hour, "", req, flow.TypeBrowser) + require.NoError(t, err) + + c := newCode(-time.Hour, f) + expected := new(flow.ExpiredError) + require.ErrorAs(t, c.Validate(), &expected) + }) + t.Run("case=returns no error if flow is not expired", func(t *testing.T) { + f, err := login.NewFlow(conf, time.Hour, "", req, flow.TypeBrowser) + require.NoError(t, err) + + c := newCode(time.Hour, f) + require.NoError(t, c.Validate()) + }) + + t.Run("case=returns error if flow has been used", func(t *testing.T) { + f, err := login.NewFlow(conf, -time.Hour, "", req, flow.TypeBrowser) + require.NoError(t, err) + + c := newCode(time.Hour, f) + c.UsedAt = sql.NullTime{ + Time: time.Now(), + Valid: true, + } + require.ErrorIs(t, c.Validate(), code.ErrCodeAlreadyUsed) + }) + + t.Run("case=returns no error if flow has not been used", func(t *testing.T) { + f, err := login.NewFlow(conf, -time.Hour, "", req, flow.TypeBrowser) + require.NoError(t, err) + + c := newCode(time.Hour, f) + c.UsedAt = sql.NullTime{ + Valid: false, + } + require.NoError(t, c.Validate()) + }) + + t.Run("case=returns error if flow is nil", func(t *testing.T) { + var c *code.LoginCode + require.ErrorIs(t, c.Validate(), code.ErrCodeNotFound) + }) + }) +} diff --git a/selfservice/strategy/code/code_recovery.go b/selfservice/strategy/code/code_recovery.go index f87a9c398de7..8d0cbe926063 100644 --- a/selfservice/strategy/code/code_recovery.go +++ b/selfservice/strategy/code/code_recovery.go @@ -8,6 +8,10 @@ import ( "database/sql" "time" + "github.com/pkg/errors" + + "github.com/ory/kratos/selfservice/flow" + "github.com/gofrs/uuid" "github.com/ory/herodot" @@ -73,16 +77,25 @@ func (RecoveryCode) TableName(ctx context.Context) string { return "identity_recovery_codes" } -func (f RecoveryCode) IsExpired() bool { - return f.ExpiresAt.Before(time.Now()) +func (f *RecoveryCode) Validate() error { + if f == nil { + return errors.WithStack(ErrCodeNotFound) + } + if f.ExpiresAt.Before(time.Now().UTC()) { + return errors.WithStack(flow.NewFlowExpiredError(f.ExpiresAt)) + } + if f.UsedAt.Valid { + return errors.WithStack(ErrCodeAlreadyUsed) + } + return nil } -func (r RecoveryCode) WasUsed() bool { - return r.UsedAt.Valid +func (f *RecoveryCode) GetHMACCode() string { + return f.CodeHMAC } -func (f RecoveryCode) IsValid() bool { - return !f.IsExpired() && !f.WasUsed() +func (f *RecoveryCode) GetID() uuid.UUID { + return f.ID } type CreateRecoveryCodeParams struct { diff --git a/selfservice/strategy/code/code_recovery_test.go b/selfservice/strategy/code/code_recovery_test.go index 3aadf350bb8c..dc099f02cae0 100644 --- a/selfservice/strategy/code/code_recovery_test.go +++ b/selfservice/strategy/code/code_recovery_test.go @@ -34,26 +34,26 @@ func TestRecoveryCode(t *testing.T) { } req := &http.Request{URL: urlx.ParseOrPanic("https://www.ory.sh/")} + t.Run("method=Validate", func(t *testing.T) { + t.Parallel() - t.Run("method=IsExpired", func(t *testing.T) { - t.Run("case=returns true if flow is expired", func(t *testing.T) { + t.Run("case=returns error if flow is expired", func(t *testing.T) { f, err := recovery.NewFlow(conf, -time.Hour, "", req, nil, flow.TypeBrowser) require.NoError(t, err) c := newCode(-time.Hour, f) - require.True(t, c.IsExpired()) + expected := new(flow.ExpiredError) + require.ErrorAs(t, c.Validate(), &expected) }) - t.Run("case=returns false if flow is not expired", func(t *testing.T) { + t.Run("case=returns no error if flow is not expired", func(t *testing.T) { f, err := recovery.NewFlow(conf, time.Hour, "", req, nil, flow.TypeBrowser) require.NoError(t, err) c := newCode(time.Hour, f) - require.False(t, c.IsExpired()) + require.NoError(t, c.Validate()) }) - }) - t.Run("method=WasUsed", func(t *testing.T) { - t.Run("case=returns true if flow has been used", func(t *testing.T) { + t.Run("case=returns error if flow has been used", func(t *testing.T) { f, err := recovery.NewFlow(conf, -time.Hour, "", req, nil, flow.TypeBrowser) require.NoError(t, err) @@ -62,9 +62,10 @@ func TestRecoveryCode(t *testing.T) { Time: time.Now(), Valid: true, } - require.True(t, c.WasUsed()) + require.ErrorIs(t, c.Validate(), code.ErrCodeAlreadyUsed) }) - t.Run("case=returns false if flow has not been used", func(t *testing.T) { + + t.Run("case=returns no error if flow has not been used", func(t *testing.T) { f, err := recovery.NewFlow(conf, -time.Hour, "", req, nil, flow.TypeBrowser) require.NoError(t, err) @@ -72,7 +73,12 @@ func TestRecoveryCode(t *testing.T) { c.UsedAt = sql.NullTime{ Valid: false, } - require.False(t, c.WasUsed()) + require.NoError(t, c.Validate()) + }) + + t.Run("case=returns error if flow is nil", func(t *testing.T) { + var c *code.RecoveryCode + require.ErrorIs(t, c.Validate(), code.ErrCodeNotFound) }) }) } diff --git a/selfservice/strategy/code/code_registration.go b/selfservice/strategy/code/code_registration.go index 256914760782..d8691b54b224 100644 --- a/selfservice/strategy/code/code_registration.go +++ b/selfservice/strategy/code/code_registration.go @@ -8,6 +8,10 @@ import ( "database/sql" "time" + "github.com/pkg/errors" + + "github.com/ory/kratos/selfservice/flow" + "github.com/gofrs/uuid" "github.com/ory/kratos/identity" @@ -60,16 +64,25 @@ func (RegistrationCode) TableName(ctx context.Context) string { return "identity_registration_codes" } -func (f RegistrationCode) IsExpired() bool { - return f.ExpiresAt.Before(time.Now()) +func (f *RegistrationCode) Validate() error { + if f == nil { + return errors.WithStack(ErrCodeNotFound) + } + if f.ExpiresAt.Before(time.Now().UTC()) { + return errors.WithStack(flow.NewFlowExpiredError(f.ExpiresAt)) + } + if f.UsedAt.Valid { + return errors.WithStack(ErrCodeAlreadyUsed) + } + return nil } -func (r RegistrationCode) WasUsed() bool { - return r.UsedAt.Valid +func (f *RegistrationCode) GetHMACCode() string { + return f.CodeHMAC } -func (f RegistrationCode) IsValid() bool { - return !f.IsExpired() && !f.WasUsed() +func (f *RegistrationCode) GetID() uuid.UUID { + return f.ID } // swagger:ignore diff --git a/selfservice/strategy/code/code_registration_test.go b/selfservice/strategy/code/code_registration_test.go new file mode 100644 index 000000000000..034d9fcf2b92 --- /dev/null +++ b/selfservice/strategy/code/code_registration_test.go @@ -0,0 +1,80 @@ +// Copyright © 2023 Ory Corp +// SPDX-License-Identifier: Apache-2.0 + +package code_test + +import ( + "database/sql" + "net/http" + "testing" + "time" + + "github.com/stretchr/testify/require" + + "github.com/ory/kratos/internal" + "github.com/ory/kratos/selfservice/flow" + "github.com/ory/kratos/selfservice/flow/registration" + "github.com/ory/kratos/selfservice/strategy/code" + "github.com/ory/kratos/x" + "github.com/ory/x/urlx" +) + +func TestRegistrationCode(t *testing.T) { + conf, _ := internal.NewFastRegistryWithMocks(t) + newCode := func(expiresIn time.Duration, f *registration.Flow) *code.RegistrationCode { + return &code.RegistrationCode{ + ID: x.NewUUID(), + FlowID: f.ID, + ExpiresAt: time.Now().Add(expiresIn), + } + } + + req := &http.Request{URL: urlx.ParseOrPanic("https://www.ory.sh/")} + t.Run("method=Validate", func(t *testing.T) { + t.Parallel() + + t.Run("case=returns error if flow is expired", func(t *testing.T) { + f, err := registration.NewFlow(conf, -time.Hour, "", req, flow.TypeBrowser) + require.NoError(t, err) + + c := newCode(-time.Hour, f) + expected := new(flow.ExpiredError) + require.ErrorAs(t, c.Validate(), &expected) + }) + t.Run("case=returns no error if flow is not expired", func(t *testing.T) { + f, err := registration.NewFlow(conf, time.Hour, "", req, flow.TypeBrowser) + require.NoError(t, err) + + c := newCode(time.Hour, f) + require.NoError(t, c.Validate()) + }) + + t.Run("case=returns error if flow has been used", func(t *testing.T) { + f, err := registration.NewFlow(conf, -time.Hour, "", req, flow.TypeBrowser) + require.NoError(t, err) + + c := newCode(time.Hour, f) + c.UsedAt = sql.NullTime{ + Time: time.Now(), + Valid: true, + } + require.ErrorIs(t, c.Validate(), code.ErrCodeAlreadyUsed) + }) + + t.Run("case=returns no error if flow has not been used", func(t *testing.T) { + f, err := registration.NewFlow(conf, -time.Hour, "", req, flow.TypeBrowser) + require.NoError(t, err) + + c := newCode(time.Hour, f) + c.UsedAt = sql.NullTime{ + Valid: false, + } + require.NoError(t, c.Validate()) + }) + + t.Run("case=returns error if flow is nil", func(t *testing.T) { + var c *code.RegistrationCode + require.ErrorIs(t, c.Validate(), code.ErrCodeNotFound) + }) + }) +} diff --git a/selfservice/strategy/code/code_sender.go b/selfservice/strategy/code/code_sender.go index db3af496b59e..d3667aea3b07 100644 --- a/selfservice/strategy/code/code_sender.go +++ b/selfservice/strategy/code/code_sender.go @@ -72,6 +72,10 @@ func (s *Sender) SendCode(ctx context.Context, f flow.Flow, id *identity.Identit // send to all addresses for _, address := range addresses { + // We have to generate a unique code per address, or otherwise it is not possible to link which + // address was used to verify the code. + // + // See also [this discussion](https://github.com/ory/kratos/pull/3456#discussion_r1307560988). rawCode := GenerateCode() switch f.GetFlowName() { diff --git a/selfservice/strategy/code/verification_code.go b/selfservice/strategy/code/code_verification.go similarity index 93% rename from selfservice/strategy/code/verification_code.go rename to selfservice/strategy/code/code_verification.go index a4f204bae176..324766ebfa1c 100644 --- a/selfservice/strategy/code/verification_code.go +++ b/selfservice/strategy/code/code_verification.go @@ -62,6 +62,9 @@ func (VerificationCode) TableName(context.Context) string { // - If the code was already used `ErrCodeAlreadyUsed` is returnd // - Otherwise, `nil` is returned func (f *VerificationCode) Validate() error { + if f == nil { + return errors.WithStack(ErrCodeNotFound) + } if f.ExpiresAt.Before(time.Now().UTC()) { return errors.WithStack(flow.NewFlowExpiredError(f.ExpiresAt)) } @@ -71,6 +74,14 @@ func (f *VerificationCode) Validate() error { return nil } +func (f *VerificationCode) GetHMACCode() string { + return f.CodeHMAC +} + +func (f *VerificationCode) GetID() uuid.UUID { + return f.ID +} + type CreateVerificationCodeParams struct { // Code represents the recovery code RawCode string diff --git a/selfservice/strategy/code/code_verification_test.go b/selfservice/strategy/code/code_verification_test.go new file mode 100644 index 000000000000..3217b7dcbb00 --- /dev/null +++ b/selfservice/strategy/code/code_verification_test.go @@ -0,0 +1,81 @@ +// Copyright © 2023 Ory Corp +// SPDX-License-Identifier: Apache-2.0 + +package code_test + +import ( + "database/sql" + "net/http" + "testing" + "time" + + "github.com/stretchr/testify/require" + + "github.com/ory/kratos/internal" + "github.com/ory/kratos/selfservice/flow" + "github.com/ory/kratos/selfservice/flow/verification" + "github.com/ory/kratos/selfservice/strategy/code" + "github.com/ory/kratos/x" + "github.com/ory/x/urlx" +) + +func TestVerificationCode(t *testing.T) { + conf, _ := internal.NewFastRegistryWithMocks(t) + + newCode := func(expiresIn time.Duration, f *verification.Flow) *code.VerificationCode { + return &code.VerificationCode{ + ID: x.NewUUID(), + FlowID: f.ID, + ExpiresAt: time.Now().Add(expiresIn), + } + } + + req := &http.Request{URL: urlx.ParseOrPanic("https://www.ory.sh/")} + t.Run("method=Validate", func(t *testing.T) { + t.Parallel() + + t.Run("case=returns error if flow is expired", func(t *testing.T) { + f, err := verification.NewFlow(conf, -time.Hour, "", req, nil, flow.TypeBrowser) + require.NoError(t, err) + + c := newCode(-time.Hour, f) + expected := new(flow.ExpiredError) + require.ErrorAs(t, c.Validate(), &expected) + }) + t.Run("case=returns no error if flow is not expired", func(t *testing.T) { + f, err := verification.NewFlow(conf, time.Hour, "", req, nil, flow.TypeBrowser) + require.NoError(t, err) + + c := newCode(time.Hour, f) + require.NoError(t, c.Validate()) + }) + + t.Run("case=returns error if flow has been used", func(t *testing.T) { + f, err := verification.NewFlow(conf, -time.Hour, "", req, nil, flow.TypeBrowser) + require.NoError(t, err) + + c := newCode(time.Hour, f) + c.UsedAt = sql.NullTime{ + Time: time.Now(), + Valid: true, + } + require.ErrorIs(t, c.Validate(), code.ErrCodeAlreadyUsed) + }) + + t.Run("case=returns no error if flow has not been used", func(t *testing.T) { + f, err := verification.NewFlow(conf, -time.Hour, "", req, nil, flow.TypeBrowser) + require.NoError(t, err) + + c := newCode(time.Hour, f) + c.UsedAt = sql.NullTime{ + Valid: false, + } + require.NoError(t, c.Validate()) + }) + + t.Run("case=returns error if flow is nil", func(t *testing.T) { + var c *code.VerificationCode + require.ErrorIs(t, c.Validate(), code.ErrCodeNotFound) + }) + }) +} diff --git a/selfservice/strategy/code/strategy.go b/selfservice/strategy/code/strategy.go index 032ed792a9e7..94c8de75e9b6 100644 --- a/selfservice/strategy/code/strategy.go +++ b/selfservice/strategy/code/strategy.go @@ -4,8 +4,6 @@ package code import ( - "context" - "encoding/json" "net/http" "strings" @@ -63,6 +61,7 @@ type ( x.CSRFTokenGeneratorProvider x.WriterProvider x.LoggingProvider + x.TracingProvider config.Provider @@ -113,6 +112,10 @@ type ( deps strategyDependencies dx *decoderx.HTTP } + + codeIdentifier struct { + Identifier string `json:"identifier"` + } ) func NewStrategy(deps strategyDependencies) *Strategy { @@ -294,169 +297,37 @@ func (s *Strategy) PopulateMethod(r *http.Request, f flow.Flow) error { // NewCodeUINodes creates a fresh UI for the code flow. // this is used with the `recovery`, `verification`, `registration` and `login` flows. -func (s *Strategy) NewCodeUINodes(r *http.Request, f flow.Flow, data json.RawMessage) error { +func (s *Strategy) NewCodeUINodes(r *http.Request, f flow.Flow, data any) error { if err := s.PopulateMethod(r, f); err != nil { return err } - // on Registration flow we need to populate the form with the values from the initial form generation + prefix := "" // The login flow does not process traits if f.GetFlowName() == flow.RegistrationFlow { - for _, n := range container.NewFromJSON("", node.CodeGroup, data, "traits").Nodes { - // we only set the value and not the whole field because we want to keep types from the initial form generation - f.GetUI().GetNodes().SetValueAttribute(n.ID(), n.Attributes.GetValue()) - } - } else if f.GetFlowName() == flow.LoginFlow { - // on Login flow we need to populate the form with the values from the initial form generation - for _, n := range container.NewFromJSON("", node.DefaultGroup, data, "").Nodes { - f.GetUI().GetNodes().SetValueAttribute(n.ID(), n.Attributes.GetValue()) - } + // The registration form does however + prefix = "traits" } - return nil -} - -type ( - FlowType interface { - *login.Flow | *registration.Flow | *recovery.Flow | *verification.Flow - flow.Flow - } - FlowPayload interface { - *updateLoginFlowWithCodeMethod | *updateRegistrationFlowWithCodeMethod | *updateRecoveryFlowWithCodeMethod | *updateVerificationFlowWithCodeMethod - } - CreateCodeState[T FlowType, P FlowPayload] func(context.Context, T, *Strategy, P) error - ValidateCodeState[T FlowType, P FlowPayload] func(context.Context, T, *Strategy, P) error - AlreadyValidatedCodeState[T FlowType, P FlowPayload] func(context.Context, T, *Strategy, P) error - CodeStateManager[T FlowType, P FlowPayload] struct { - f T - payload P - strategy *Strategy - createCodeState CreateCodeState[T, P] - verifyCodeState ValidateCodeState[T, P] - alreadyValidatedCodeState AlreadyValidatedCodeState[T, P] + cont, err := container.NewFromStruct("", node.CodeGroup, data, prefix) + if err != nil { + return err } -) -func NewCodeStateManager[T FlowType, P FlowPayload](f T, s *Strategy, payload P) *CodeStateManager[T, P] { - return &CodeStateManager[T, P]{ - f: f, - strategy: s, - payload: payload, + for _, n := range cont.Nodes { + // we only set the value and not the whole field because we want to keep types from the initial form generation + f.GetUI().GetNodes().SetValueAttribute(n.ID(), n.Attributes.GetValue()) } -} - -func (c *CodeStateManager[T, P]) SetCreateCodeHandler(fn CreateCodeState[T, P]) { - c.createCodeState = fn -} -func (c *CodeStateManager[T, P]) SetCodeVerifyHandler(fn ValidateCodeState[T, P]) { - c.verifyCodeState = fn -} - -func (c *CodeStateManager[T, P]) SetCodeDoneHandler(fn AlreadyValidatedCodeState[T, P]) { - c.alreadyValidatedCodeState = fn -} - -func (c *CodeStateManager[T, P]) validatePayload(ctx context.Context) error { - switch v := any(c.payload).(type) { - case *updateLoginFlowWithCodeMethod: - if len(v.Identifier) == 0 { - return errors.WithStack(schema.NewRequiredError("#/identifier", "identifier")) - } - case *updateRegistrationFlowWithCodeMethod: - if len(v.Traits) == 0 { - return errors.WithStack(schema.NewRequiredError("#/traits", "traits")) - } - case *updateRecoveryFlowWithCodeMethod: - if len(v.Email) == 0 { - return errors.WithStack(schema.NewRequiredError("#/email", "email")) - } - case *updateVerificationFlowWithCodeMethod: - if len(v.Email) == 0 { - return errors.WithStack(schema.NewRequiredError("#/email", "email")) - } - default: - return errors.WithStack(herodot.ErrBadRequest.WithReason("received unexpected flow payload type")) - } return nil } -func (c *CodeStateManager[T, P]) getResend() string { - switch v := any(c.payload).(type) { - case *updateLoginFlowWithCodeMethod: - return v.Resend - case *updateRegistrationFlowWithCodeMethod: - return v.Resend - } - return "" -} - -func (c *CodeStateManager[T, P]) getCode() string { - switch v := any(c.payload).(type) { - case *updateLoginFlowWithCodeMethod: - return v.Code - case *updateRegistrationFlowWithCodeMethod: - return v.Code - case *updateRecoveryFlowWithCodeMethod: - return v.Code - case *updateVerificationFlowWithCodeMethod: - return v.Code - } - return "" -} - -func (c *CodeStateManager[T, P]) Run(ctx context.Context) error { +func SetDefaultFlowState(f flow.Flow, resend string) { // By Default the flow should be in the 'choose method' state. - if c.f.GetState() == "" { - c.f.SetState(flow.StateChooseMethod) - } - - if strings.EqualFold(c.getResend(), "code") { - c.f.SetState(flow.StateChooseMethod) - } - - switch c.f.GetState() { - case flow.StateChooseMethod: - // we are in the first submission state of the flow - - if err := c.validatePayload(ctx); err != nil { - return err - } - - if err := c.createCodeState(ctx, c.f, c.strategy, c.payload); err != nil { - return err - } - - case flow.StateEmailSent: - // we are in the second submission state of the flow - // we need to check the code and update the identity - if len(c.getCode()) == 0 { - return errors.WithStack(schema.NewRequiredError("#/code", "code")) - } - - if err := c.validatePayload(ctx); err != nil { - return err - } - - if err := c.verifyCodeState(ctx, c.f, c.strategy, c.payload); err != nil { - return err - } - case flow.StatePassedChallenge: - return c.alreadyValidatedCodeState(ctx, c.f, c.strategy, c.payload) - default: - return errors.WithStack(errors.New("Unknown flow state")) + if f.GetState() == "" { + f.SetState(flow.StateChooseMethod) } - return nil -} -func (s *Strategy) NextFlowState(f flow.Flow) { - switch f.GetState() { - case flow.StateChooseMethod: - f.SetState(flow.StateEmailSent) - case flow.StateEmailSent: - f.SetState(flow.StatePassedChallenge) - case flow.StatePassedChallenge: - f.SetState(flow.StatePassedChallenge) - default: + if strings.EqualFold(resend, "code") { f.SetState(flow.StateChooseMethod) } } diff --git a/selfservice/strategy/code/strategy_login.go b/selfservice/strategy/code/strategy_login.go index 9345c834590a..264df98e2dd9 100644 --- a/selfservice/strategy/code/strategy_login.go +++ b/selfservice/strategy/code/strategy_login.go @@ -5,7 +5,6 @@ package code import ( "context" - "encoding/json" "net/http" "strings" @@ -13,6 +12,8 @@ import ( "github.com/pkg/errors" "github.com/ory/herodot" + "github.com/ory/x/otelx" + "github.com/ory/kratos/identity" "github.com/ory/kratos/schema" "github.com/ory/kratos/selfservice/flow" @@ -64,15 +65,19 @@ func (s *Strategy) CompletedAuthenticationMethod(ctx context.Context) session.Au } } -func (s *Strategy) HandleLoginError(w http.ResponseWriter, r *http.Request, flow *login.Flow, body *updateLoginFlowWithCodeMethod, err error) error { - if flow != nil { +func (s *Strategy) HandleLoginError(r *http.Request, f *login.Flow, body *updateLoginFlowWithCodeMethod, err error) error { + if errors.Is(err, flow.ErrCompletedByStrategy) { + return err + } + + if f != nil { email := "" if body != nil { email = body.Identifier } - flow.UI.SetCSRF(s.deps.GenerateCSRFToken(r)) - flow.UI.GetNodes().Upsert( + f.UI.SetCSRF(s.deps.GenerateCSRFToken(r)) + f.UI.GetNodes().Upsert( node.NewInputField("identifier", email, node.DefaultGroup, node.InputAttributeTypeText, node.WithRequiredInputAttribute). WithMetaLabel(text.NewInfoNodeLabelID()), ) @@ -85,34 +90,25 @@ func (s *Strategy) PopulateLoginMethod(r *http.Request, requestedAAL identity.Au return s.PopulateMethod(r, lf) } -func (s *Strategy) getIdentity(ctx context.Context, identifier string) (*identity.Identity, *identity.Credentials, error) { - i, _, err := s.deps.PrivilegedIdentityPool().FindByCredentialsIdentifier(ctx, s.ID(), identifier) +func (s *Strategy) getIdentity(ctx context.Context, identifier string) (_ *identity.Identity, _ *identity.Credentials, err error) { + ctx, span := s.deps.Tracer(ctx).Tracer().Start(ctx, "selfservice.strategy.code.strategy.getIdentity") + defer otelx.End(span, &err) + + i, cred, err := s.deps.PrivilegedIdentityPool().FindByCredentialsIdentifier(ctx, s.ID(), identifier) if err != nil { return nil, nil, errors.WithStack(schema.NewNoCodeAuthnCredentials()) } - if err := s.deps.IdentityValidator().Validate(ctx, i); err != nil { - return nil, nil, errors.WithStack(schema.NewRequiredError("#/identifier", "identifier")) - } - - cred, ok := i.GetCredentials(s.ID()) - if !ok { - return nil, nil, errors.WithStack(schema.NewNoCodeAuthnCredentials()) - } else if len(cred.Identifiers) == 0 { + if len(cred.Identifiers) == 0 { return nil, nil, errors.WithStack(schema.NewNoCodeAuthnCredentials()) - } else if cred.IdentifierAddressType == "" { - return nil, nil, errors.WithStack(schema.NewRequiredError("#/code", "via")) } return i, cred, nil } -func (s *Strategy) Login(w http.ResponseWriter, r *http.Request, f *login.Flow, identityID uuid.UUID) (i *identity.Identity, err error) { - s.deps.Audit(). - WithRequest(r). - WithField("identity_id", identityID). - WithField("login_flow_id", f.ID). - Info("Login with the code strategy started.") +func (s *Strategy) Login(w http.ResponseWriter, r *http.Request, f *login.Flow, _ uuid.UUID) (_ *identity.Identity, err error) { + ctx, span := s.deps.Tracer(r.Context()).Tracer().Start(r.Context(), "selfservice.strategy.code.strategy.Login") + defer otelx.End(span, &err) if err := flow.MethodEnabledAndAllowedFromRequest(r, f.GetFlowName(), s.ID().String(), s.deps); err != nil { return nil, err @@ -129,157 +125,154 @@ func (s *Strategy) Login(w http.ResponseWriter, r *http.Request, f *login.Flow, decoderx.MustHTTPRawJSONSchemaCompiler(loginMethodSchema), decoderx.HTTPDecoderAllowedMethods("POST"), decoderx.HTTPDecoderJSONFollowsFormFormat()); err != nil { - return nil, s.HandleLoginError(w, r, f, &p, err) + return nil, s.HandleLoginError(r, f, &p, err) } - if err := flow.EnsureCSRF(s.deps, r, f.Type, s.deps.Config().DisableAPIFlowEnforcement(r.Context()), s.deps.GenerateCSRFToken, p.CSRFToken); err != nil { - return nil, s.HandleLoginError(w, r, f, &p, err) + if err := flow.EnsureCSRF(s.deps, r, f.Type, s.deps.Config().DisableAPIFlowEnforcement(ctx), s.deps.GenerateCSRFToken, p.CSRFToken); err != nil { + return nil, s.HandleLoginError(r, f, &p, err) } - codeManager := NewCodeStateManager(f, s, &p) + // By Default the flow should be in the 'choose method' state. + SetDefaultFlowState(f, p.Resend) - codeManager.SetCreateCodeHandler(func(ctx context.Context, f *login.Flow, strategy *Strategy, p *updateLoginFlowWithCodeMethod) error { - strategy.deps.Audit(). - WithSensitiveField("identifier", p.Identifier). - Info("Creating login code state.") - - // Step 1: Get the identity - i, cred, err := strategy.getIdentity(ctx, p.Identifier) + switch f.GetState() { + case flow.StateChooseMethod: + if err := s.loginSendEmail(ctx, w, r, f, &p); err != nil { + return nil, s.HandleLoginError(r, f, &p, err) + } + return nil, nil + case flow.StateEmailSent: + i, err := s.loginVerifyCode(ctx, r, f, &p) if err != nil { - return err + return nil, s.HandleLoginError(r, f, &p, err) } + return i, nil + case flow.StatePassedChallenge: + return nil, s.HandleLoginError(r, f, &p, errors.WithStack(schema.NewNoLoginStrategyResponsible())) + } - // Step 2: Delete any previous login codes for this flow ID - if err := strategy.deps.LoginCodePersister().DeleteLoginCodesOfFlow(ctx, f.GetID()); err != nil { - return errors.WithStack(err) - } + return nil, s.HandleLoginError(r, f, &p, errors.WithStack(herodot.ErrInternalServerError.WithReasonf("Unexpected flow state: %s", f.GetState()))) +} - var identifier string - for _, id := range cred.Identifiers { - if strings.EqualFold(p.Identifier, id) { - identifier = id - } - } +func (s *Strategy) loginSendEmail(ctx context.Context, w http.ResponseWriter, r *http.Request, f *login.Flow, p *updateLoginFlowWithCodeMethod) (err error) { + ctx, span := s.deps.Tracer(ctx).Tracer().Start(ctx, "selfservice.strategy.code.strategy.loginSendEmail") + defer otelx.End(span, &err) - addresses := []Address{ - { - To: identifier, - Via: identity.CodeAddressType(cred.IdentifierAddressType), - }, - } + if len(p.Identifier) == 0 { + return errors.WithStack(schema.NewRequiredError("#/identifier", "identifier")) + } - // kratos only supports `email` identifiers at the moment with the code method - // this is validated in the identity validation step above - if err := strategy.deps.CodeSender().SendCode(ctx, f, i, addresses...); err != nil { - return errors.WithStack(err) - } + p.Identifier = maybeNormalizeEmail(p.Identifier) - // sets the flow state to code sent - s.NextFlowState(f) + // Step 1: Get the identity + i, _, err := s.getIdentity(ctx, p.Identifier) + if err != nil { + return err + } - nodeData, err := json.Marshal(struct { - Identifier string `json:"identifier"` - }{ - Identifier: p.Identifier, - }) - if err != nil { - return errors.WithStack(err) - } + // Step 2: Delete any previous login codes for this flow ID + if err := s.deps.LoginCodePersister().DeleteLoginCodesOfFlow(ctx, f.GetID()); err != nil { + return errors.WithStack(err) + } - if err := s.NewCodeUINodes(r, f, nodeData); err != nil { - return err - } + addresses := []Address{{ + To: p.Identifier, + Via: identity.CodeAddressType(identity.AddressTypeEmail), + }} - f.Active = identity.CredentialsTypeCodeAuth - if err = strategy.deps.LoginFlowPersister().UpdateLoginFlow(ctx, f); err != nil { - return err - } + // kratos only supports `email` identifiers at the moment with the code method + // this is validated in the identity validation step above + if err := s.deps.CodeSender().SendCode(ctx, f, i, addresses...); err != nil { + return errors.WithStack(err) + } - if x.IsJSONRequest(r) { - strategy.deps.Writer().WriteCode(w, r, http.StatusBadRequest, f) - } else { - http.Redirect(w, r, f.AppendTo(strategy.deps.Config().SelfServiceFlowLoginUI(ctx)).String(), http.StatusSeeOther) - } + // sets the flow state to code sent + f.SetState(flow.NextState(f.GetState())) + + if err := s.NewCodeUINodes(r, f, &codeIdentifier{Identifier: p.Identifier}); err != nil { + return err + } - // we return an error to the flow handler so that it does not continue execution of the hooks. - // we are not done with the login flow yet. The user needs to verify the code and then we need to persist the identity. - return errors.WithStack(flow.ErrCompletedByStrategy) - }) + f.Active = identity.CredentialsTypeCodeAuth + if err = s.deps.LoginFlowPersister().UpdateLoginFlow(ctx, f); err != nil { + return err + } - codeManager.SetCodeVerifyHandler(func(ctx context.Context, f *login.Flow, strategy *Strategy, p *updateLoginFlowWithCodeMethod) error { - strategy.deps.Audit(). - WithSensitiveField("code", p.Code). - WithSensitiveField("identifier", p.Identifier). - Debug("Verifying login code") + if x.IsJSONRequest(r) { + s.deps.Writer().WriteCode(w, r, http.StatusBadRequest, f) + } else { + http.Redirect(w, r, f.AppendTo(s.deps.Config().SelfServiceFlowLoginUI(ctx)).String(), http.StatusSeeOther) + } - // Step 1: Get the identity - i, _, err = strategy.getIdentity(ctx, p.Identifier) - if err != nil { - return err - } + // we return an error to the flow handler so that it does not continue execution of the hooks. + // we are not done with the login flow yet. The user needs to verify the code and then we need to persist the identity. + return errors.WithStack(flow.ErrCompletedByStrategy) +} - loginCode, err := strategy.deps.LoginCodePersister().UseLoginCode(ctx, f.ID, i.ID, p.Code) - if err != nil { - if errors.Is(err, ErrCodeNotFound) { - return schema.NewLoginCodeInvalid() - } - return errors.WithStack(err) - } +// If identifier is an email, we lower case it because on mobile phones the first letter sometimes is capitalized. +func maybeNormalizeEmail(input string) string { + if strings.Contains(input, "@") { + return strings.ToLower(input) + } + return input +} - i, err = strategy.deps.PrivilegedIdentityPool().GetIdentity(ctx, loginCode.IdentityID, identity.ExpandDefault) - if err != nil { - return errors.WithStack(err) - } +func (s *Strategy) loginVerifyCode(ctx context.Context, r *http.Request, f *login.Flow, p *updateLoginFlowWithCodeMethod) (_ *identity.Identity, err error) { + ctx, span := s.deps.Tracer(ctx).Tracer().Start(ctx, "selfservice.strategy.code.strategy.loginVerifyCode") + defer otelx.End(span, &err) + + // we are in the second submission state of the flow + // we need to check the code and update the identity + if p.Code == "" { + return nil, errors.WithStack(schema.NewRequiredError("#/code", "code")) + } - // Step 2: The code was correct - f.Active = identity.CredentialsTypeCodeAuth + if len(p.Identifier) == 0 { + return nil, errors.WithStack(schema.NewRequiredError("#/identifier", "identifier")) + } - // since nothing has errored yet, we can assume that the code is correct - // and we can update the login flow - strategy.NextFlowState(f) + p.Identifier = maybeNormalizeEmail(p.Identifier) - if err := strategy.deps.LoginFlowPersister().UpdateLoginFlow(ctx, f); err != nil { - return errors.WithStack(err) - } + // Step 1: Get the identity + i, _, err := s.getIdentity(ctx, p.Identifier) + if err != nil { + return nil, err + } - for idx := range i.VerifiableAddresses { - va := i.VerifiableAddresses[idx] - if !va.Verified && loginCode.Address == va.Value { - va.Verified = true - va.Status = identity.VerifiableAddressStatusCompleted - if err := strategy.deps.PrivilegedIdentityPool().UpdateVerifiableAddress(r.Context(), &va); err != nil { - return err - } - break - } + loginCode, err := s.deps.LoginCodePersister().UseLoginCode(ctx, f.ID, i.ID, p.Code) + if err != nil { + if errors.Is(err, ErrCodeNotFound) { + return nil, schema.NewLoginCodeInvalid() } + return nil, errors.WithStack(err) + } - return nil - }) + i, err = s.deps.PrivilegedIdentityPool().GetIdentity(ctx, loginCode.IdentityID, identity.ExpandDefault) + if err != nil { + return nil, errors.WithStack(err) + } - codeManager.SetCodeDoneHandler(func(ctx context.Context, f *login.Flow, strategy *Strategy, p *updateLoginFlowWithCodeMethod) error { - strategy.deps.Audit(). - WithSensitiveField("identifier", p.Identifier). - Debug("The login flow has already been completed, but is being re-requested.") - return errors.WithStack(schema.NewNoLoginStrategyResponsible()) - }) + // Step 2: The code was correct + f.Active = identity.CredentialsTypeCodeAuth - if err := codeManager.Run(r.Context()); err != nil { - if errors.Is(err, flow.ErrCompletedByStrategy) { - return nil, err - } - // the error is already handled by the registered code states - return i, s.HandleLoginError(w, r, f, &p, err) + // since nothing has errored yet, we can assume that the code is correct + // and we can update the login flow + f.SetState(flow.NextState(f.GetState())) + + if err := s.deps.LoginFlowPersister().UpdateLoginFlow(ctx, f); err != nil { + return nil, errors.WithStack(err) } - // a precaution in case the code manager did not set the identity - if i == nil { - s.deps.Audit(). - WithSensitiveField("identifier", p.Identifier). - WithRequest(r). - WithField("login_flow", f). - Error("The code manager did not set the identity.") - return nil, errors.WithStack(herodot.ErrInternalServerError.WithReason("the login flow did not complete successfully")) + for idx := range i.VerifiableAddresses { + va := i.VerifiableAddresses[idx] + if !va.Verified && loginCode.Address == va.Value { + va.Verified = true + va.Status = identity.VerifiableAddressStatusCompleted + if err := s.deps.PrivilegedIdentityPool().UpdateVerifiableAddress(ctx, &va); err != nil { + return nil, err + } + break + } } return i, nil diff --git a/selfservice/strategy/code/strategy_login_test.go b/selfservice/strategy/code/strategy_login_test.go index 3f4a631a9e62..a371fdfbf5f6 100644 --- a/selfservice/strategy/code/strategy_login_test.go +++ b/selfservice/strategy/code/strategy_login_test.go @@ -12,6 +12,8 @@ import ( "net/url" "testing" + "github.com/ory/x/stringsx" + "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "github.com/tidwall/gjson" @@ -162,6 +164,27 @@ func TestLoginCodeStrategy(t *testing.T) { }, } { t.Run("test="+tc.d, func(t *testing.T) { + t.Run("case=email identifier should be case insensitive", func(t *testing.T) { + // create login flow + s := createLoginFlow(ctx, t, public, tc.isSPA) + + // submit email + s = submitLogin(ctx, t, s, tc.isSPA, func(v *url.Values) { + v.Set("identifier", stringsx.ToUpperInitial(s.identityEmail)) + }, false, nil) + + message := testhelpers.CourierExpectMessage(ctx, t, reg, s.identityEmail, "Login to your account") + assert.Contains(t, message.Body, "please login to your account by entering the following code") + + loginCode := testhelpers.CourierExpectCodeInMessage(t, message, 1) + assert.NotEmpty(t, loginCode) + + // 3. Submit OTP + submitLogin(ctx, t, s, tc.isSPA, func(v *url.Values) { + v.Set("code", loginCode) + }, true, nil) + }) + t.Run("case=should be able to log in with code", func(t *testing.T) { // create login flow s := createLoginFlow(ctx, t, public, tc.isSPA) @@ -331,7 +354,7 @@ func TestLoginCodeStrategy(t *testing.T) { }) }) - t.Run("case=resend code shoud invalidate previous code", func(t *testing.T) { + t.Run("case=resend code should invalidate previous code", func(t *testing.T) { ctx := context.Background() s := createLoginFlow(ctx, t, public, tc.isSPA) diff --git a/selfservice/strategy/code/strategy_registration.go b/selfservice/strategy/code/strategy_registration.go index dc4dce7a6b8d..f4c97641f055 100644 --- a/selfservice/strategy/code/strategy_registration.go +++ b/selfservice/strategy/code/strategy_registration.go @@ -10,6 +10,9 @@ import ( "net/http" "strings" + "github.com/ory/herodot" + "github.com/ory/x/otelx" + "github.com/pkg/errors" "github.com/ory/kratos/identity" @@ -59,19 +62,27 @@ type updateRegistrationFlowWithCodeMethod struct { Resend string `json:"resend" form:"resend"` } +func (p *updateRegistrationFlowWithCodeMethod) GetResend() string { + return p.Resend +} + func (s *Strategy) RegisterRegistrationRoutes(*x.RouterPublic) {} -func (s *Strategy) HandleRegistrationError(w http.ResponseWriter, r *http.Request, flow *registration.Flow, body *updateRegistrationFlowWithCodeMethod, err error) error { - if flow != nil { +func (s *Strategy) HandleRegistrationError(ctx context.Context, r *http.Request, f *registration.Flow, body *updateRegistrationFlowWithCodeMethod, err error) error { + if errors.Is(err, flow.ErrCompletedByStrategy) { + return err + } + + if f != nil { if body != nil { - action := flow.AppendTo(urlx.AppendPaths(s.deps.Config().SelfPublicURL(r.Context()), registration.RouteSubmitFlow)).String() + action := f.AppendTo(urlx.AppendPaths(s.deps.Config().SelfPublicURL(ctx), registration.RouteSubmitFlow)).String() for _, n := range container.NewFromJSON(action, node.CodeGroup, body.Traits, "traits").Nodes { // we only set the value and not the whole field because we want to keep types from the initial form generation - flow.UI.Nodes.SetValueAttribute(n.ID(), n.Attributes.GetValue()) + f.UI.Nodes.SetValueAttribute(n.ID(), n.Attributes.GetValue()) } } - flow.UI.SetCSRF(s.deps.GenerateCSRFToken(r)) + f.UI.SetCSRF(s.deps.GenerateCSRFToken(r)) } return err @@ -129,11 +140,9 @@ func (s *Strategy) getCredentialsFromTraits(ctx context.Context, f *registration return cred, nil } -func (s *Strategy) Register(w http.ResponseWriter, r *http.Request, f *registration.Flow, i *identity.Identity) error { - s.deps.Audit(). - WithRequest(r). - WithField("registration_flow_id", f.ID). - Info("Registration with the code strategy started.") +func (s *Strategy) Register(w http.ResponseWriter, r *http.Request, f *registration.Flow, i *identity.Identity) (err error) { + ctx, span := s.deps.Tracer(r.Context()).Tracer().Start(r.Context(), "selfservice.strategy.code.strategy.Register") + defer otelx.End(span, &err) if err := flow.MethodEnabledAndAllowedFromRequest(r, f.GetFlowName(), s.ID().String(), s.deps); err != nil { return err @@ -141,133 +150,135 @@ func (s *Strategy) Register(w http.ResponseWriter, r *http.Request, f *registrat var p updateRegistrationFlowWithCodeMethod if err := registration.DecodeBody(&p, r, s.dx, s.deps.Config(), registrationSchema); err != nil { - return s.HandleRegistrationError(w, r, f, &p, err) + return s.HandleRegistrationError(ctx, r, f, &p, err) } - if err := flow.EnsureCSRF(s.deps, r, f.Type, s.deps.Config().DisableAPIFlowEnforcement(r.Context()), s.deps.GenerateCSRFToken, p.CSRFToken); err != nil { - return s.HandleRegistrationError(w, r, f, &p, err) + if err := flow.EnsureCSRF(s.deps, r, f.Type, s.deps.Config().DisableAPIFlowEnforcement(ctx), s.deps.GenerateCSRFToken, p.CSRFToken); err != nil { + return s.HandleRegistrationError(ctx, r, f, &p, err) } - codeManager := NewCodeStateManager(f, s, &p) + // By Default the flow should be in the 'choose method' state. + SetDefaultFlowState(f, p.Resend) - codeManager.SetCreateCodeHandler(func(ctx context.Context, f *registration.Flow, strategy *Strategy, p *updateRegistrationFlowWithCodeMethod) error { - strategy.deps.Logger(). - WithSensitiveField("traits", p.Traits). - WithSensitiveField("transient_paylaod", p.TransientPayload). - Debug("Creating registration code.") + switch f.GetState() { + case flow.StateChooseMethod: + return s.HandleRegistrationError(ctx, r, f, &p, s.registrationSendEmail(ctx, w, r, f, &p, i)) + case flow.StateEmailSent: + return s.HandleRegistrationError(ctx, r, f, &p, s.registrationVerifyCode(ctx, f, &p, i)) + case flow.StatePassedChallenge: + return s.HandleRegistrationError(ctx, r, f, &p, errors.WithStack(schema.NewNoRegistrationStrategyResponsible())) + } - // Create the Registration code + return s.HandleRegistrationError(ctx, r, f, &p, errors.WithStack(herodot.ErrInternalServerError.WithReasonf("Unexpected flow state: %s", f.GetState()))) +} - // Step 1: validate the identity's traits - cred, err := strategy.getCredentialsFromTraits(ctx, f, i, p.Traits, p.TransientPayload) - if err != nil { - return err - } +func (s *Strategy) registrationSendEmail(ctx context.Context, w http.ResponseWriter, r *http.Request, f *registration.Flow, p *updateRegistrationFlowWithCodeMethod, i *identity.Identity) (err error) { + ctx, span := s.deps.Tracer(ctx).Tracer().Start(ctx, "selfservice.strategy.code.strategy.registrationSendEmail") + defer otelx.End(span, &err) - // Step 2: Delete any previous registration codes for this flow ID - if err := strategy.deps.RegistrationCodePersister().DeleteRegistrationCodesOfFlow(ctx, f.ID); err != nil { - return errors.WithStack(err) - } + if len(p.Traits) == 0 { + return errors.WithStack(schema.NewRequiredError("#/traits", "traits")) + } - // Step 3: Get the identity email and send the code - var addresses []Address - for _, identifier := range cred.Identifiers { - addresses = append(addresses, Address{To: identifier, Via: identity.CodeAddressType(cred.IdentifierAddressType)}) - } - // kratos only supports `email` identifiers at the moment with the code method - // this is validated in the identity validation step above - if err := strategy.deps.CodeSender().SendCode(ctx, f, i, addresses...); err != nil { - return errors.WithStack(err) - } + // Create the Registration code - // sets the flow state to code sent - strategy.NextFlowState(f) + // Step 1: validate the identity's traits + cred, err := s.getCredentialsFromTraits(ctx, f, i, p.Traits, p.TransientPayload) + if err != nil { + return err + } - // Step 4: Generate the UI for the `code` input form - // re-initialize the UI with a "clean" new state - // this should also provide a "resend" button and an option to change the email address - if err := strategy.NewCodeUINodes(r, f, p.Traits); err != nil { - return errors.WithStack(err) - } + // Step 2: Delete any previous registration codes for this flow ID + if err := s.deps.RegistrationCodePersister().DeleteRegistrationCodesOfFlow(ctx, f.ID); err != nil { + return errors.WithStack(err) + } - f.Active = identity.CredentialsTypeCodeAuth - if err := strategy.deps.RegistrationFlowPersister().UpdateRegistrationFlow(ctx, f); err != nil { - return errors.WithStack(err) - } + // Step 3: Get the identity email and send the code + var addresses []Address + for _, identifier := range cred.Identifiers { + addresses = append(addresses, Address{To: identifier, Via: identity.AddressTypeEmail}) + } + // kratos only supports `email` identifiers at the moment with the code method + // this is validated in the identity validation step above + if err := s.deps.CodeSender().SendCode(ctx, f, i, addresses...); err != nil { + return errors.WithStack(err) + } - if x.IsJSONRequest(r) { - strategy.deps.Writer().WriteCode(w, r, http.StatusBadRequest, f) - } else { - http.Redirect(w, r, f.AppendTo(s.deps.Config().SelfServiceFlowRegistrationUI(ctx)).String(), http.StatusSeeOther) - } + // sets the flow state to code sent + f.SetState(flow.NextState(f.GetState())) - // we return an error to the flow handler so that it does not continue execution of the hooks. - // we are not done with the registration flow yet. The user needs to verify the code and then we need to persist the identity. - return errors.WithStack(flow.ErrCompletedByStrategy) - }) - - codeManager.SetCodeVerifyHandler(func(ctx context.Context, f *registration.Flow, strategy *Strategy, p *updateRegistrationFlowWithCodeMethod) error { - strategy.deps.Logger(). - WithSensitiveField("traits", p.Traits). - WithSensitiveField("transient_payload", p.TransientPayload). - WithSensitiveField("code", p.Code). - Debug("Verifying registration code") - - // Step 1: Re-validate the identity's traits - // this is important since the client could have switched out the identity's traits - // this method also returns the credentials for a temporary identity - cred, err := strategy.getCredentialsFromTraits(ctx, f, i, p.Traits, p.TransientPayload) - if err != nil { - return err - } + // Step 4: Generate the UI for the `code` input form + // re-initialize the UI with a "clean" new state + // this should also provide a "resend" button and an option to change the email address + if err := s.NewCodeUINodes(r, f, p.Traits); err != nil { + return errors.WithStack(err) + } - // Step 2: Check if the flow traits match the identity traits - for _, n := range container.NewFromJSON("", node.DefaultGroup, p.Traits, "traits").Nodes { - if !strings.EqualFold(f.GetUI().GetNodes().Find(n.ID()).Attributes.GetValue().(string), n.Attributes.GetValue().(string)) { - return errors.WithStack(schema.NewTraitsMismatch()) - } - } + f.Active = identity.CredentialsTypeCodeAuth + if err := s.deps.RegistrationFlowPersister().UpdateRegistrationFlow(ctx, f); err != nil { + return errors.WithStack(err) + } - // Step 3: Attempt to use the code - registrationCode, err := strategy.deps.RegistrationCodePersister().UseRegistrationCode(ctx, f.ID, p.Code, cred.Identifiers...) - if err != nil { - if errors.Is(err, ErrCodeNotFound) { - return errors.WithStack(schema.NewRegistrationCodeInvalid()) - } - return errors.WithStack(err) - } + if x.IsJSONRequest(r) { + s.deps.Writer().WriteCode(w, r, http.StatusBadRequest, f) + } else { + http.Redirect(w, r, f.AppendTo(s.deps.Config().SelfServiceFlowRegistrationUI(ctx)).String(), http.StatusSeeOther) + } - // Step 4: The code was correct, populate the Identity credentials and traits - if err := strategy.handleIdentityTraits(ctx, f, p.Traits, p.TransientPayload, i, WithCredentials(registrationCode.AddressType, registrationCode.UsedAt)); err != nil { - return errors.WithStack(err) - } + // we return an error to the flow handler so that it does not continue execution of the hooks. + // we are not done with the registration flow yet. The user needs to verify the code and then we need to persist the identity. + return errors.WithStack(flow.ErrCompletedByStrategy) + +} - // since nothing has errored yet, we can assume that the code is correct - // and we can update the registration flow - strategy.NextFlowState(f) +func (s *Strategy) registrationVerifyCode(ctx context.Context, f *registration.Flow, p *updateRegistrationFlowWithCodeMethod, i *identity.Identity) (err error) { + ctx, span := s.deps.Tracer(ctx).Tracer().Start(ctx, "selfservice.strategy.code.strategy.registrationVerifyCode") + defer otelx.End(span, &err) - if err := strategy.deps.RegistrationFlowPersister().UpdateRegistrationFlow(ctx, f); err != nil { - return errors.WithStack(err) - } + if len(p.Code) == 0 { + return errors.WithStack(schema.NewRequiredError("#/code", "code")) + } - return nil - }) + if len(p.Traits) == 0 { + return errors.WithStack(schema.NewRequiredError("#/traits", "traits")) + } - codeManager.SetCodeDoneHandler(func(ctx context.Context, f *registration.Flow, strategy *Strategy, p *updateRegistrationFlowWithCodeMethod) error { - strategy.deps.Audit(). - WithSensitiveField("traits", p.Traits). - WithSensitiveField("transient_payload", p.TransientPayload). - WithSensitiveField("code", p.Code). - Debug("The registration flow has already been completed, but is being re-requested.") + // Step 1: Re-validate the identity's traits + // this is important since the client could have switched out the identity's traits + // this method also returns the credentials for a temporary identity + cred, err := s.getCredentialsFromTraits(ctx, f, i, p.Traits, p.TransientPayload) + if err != nil { + return err + } - return errors.WithStack(schema.NewNoRegistrationStrategyResponsible()) - }) + // Step 2: Check if the flow traits match the identity traits + for _, n := range container.NewFromJSON("", node.DefaultGroup, p.Traits, "traits").Nodes { + if !strings.EqualFold(f.GetUI().GetNodes().Find(n.ID()).Attributes.GetValue().(string), n.Attributes.GetValue().(string)) { + return errors.WithStack(schema.NewTraitsMismatch()) + } + } - if err := codeManager.Run(r.Context()); err != nil { - if errors.Is(err, flow.ErrCompletedByStrategy) { - return err + // Step 3: Attempt to use the code + registrationCode, err := s.deps.RegistrationCodePersister().UseRegistrationCode(ctx, f.ID, p.Code, cred.Identifiers...) + if err != nil { + if errors.Is(err, ErrCodeNotFound) { + return errors.WithStack(schema.NewRegistrationCodeInvalid()) } - return s.HandleRegistrationError(w, r, f, &p, err) + return errors.WithStack(err) } + + // Step 4: The code was correct, populate the Identity credentials and traits + if err := s.handleIdentityTraits(ctx, f, p.Traits, p.TransientPayload, i, WithCredentials(registrationCode.AddressType, registrationCode.UsedAt)); err != nil { + return errors.WithStack(err) + } + + // since nothing has errored yet, we can assume that the code is correct + // and we can update the registration flow + f.SetState(flow.NextState(f.GetState())) + + if err := s.deps.RegistrationFlowPersister().UpdateRegistrationFlow(ctx, f); err != nil { + return errors.WithStack(err) + } + return nil } diff --git a/selfservice/strategy/code/strategy_registration_test.go b/selfservice/strategy/code/strategy_registration_test.go index edac27c8506c..d787b2b3fa4e 100644 --- a/selfservice/strategy/code/strategy_registration_test.go +++ b/selfservice/strategy/code/strategy_registration_test.go @@ -32,10 +32,11 @@ import ( ) type state struct { - flowID string - client *http.Client - email string - testServer *httptest.Server + flowID string + client *http.Client + email string + testServer *httptest.Server + resultIdentity *identity.Identity } func TestRegistrationCodeStrategyDisabled(t *testing.T) { @@ -130,15 +131,16 @@ func TestRegistrationCodeStrategy(t *testing.T) { registerNewUser := func(ctx context.Context, t *testing.T, s *state, isSPA bool, submitAssertion onSubmitAssertion) *state { t.Helper() - email := testhelpers.RandomEmail() - s.email = email + if s.email == "" { + s.email = testhelpers.RandomEmail() + } rf, resp, err := testhelpers.NewSDKCustomClient(s.testServer, s.client).FrontendApi.GetRegistrationFlow(context.Background()).Id(s.flowID).Execute() require.NoError(t, err) require.EqualValues(t, http.StatusOK, resp.StatusCode) values := testhelpers.SDKFormFieldsToURLValues(rf.Ui.Nodes) - values.Set("traits.email", email) + values.Set("traits.email", s.email) values.Set("method", "code") body, resp := testhelpers.RegistrationMakeRequest(t, false, isSPA, rf, s.client, testhelpers.EncodeFormAsJSON(t, false, values)) @@ -155,7 +157,7 @@ func TestRegistrationCodeStrategy(t *testing.T) { } csrfToken := gjson.Get(body, "ui.nodes.#(attributes.name==csrf_token).attributes.value").String() assert.NotEmptyf(t, csrfToken, "%s", body) - require.Equal(t, email, gjson.Get(body, "ui.nodes.#(attributes.name==traits.email).attributes.value").String()) + require.Equal(t, s.email, gjson.Get(body, "ui.nodes.#(attributes.name==traits.email).attributes.value").String()) return s } @@ -186,7 +188,7 @@ func TestRegistrationCodeStrategy(t *testing.T) { verifiableAddress, err := reg.PrivilegedIdentityPool().FindVerifiableAddressByValue(ctx, identity.VerifiableAddressTypeEmail, s.email) require.NoError(t, err) - require.Equal(t, s.email, verifiableAddress.Value) + require.Equal(t, strings.ToLower(s.email), verifiableAddress.Value) id, err := reg.PrivilegedIdentityPool().GetIdentityConfidential(ctx, verifiableAddress.IdentityID) require.NoError(t, err) @@ -195,6 +197,7 @@ func TestRegistrationCodeStrategy(t *testing.T) { _, ok := id.GetCredentials(identity.CredentialsTypeCodeAuth) require.True(t, ok) + s.resultIdentity = id return s } @@ -239,6 +242,35 @@ func TestRegistrationCodeStrategy(t *testing.T) { }, tc.isSPA, nil) }) + t.Run("case=should normalize email address on sign up", func(t *testing.T) { + ctx := context.Background() + + // 1. Initiate flow + state := createRegistrationFlow(ctx, t, public, tc.isSPA) + sourceMail := testhelpers.RandomEmail() + state.email = strings.ToUpper(sourceMail) + assert.NotEqual(t, sourceMail, state.email) + + // 2. Submit Identifier (email) + state = registerNewUser(ctx, t, state, tc.isSPA, nil) + + message := testhelpers.CourierExpectMessage(ctx, t, reg, sourceMail, "Complete your account registration") + assert.Contains(t, message.Body, "please complete your account registration by entering the following code") + + registrationCode := testhelpers.CourierExpectCodeInMessage(t, message, 1) + assert.NotEmpty(t, registrationCode) + + // 3. Submit OTP + state = submitOTP(ctx, t, reg, state, func(v *url.Values) { + v.Set("code", registrationCode) + }, tc.isSPA, nil) + + creds, ok := state.resultIdentity.GetCredentials(identity.CredentialsTypeCodeAuth) + require.True(t, ok) + require.Len(t, creds.Identifiers, 1) + assert.Equal(t, sourceMail, creds.Identifiers[0]) + }) + t.Run("case=should be able to resend the code", func(t *testing.T) { ctx := context.Background() diff --git a/selfservice/strategy/code/strategy_verification.go b/selfservice/strategy/code/strategy_verification.go index aea07b9f6ae9..c02e89105116 100644 --- a/selfservice/strategy/code/strategy_verification.go +++ b/selfservice/strategy/code/strategy_verification.go @@ -236,10 +236,6 @@ func (s *Strategy) verificationUseCode(w http.ResponseWriter, r *http.Request, c return s.retryVerificationFlowWithError(w, r, f.Type, err) } - if err := code.Validate(); err != nil { - return s.retryVerificationFlowWithError(w, r, f.Type, err) - } - i, err := s.deps.IdentityPool().GetIdentity(r.Context(), code.VerifiableAddress.IdentityID, identity.ExpandDefault) if err != nil { return s.retryVerificationFlowWithError(w, r, f.Type, err) diff --git a/selfservice/strategy/password/registration_test.go b/selfservice/strategy/password/registration_test.go index dacdef8c9ff2..1587888e0168 100644 --- a/selfservice/strategy/password/registration_test.go +++ b/selfservice/strategy/password/registration_test.go @@ -50,8 +50,8 @@ func newRegistrationRegistry(t *testing.T) *driver.RegistryDefault { } func TestRegistration(t *testing.T) { - ctx := context.Background() + t.Run("case=registration", func(t *testing.T) { reg := newRegistrationRegistry(t) conf := reg.Config() diff --git a/test/e2e/cypress/integration/profiles/code/login/error.spec.ts b/test/e2e/cypress/integration/profiles/code/login/error.spec.ts index 635e4415f4cb..9f6cdec24664 100644 --- a/test/e2e/cypress/integration/profiles/code/login/error.spec.ts +++ b/test/e2e/cypress/integration/profiles/code/login/error.spec.ts @@ -20,9 +20,9 @@ context("Login error messages with code method", () => { ].forEach(({ route, profile, app }) => { describe(`for app ${app}`, () => { before(() => { - cy.proxy(app) cy.useConfigProfile(profile) cy.deleteMail() + cy.proxy(app) }) beforeEach(() => { @@ -123,8 +123,7 @@ context("Login error messages with code method", () => { it("should show error message when code is expired", () => { cy.updateConfigFile((config) => { config.selfservice.methods.code = { - registration_enabled: true, - login_enabled: true, + passwordless_enabled: true, config: { lifespan: "1ns", }, @@ -142,7 +141,7 @@ context("Login error messages with code method", () => { cy.url().should("contain", "login") cy.get("@email").then((email) => { - cy.getLoginCodeFromEmail(email.toString()).then((code) => { + cy.getLoginCodeFromEmail(email.toString()).should((code) => { cy.get('input[name="code"]').type(code) }) }) @@ -164,8 +163,7 @@ context("Login error messages with code method", () => { cy.updateConfigFile((config) => { config.selfservice.methods.code = { - registration_enabled: true, - login_enabled: true, + passwordless_enabled: true, config: { lifespan: "1h", }, diff --git a/test/e2e/cypress/integration/profiles/code/login/success.spec.ts b/test/e2e/cypress/integration/profiles/code/login/success.spec.ts index acbce5547899..8c0b3796461a 100644 --- a/test/e2e/cypress/integration/profiles/code/login/success.spec.ts +++ b/test/e2e/cypress/integration/profiles/code/login/success.spec.ts @@ -42,7 +42,7 @@ context("Login success with code method", () => { cy.get('input[name="identifier"]').clear().type(email.toString()) cy.submitCodeForm() - cy.getLoginCodeFromEmail(email.toString()).then((code) => { + cy.getLoginCodeFromEmail(email.toString()).should((code) => { cy.get('input[name="code"]').type(code) cy.get("button[name=method][value=code]").click() @@ -68,13 +68,13 @@ context("Login success with code method", () => { cy.get('input[name="identifier"]').clear().type(email.toString()) cy.submitCodeForm() - cy.getLoginCodeFromEmail(email.toString()).then((code) => { + cy.getLoginCodeFromEmail(email.toString()).should((code) => { cy.wrap(code).as("code1") }) cy.get("button[name=resend]").click() - cy.getLoginCodeFromEmail(email.toString()).then((code) => { + cy.getLoginCodeFromEmail(email.toString()).should((code) => { cy.wrap(code).as("code2") }) @@ -133,6 +133,9 @@ context("Login success with code method", () => { "traits.email2": email2, }, }) + + // There are verification emails from the registration process in the inbox that we need to deleted + // for the assertions below to pass. cy.deleteMail({ atLeast: 1 }) cy.visit(route) @@ -140,7 +143,7 @@ context("Login success with code method", () => { cy.get('input[name="identifier"]').clear().type(email2) cy.submitCodeForm() - cy.getLoginCodeFromEmail(email2).then((code) => { + cy.getLoginCodeFromEmail(email2).should((code) => { cy.get('input[name="code"]').type(code) cy.get("button[name=method][value=code]").click() }) diff --git a/test/e2e/cypress/integration/profiles/code/registration/error.spec.ts b/test/e2e/cypress/integration/profiles/code/registration/error.spec.ts index 684019fa3cb7..a9b98f373e43 100644 --- a/test/e2e/cypress/integration/profiles/code/registration/error.spec.ts +++ b/test/e2e/cypress/integration/profiles/code/registration/error.spec.ts @@ -36,13 +36,12 @@ context("Registration error messages with code method", () => { cy.get('input[name="traits.email"]').type(email) cy.submitCodeForm() - cy.url().should("contain", "registration") cy.get('[data-testid="ui/message/1040005"]').should( "contain", "An email containing a code has been sent to the email address you provided", ) - cy.get(' input[name="code"]').type("invalid-code") + cy.get('input[name="code"]').type("invalid-code") cy.submitCodeForm() cy.get('[data-testid="ui/message/4040003"]').should( @@ -56,12 +55,15 @@ context("Registration error messages with code method", () => { cy.get('input[name="traits.email"]').type(email) cy.submitCodeForm() + cy.get('[data-testid="ui/message/1040005"]').should( + "contain", + "An email containing a code has been sent to the email address you provided", + ) - cy.url().should("contain", "registration") cy.get('input[name="traits.email"]') .clear() .type("changed-email@email.com") - cy.get(' input[name="code"]').type("invalid-code") + cy.get('input[name="code"]').type("invalid-code") cy.submitCodeForm() cy.get('[data-testid="ui/message/4000030"]').should( @@ -75,12 +77,14 @@ context("Registration error messages with code method", () => { cy.get('input[name="traits.email"]').type(email) cy.submitCodeForm() - - cy.url().should("contain", "registration") + cy.get('[data-testid="ui/message/1040005"]').should( + "contain", + "An email containing a code has been sent to the email address you provided", + ) cy.removeAttribute(['input[name="code"]'], "required") - cy.submitCodeForm() + cy.submitCodeForm() cy.get('[data-testid="ui/message/4000002"]').should( "contain", "Property code is missing", @@ -102,14 +106,18 @@ context("Registration error messages with code method", () => { config.selfservice.methods.code.config.lifespan = "1ns" return config }) + cy.visit(route) const email = gen.email() + cy.get('input[name="traits.email"]').type(email) - cy.get(' input[name="traits.email"]').type(email) cy.submitCodeForm() + cy.get('[data-testid="ui/message/1040005"]').should( + "contain", + "An email containing a code has been sent to the email address you provided", + ) - cy.url().should("contain", "registration") - cy.getRegistrationCodeFromEmail(email).then((code) => { + cy.getRegistrationCodeFromEmail(email).should((code) => { cy.get('input[name="code"]').type(code) cy.submitCodeForm() }) diff --git a/test/e2e/cypress/integration/profiles/code/registration/success.spec.ts b/test/e2e/cypress/integration/profiles/code/registration/success.spec.ts index eaf734b144ed..299d5707a7d5 100644 --- a/test/e2e/cypress/integration/profiles/code/registration/success.spec.ts +++ b/test/e2e/cypress/integration/profiles/code/registration/success.spec.ts @@ -38,31 +38,33 @@ context("Registration success with code method", () => { it("should be able to resend the registration code", async () => { const email = gen.email() - cy.get(` input[name='traits.email']`).type(email) + cy.get(`input[name='traits.email']`).type(email) cy.submitCodeForm() + cy.get('[data-testid="ui/message/1040005"]').should( + "contain", + "An email containing a code has been sent to the email address you provided", + ) - cy.url().should("contain", "registration") - - cy.getRegistrationCodeFromEmail(email).then((code) => + cy.getRegistrationCodeFromEmail(email).should((code) => cy.wrap(code).as("code1"), ) - cy.get(` input[name='traits.email']`).should("have.value", email) - cy.get(` input[name='method'][value='code'][type='hidden']`).should( + cy.get(`input[name='traits.email']`).should("have.value", email) + cy.get(`input[name='method'][value='code'][type='hidden']`).should( "exist", ) - cy.get(` button[name='resend'][value='code']`).click() + cy.get(`button[name='resend'][value='code']`).click() - cy.getRegistrationCodeFromEmail(email).then((code) => { + cy.getRegistrationCodeFromEmail(email).should((code) => { cy.wrap(code).as("code2") }) cy.get("@code1").then((code1) => { // previous code should not work cy.get('input[name="code"]').clear().type(code1.toString()) - cy.submitCodeForm() + cy.submitCodeForm() cy.get('[data-testid="ui/message/4040003"]').should( "contain.text", "The registration code is invalid or has already been used. Please try again.", @@ -89,10 +91,13 @@ context("Registration success with code method", () => { cy.get(` input[name='traits.email']`).type(email) cy.submitCodeForm() + cy.get('[data-testid="ui/message/1040005"]').should( + "contain", + "An email containing a code has been sent to the email address you provided", + ) - cy.url().should("contain", "registration") - cy.getRegistrationCodeFromEmail(email).then((code) => { - cy.get(` input[name=code]`).type(code) + cy.getRegistrationCodeFromEmail(email).should((code) => { + cy.get(`input[name=code]`).type(code) cy.get("button[name=method][value=code]").click() }) @@ -109,29 +114,28 @@ context("Registration success with code method", () => { cy.setPostCodeRegistrationHooks([]) const email = gen.email() - cy.get(` input[name='traits.email']`).type(email) + cy.get(`input[name='traits.email']`).type(email) cy.submitCodeForm() + cy.get('[data-testid="ui/message/1040005"]').should( + "contain", + "An email containing a code has been sent to the email address you provided", + ) - cy.url().should("contain", "registration") - cy.getRegistrationCodeFromEmail(email).then((code) => { - cy.get(` input[name=code]`).type(code) + cy.getRegistrationCodeFromEmail(email).should((code) => { + cy.get(`input[name=code]`).type(code) cy.get("button[name=method][value=code]").click() }) - cy.deleteMail({ atLeast: 1 }) - cy.visit(login) - cy.get(` input[name=identifier]`).type(email) + cy.get(`input[name=identifier]`).type(email) cy.get("button[name=method][value=code]").click() cy.getLoginCodeFromEmail(email).then((code) => { - cy.get(`input[name = code]`).type(code) + cy.get(`input[name=code]`).type(code) cy.get("button[name=method][value=code]").click() }) - cy.deleteMail({ atLeast: 1 }) - cy.getSession().should((session) => { const { identity } = session expect(identity.id).to.not.be.empty @@ -179,35 +183,40 @@ context("Registration success with code method", () => { cy.get(`input[name='traits.username']`).type(Math.random().toString(36)) const email = gen.email() - - cy.get(` input[name='traits.email']`).type(email) + cy.get(`input[name='traits.email']`).type(email) const email2 = gen.email() - - cy.get(` input[name='traits.email2']`).type(email2) + cy.get(`input[name='traits.email2']`).type(email2) cy.submitCodeForm() + cy.get('[data-testid="ui/message/1040005"]').should( + "contain", + "An email containing a code has been sent to the email address you provided", + ) - // intentionally use email 1 to verify the account - cy.url().should("contain", "registration") - cy.getRegistrationCodeFromEmail(email, { expectedCount: 2 }).then( + // intentionally use email 1 to sign up for the account + cy.getRegistrationCodeFromEmail(email, { expectedCount: 1 }).should( (code) => { cy.get(`input[name=code]`).type(code) cy.get("button[name=method][value=code]").click() }, ) - cy.deleteMail({ atLeast: 2 }) - cy.logout() + // There are verification emails from the registration process in the inbox that we need to deleted + // for the assertions below to pass. + cy.deleteMail({ atLeast: 1 }) + // Attempt to sign in with email 2 (should fail) cy.visit(login) - cy.get(` input[name=identifier]`).type(email2) + cy.get(`input[name=identifier]`).type(email2) cy.get("button[name=method][value=code]").click() - cy.getLoginCodeFromEmail(email2).then((code) => { + cy.getLoginCodeFromEmail(email2, { + expectedCount: 1, + }).should((code) => { cy.get(`input[name=code]`).type(code) cy.get("button[name=method][value=code]").click() }) diff --git a/test/e2e/cypress/support/commands.ts b/test/e2e/cypress/support/commands.ts index bce417aca9f1..77f8a084389a 100644 --- a/test/e2e/cypress/support/commands.ts +++ b/test/e2e/cypress/support/commands.ts @@ -384,9 +384,13 @@ Cypress.Commands.add( Cypress.Commands.add( "registerWithCode", - ({ email = gen.email(), code = undefined, traits = {}, query = {} } = {}) => { - console.log("Creating user account: ", { email }) - + ({ + email = gen.email(), + code = undefined, + traits = {}, + query = {}, + expectedMailCount = 1, + } = {}) => { cy.clearAllCookies() cy.request({ @@ -394,7 +398,6 @@ Cypress.Commands.add( method: "GET", followRedirect: false, headers: { - "Content-Type": "application/json", Accept: "application/json", }, qs: query || {}, @@ -419,7 +422,6 @@ Cypress.Commands.add( }) .then(({ body }) => { if (!code) { - console.log("registration with code", body) expect( body.ui.nodes.find( (f: UiNode) => @@ -429,14 +431,9 @@ Cypress.Commands.add( ).attributes.value, ).to.eq(email) - const expectedCount = - Object.keys(traits) - .map((k) => (k.includes("email") ? k : null)) - .filter(Boolean).length + 1 - return cy .getRegistrationCodeFromEmail(email, { - expectedCount: expectedCount, + expectedCount: expectedMailCount, }) .then((code) => { return cy.request({ @@ -1190,7 +1187,7 @@ Cypress.Commands.add( Cypress.Commands.add( "verifyEmailButExpired", ({ expect: { email }, strategy = "code" }) => { - cy.getMail().then((message) => { + cy.getMail().should((message) => { expect(message.subject).to.equal("Please verify your email address") expect(message.fromAddress.trim()).to.equal("no-reply@ory.kratos.sh") @@ -1263,28 +1260,32 @@ Cypress.Commands.add( const req = () => cy.request(`${MAIL_API}/mail`).then((response) => { expect(response.body).to.have.property("mailItems") - const count = response.body.mailItems.length + let count = response.body.mailItems.length if (count === 0 && tries < 100) { tries++ cy.wait(pollInterval) return req() } + let mailItem: any if (email) { - mailItem = response.body.mailItems.find((m: any) => + const filtered = response.body.mailItems.filter((m: any) => m.toAddresses.includes(email), ) - if (!mailItem) { - return req + + if (filtered.length === 0) { + tries++ + cy.wait(pollInterval) + return req() } + + expect(filtered.length).to.equal(expectedCount) + mailItem = filtered[0] } else { + expect(count).to.equal(expectedCount) mailItem = response.body.mailItems[0] } - console.log({ mailItems: response.body.mailItems }) - console.log({ mailItem }) - console.log({ email }) - expect(count).to.equal(expectedCount) if (removeMail) { return cy.deleteMail({ atLeast: count }).then(() => { return Promise.resolve(mailItem) @@ -1485,7 +1486,7 @@ Cypress.Commands.add("getVerificationCodeFromEmail", (email) => { Cypress.Commands.add("enableRegistrationViaCode", (enable: boolean = true) => { cy.updateConfigFile((config) => { - config.selfservice.methods.code.registration_enabled = enable + config.selfservice.methods.code.passwordless_enabled = enable return config }) }) diff --git a/test/e2e/cypress/support/config.d.ts b/test/e2e/cypress/support/config.d.ts index 762dbe4e090b..1bac43efd7b3 100644 --- a/test/e2e/cypress/support/config.d.ts +++ b/test/e2e/cypress/support/config.d.ts @@ -527,8 +527,7 @@ export interface OryKratosConfiguration2 { config?: LinkConfiguration } code?: { - login_enabled?: EnablesLoginWithCodeMethod - registration_enabled?: EnablesRegistrationWithCodeMethod + passwordless_enabled?: boolean enabled?: EnablesCodeMethod config?: CodeConfiguration } diff --git a/test/e2e/cypress/support/index.d.ts b/test/e2e/cypress/support/index.d.ts index a30442ee68af..47b66308c252 100644 --- a/test/e2e/cypress/support/index.d.ts +++ b/test/e2e/cypress/support/index.d.ts @@ -80,6 +80,7 @@ declare global { code?: string traits?: { [key: string]: any } query?: { [key: string]: string } + expectedMailCount?: number }): Chainable> /** @@ -731,7 +732,7 @@ declare global { */ getRegistrationCodeFromEmail( email: string, - opts?: { expectedCount: number }, + opts?: { expectedCount: number; removeMail?: boolean }, ): Chainable /** diff --git a/test/e2e/profiles/code/.kratos.yml b/test/e2e/profiles/code/.kratos.yml index 05820e0002e9..ec69fb050fab 100644 --- a/test/e2e/profiles/code/.kratos.yml +++ b/test/e2e/profiles/code/.kratos.yml @@ -37,8 +37,7 @@ selfservice: password: enabled: false code: - registration_enabled: true - login_enabled: true + passwordless_enabled: true enabled: true config: lifespan: 1h