From 4808d9c8ab553e30e0f469b495c641d866ef71d8 Mon Sep 17 00:00:00 2001 From: Henning Perl Date: Mon, 9 Sep 2024 14:18:22 +0200 Subject: [PATCH 01/21] feat: partial failures for batch import --- identity/handler.go | 13 +- identity/handler_test.go | 41 ++++-- identity/identity.go | 21 +-- identity/manager.go | 69 +++++++-- identity/pool.go | 6 +- persistence/sql/batch/create.go | 134 +++++++----------- .../sql/identity/persister_identity.go | 97 ++++++++++++- 7 files changed, 259 insertions(+), 122 deletions(-) diff --git a/identity/handler.go b/identity/handler.go index cf85dc792c43..c4f90cec7619 100644 --- a/identity/handler.go +++ b/identity/handler.go @@ -617,13 +617,22 @@ func (h *Handler) batchPatchIdentities(w http.ResponseWriter, r *http.Request, _ } } - if err := h.r.IdentityManager().CreateIdentities(r.Context(), identities); err != nil { + err := h.r.IdentityManager().CreateIdentities(r.Context(), identities) + partialErr := new(CreateIdentitiesError) + if err != nil && !errors.As(err, &partialErr) { h.r.Writer().WriteError(w, r, err) return } for resIdx, identitiesIdx := range indexInIdentities { if identitiesIdx != nil { - res.Identities[resIdx].IdentityID = &identities[*identitiesIdx].ID + ident := identities[*identitiesIdx] + // Check if the identity was created successfully. + if failed := partialErr.Find(ident); failed != nil { + res.Identities[resIdx].Action = ActionError + res.Identities[resIdx].Error = failed.Error + } else { + res.Identities[resIdx].IdentityID = &ident.ID + } } } diff --git a/identity/handler_test.go b/identity/handler_test.go index d97c2f73fae4..2162095c0431 100644 --- a/identity/handler_test.go +++ b/identity/handler_test.go @@ -774,34 +774,49 @@ func TestHandler(t *testing.T) { } for _, tt := range []struct { - name string - body *identity.CreateIdentityBody - expectStatus int + name string + body *identity.CreateIdentityBody }{ { - name: "missing all fields", - body: &identity.CreateIdentityBody{}, - expectStatus: http.StatusBadRequest, + name: "missing-all-fields", + body: &identity.CreateIdentityBody{}, }, { - name: "duplicate identity", - body: validCreateIdentityBody("valid-patch", 0), - expectStatus: http.StatusConflict, + name: "duplicate-identity", + body: validCreateIdentityBody("duplicate-identity", 0), }, { - name: "invalid traits", + name: "invalid-traits", body: &identity.CreateIdentityBody{ Traits: json.RawMessage(`"invalid traits"`), }, - expectStatus: http.StatusBadRequest, }, } { t.Run("invalid because "+tt.name, func(t *testing.T) { - patches := append([]*identity.BatchIdentityPatch{}, validPatches...) + validPatches := []*identity.BatchIdentityPatch{ + {Create: validCreateIdentityBody(tt.name, 0)}, + {Create: validCreateIdentityBody(tt.name, 1)}, + {Create: validCreateIdentityBody(tt.name, 2)}, + {Create: validCreateIdentityBody(tt.name, 3)}, + {Create: validCreateIdentityBody(tt.name, 4)}, + } + + patches := make([]*identity.BatchIdentityPatch, 0, len(validPatches)+1) + patches = append(patches, validPatches[0:3]...) patches = append(patches, &identity.BatchIdentityPatch{Create: tt.body}) + patches = append(patches, validPatches[3:5]...) + for i, p := range patches { + id := uuid.NewV5(uuid.Nil, fmt.Sprintf("%s-%d", tt.name, i)) + p.ID = &id + } req := &identity.BatchPatchIdentitiesBody{Identities: patches} - send(t, adminTS, "PATCH", "/identities", tt.expectStatus, req) + body := send(t, adminTS, "PATCH", "/identities", http.StatusOK, req) + var actions []string + for _, a := range body.Get("identities.#.action").Array() { + actions = append(actions, a.String()) + } + assert.Equal(t, []string{"create", "create", "create", "error", "create", "create"}, actions, body) }) } diff --git a/identity/identity.go b/identity/identity.go index 0277772c4708..d21cadb36ab3 100644 --- a/identity/identity.go +++ b/identity/identity.go @@ -11,22 +11,17 @@ import ( "sync" "time" + "github.com/gofrs/uuid" + "github.com/pkg/errors" "github.com/samber/lo" - - "github.com/tidwall/sjson" - "github.com/tidwall/gjson" - - "github.com/ory/kratos/cipher" + "github.com/tidwall/sjson" "github.com/ory/herodot" + "github.com/ory/kratos/cipher" + "github.com/ory/kratos/driver/config" "github.com/ory/x/pagination/keysetpagination" "github.com/ory/x/sqlxx" - - "github.com/ory/kratos/driver/config" - - "github.com/gofrs/uuid" - "github.com/pkg/errors" ) // An Identity's State @@ -645,6 +640,9 @@ const ( // Create this identity. ActionCreate BatchPatchAction = "create" + // Error indicates that the patch failed. + ActionError BatchPatchAction = "error" + // Future actions: // // Delete this identity. @@ -677,4 +675,7 @@ type BatchIdentityPatchResponse struct { // The ID of this patch response, if an ID was specified in the patch. PatchID *uuid.UUID `json:"patch_id,omitempty"` + + // The error message, if the action was "error". + Error *herodot.DefaultError `json:"error,omitempty"` } diff --git a/identity/manager.go b/identity/manager.go index d05fd83a4a7f..2dc41bac0e29 100644 --- a/identity/manager.go +++ b/identity/manager.go @@ -6,6 +6,7 @@ package identity import ( "context" "encoding/json" + "fmt" "reflect" "slices" "sort" @@ -326,30 +327,78 @@ func (e *ErrDuplicateCredentials) HasHints() bool { return len(e.availableCredentials) > 0 || len(e.availableOIDCProviders) > 0 || len(e.identifierHint) > 0 } +type FailedIdentity struct { + Identity *Identity + Error *herodot.DefaultError +} + +type CreateIdentitiesError struct { + Failed []*FailedIdentity +} + +func (e *CreateIdentitiesError) Error() string { + return fmt.Sprintf("create identities error: %d identities failed", len(e.Failed)) +} +func (e *CreateIdentitiesError) Contains(ident *Identity) bool { + for _, failed := range e.Failed { + if failed.Identity.ID == ident.ID { + return true + } + } + return false +} +func (e *CreateIdentitiesError) Find(ident *Identity) *FailedIdentity { + for _, failed := range e.Failed { + if failed.Identity.ID == ident.ID { + return failed + } + } + return nil +} +func (e *CreateIdentitiesError) ErrOrNil() error { + if len(e.Failed) == 0 { + return nil + } + return e +} + func (m *Manager) CreateIdentities(ctx context.Context, identities []*Identity, opts ...ManagerOption) (err error) { ctx, span := m.r.Tracer(ctx).Tracer().Start(ctx, "identity.Manager.CreateIdentities") defer otelx.End(span, &err) - for _, i := range identities { - if i.SchemaID == "" { - i.SchemaID = m.r.Config().DefaultIdentityTraitsSchemaID(ctx) + createIdentitiesError := &CreateIdentitiesError{} + validIdentities := make([]*Identity, 0, len(identities)) + for _, ident := range identities { + if ident.SchemaID == "" { + ident.SchemaID = m.r.Config().DefaultIdentityTraitsSchemaID(ctx) } o := newManagerOptions(opts) - if err := m.ValidateIdentity(ctx, i, o); err != nil { - return err + if err := m.ValidateIdentity(ctx, ident, o); err != nil { + createIdentitiesError.Failed = append(createIdentitiesError.Failed, &FailedIdentity{ + Identity: ident, + Error: herodot.ErrBadRequest.WithReasonf("%s", err).WithWrap(err), + }) + continue } + validIdentities = append(validIdentities, ident) } - if err := m.r.PrivilegedIdentityPool().CreateIdentities(ctx, identities...); err != nil { - return err + if err := m.r.PrivilegedIdentityPool().CreateIdentities(ctx, validIdentities...); err != nil { + if partialErr := new(CreateIdentitiesError); errors.As(err, &partialErr) { + createIdentitiesError.Failed = append(createIdentitiesError.Failed, partialErr.Failed...) + } else { + return err + } } - for _, i := range identities { - trace.SpanFromContext(ctx).AddEvent(events.NewIdentityCreated(ctx, i.ID)) + for _, ident := range validIdentities { + if !createIdentitiesError.Contains(ident) { + trace.SpanFromContext(ctx).AddEvent(events.NewIdentityCreated(ctx, ident.ID)) + } } - return nil + return createIdentitiesError.ErrOrNil() } func (m *Manager) requiresPrivilegedAccess(ctx context.Context, original, updated *Identity, o *ManagerOptions) (err error) { diff --git a/identity/pool.go b/identity/pool.go index 8a94aad3e075..86559f0a8a3f 100644 --- a/identity/pool.go +++ b/identity/pool.go @@ -61,9 +61,13 @@ type ( FindByCredentialsIdentifier(ctx context.Context, ct CredentialsType, match string) (*Identity, *Credentials, error) // DeleteIdentity removes an identity by its id. Will return an error - // if identity exists, backend connectivity is broken, or trait validation fails. + // if identity does not exists, or backend connectivity is broken. DeleteIdentity(context.Context, uuid.UUID) error + // DeleteIdentities removes identities by its id. Will return an error + // if any identity does not exists, or backend connectivity is broken. + DeleteIdentities(context.Context, []uuid.UUID) error + // UpdateVerifiableAddress updates an identity's verifiable address. UpdateVerifiableAddress(ctx context.Context, address *VerifiableAddress) error diff --git a/persistence/sql/batch/create.go b/persistence/sql/batch/create.go index 38254a3b2a80..67b9b49cc1f0 100644 --- a/persistence/sql/batch/create.go +++ b/persistence/sql/batch/create.go @@ -5,26 +5,23 @@ package batch import ( "context" - "database/sql" "fmt" "reflect" + "slices" "sort" "strings" "time" + "github.com/gobuffalo/pop/v6" + "github.com/gofrs/uuid" "github.com/jmoiron/sqlx/reflectx" + "github.com/pkg/errors" "go.opentelemetry.io/otel/attribute" "go.opentelemetry.io/otel/trace" "github.com/ory/x/dbal" - - "github.com/gobuffalo/pop/v6" - "github.com/gofrs/uuid" - "github.com/pkg/errors" - "github.com/ory/x/otelx" "github.com/ory/x/sqlcon" - "github.com/ory/x/sqlxx" ) @@ -42,8 +39,22 @@ type ( Tracer *otelx.Tracer Connection *pop.Connection } + + PartialConflictError[T any] struct { + Failed []*T + } ) +func (p *PartialConflictError[T]) Error() string { + return fmt.Sprintf("partial conflict error: %d models failed to insert", len(p.Failed)) +} +func (p *PartialConflictError[T]) ErrOrNil() error { + if len(p.Failed) == 0 { + return nil + } + return p +} + func buildInsertQueryArgs[T any](ctx context.Context, dialect string, mapper *reflectx.Mapper, quoter quoter, models []*T) insertQueryArgs { var ( v T @@ -73,33 +84,10 @@ func buildInsertQueryArgs[T any](ctx context.Context, dialect string, mapper *re // (?, ?, ?, ?), // (?, ?, ?, ?), // (?, ?, ?, ?) - for _, m := range models { - m := reflect.ValueOf(m) - + for range models { pl := make([]string, len(placeholderRow)) copy(pl, placeholderRow) - // There is a special case - when using CockroachDB we want to generate - // UUIDs using "gen_random_uuid()" which ends up in a VALUE statement of: - // - // (gen_random_uuid(), ?, ?, ?), - for k := range placeholderRow { - if columns[k] != "id" { - continue - } - - field := mapper.FieldByName(m, columns[k]) - val, ok := field.Interface().(uuid.UUID) - if !ok { - continue - } - - if val == uuid.Nil && dialect == dbal.DriverCockroachDB { - pl[k] = "gen_random_uuid()" - break - } - } - placeholders = append(placeholders, fmt.Sprintf("(%s)", strings.Join(pl, ", "))) } @@ -130,19 +118,6 @@ func buildInsertQueryValues[T any](dialect string, mapper *reflectx.Mapper, colu case "id": if field.Interface().(uuid.UUID) != uuid.Nil { break // breaks switch, not for - } else if dialect == dbal.DriverCockroachDB { - // This is a special case: - // 1. We're using cockroach - // 2. It's the primary key field ("ID") - // 3. A UUID was not yet set. - // - // If all these conditions meet, the VALUE statement will look as such: - // - // (gen_random_uuid(), ?, ?, ?, ...) - // - // For that reason, we do not add the ID value to the list of arguments, - // because one of the arguments is using a built-in and thus doesn't need a value. - continue // break switch, not for } id, err := uuid.NewV4() @@ -196,7 +171,7 @@ func Create[T any](ctx context.Context, p *TracerConnection, models []*T) (err e var returningClause string if conn.Dialect.Name() != dbal.DriverMySQL { // PostgreSQL, CockroachDB, SQLite support RETURNING. - returningClause = fmt.Sprintf("RETURNING %s", model.IDField()) + returningClause = fmt.Sprintf("ON CONFLICT DO NOTHING RETURNING %s", model.IDField()) } query := conn.Dialect.TranslateSQL(fmt.Sprintf( @@ -211,23 +186,30 @@ func Create[T any](ctx context.Context, p *TracerConnection, models []*T) (err e if err != nil { return sqlcon.HandleError(err) } - defer rows.Close() - // Hydrate the models from the RETURNING clause. - // - // Databases not supporting RETURNING will just return 0 rows. - count := 0 + idIdx := slices.Index(queryArgs.Columns, "id") + if idIdx == -1 { + return errors.New("id column not found") + } + var idValues []uuid.UUID + for i := idIdx; i < len(values); i += len(queryArgs.Columns) { + idValues = append(idValues, values[i].(uuid.UUID)) + } + + // Hydrate the models from the RETURNING clause. Note that MySQL, which does not + // support RETURNING, also does not have ON CONFLICT DO NOTHING, meaning that + // MySQL will always fail the whole transaction on a single record conflict. + idsInDB := make(map[uuid.UUID]struct{}) for rows.Next() { if err := rows.Err(); err != nil { return sqlcon.HandleError(err) } - - if err := setModelID(rows, pop.NewModel(models[count], ctx)); err != nil { - return err + var id uuid.UUID + if err := rows.Scan(&id); err != nil { + return errors.WithStack(err) } - count++ + idsInDB[id] = struct{}{} } - if err := rows.Err(); err != nil { return sqlcon.HandleError(err) } @@ -236,43 +218,29 @@ func Create[T any](ctx context.Context, p *TracerConnection, models []*T) (err e return sqlcon.HandleError(err) } - return sqlcon.HandleError(err) + var partialConflictError PartialConflictError[T] + for i, id := range idValues { + if _, ok := idsInDB[id]; !ok { + partialConflictError.Failed = append(partialConflictError.Failed, models[i]) + } else { + if err := setModelID(id, pop.NewModel(models[i], ctx)); err != nil { + return err + } + } + } + + return partialConflictError.ErrOrNil() } // setModelID was copy & pasted from pop. It basically sets // the primary key to the given value read from the SQL row. -func setModelID(row *sql.Rows, model *pop.Model) error { +func setModelID(id uuid.UUID, model *pop.Model) error { el := reflect.ValueOf(model.Value).Elem() fbn := el.FieldByName("ID") if !fbn.IsValid() { return errors.New("model does not have a field named id") } - - pkt, err := model.PrimaryKeyType() - if err != nil { - return errors.WithStack(err) - } - - switch pkt { - case "UUID": - var id uuid.UUID - if err := row.Scan(&id); err != nil { - return errors.WithStack(err) - } - fbn.Set(reflect.ValueOf(id)) - default: - var id interface{} - if err := row.Scan(&id); err != nil { - return errors.WithStack(err) - } - v := reflect.ValueOf(id) - switch fbn.Kind() { - case reflect.Int, reflect.Int64: - fbn.SetInt(v.Int()) - default: - fbn.Set(reflect.ValueOf(id)) - } - } + fbn.Set(reflect.ValueOf(id)) return nil } diff --git a/persistence/sql/identity/persister_identity.go b/persistence/sql/identity/persister_identity.go index 807d3d67779d..1616880091e3 100644 --- a/persistence/sql/identity/persister_identity.go +++ b/persistence/sql/identity/persister_identity.go @@ -580,15 +580,68 @@ func (p *IdentityPersister) CreateIdentities(ctx context.Context, identities ... p.normalizeAllAddressess(ctx, identities...) + failedIdentityIDs := make(map[uuid.UUID]struct{}) + if err = p.createVerifiableAddresses(ctx, tx, identities...); err != nil { - return sqlcon.HandleError(err) + if paritalErr := new(batch.PartialConflictError[identity.VerifiableAddress]); errors.As(err, &paritalErr) { + for _, k := range paritalErr.Failed { + failedIdentityIDs[k.IdentityID] = struct{}{} + } + } else { + return sqlcon.HandleError(err) + } } if err = p.createRecoveryAddresses(ctx, tx, identities...); err != nil { - return sqlcon.HandleError(err) + if paritalErr := new(batch.PartialConflictError[identity.RecoveryAddress]); errors.As(err, &paritalErr) { + for _, k := range paritalErr.Failed { + failedIdentityIDs[k.IdentityID] = struct{}{} + } + } else { + return sqlcon.HandleError(err) + } } if err = p.createIdentityCredentials(ctx, tx, identities...); err != nil { - return sqlcon.HandleError(err) + if paritalErr := new(batch.PartialConflictError[identity.Credentials]); errors.As(err, &paritalErr) { + for _, k := range paritalErr.Failed { + failedIdentityIDs[k.IdentityID] = struct{}{} + } + + } else if paritalErr := new(batch.PartialConflictError[identity.CredentialIdentifier]); errors.As(err, &paritalErr) { + for _, k := range paritalErr.Failed { + credID := k.IdentityCredentialsID + for _, ident := range identities { + for _, cred := range ident.Credentials { + if cred.ID == credID { + failedIdentityIDs[ident.ID] = struct{}{} + } + } + } + } + } else { + return sqlcon.HandleError(err) + } } + + // If any of the batch inserts failed on conflict, let's delete the corresponding + // identities and return a list of failed identities in the error. + if len(failedIdentityIDs) > 0 { + partialErr := &identity.CreateIdentitiesError{} + failedIDs := make([]uuid.UUID, 0, len(failedIdentityIDs)) + for _, ident := range identities { + if _, ok := failedIdentityIDs[ident.ID]; ok { + partialErr.Failed = append(partialErr.Failed, &identity.FailedIdentity{ + Identity: ident, + Error: herodot.ErrConflict.WithReason("This identity conflicts with another identity that already exists."), + }) + failedIDs = append(failedIDs, ident.ID) + } + } + if err := p.DeleteIdentities(ctx, failedIDs); err != nil { + return sqlcon.HandleError(err) + } + return partialErr + } + return nil }) } @@ -1031,6 +1084,44 @@ func (p *IdentityPersister) DeleteIdentity(ctx context.Context, id uuid.UUID) (e return nil } +func (p *IdentityPersister) DeleteIdentities(ctx context.Context, ids []uuid.UUID) (err error) { + stringIDs := make([]string, len(ids)) + for k, id := range ids { + stringIDs[k] = id.String() + } + ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.DeleteIdentites", + trace.WithAttributes( + attribute.StringSlice("identity.ids", stringIDs), + attribute.Stringer("network.id", p.NetworkID(ctx)))) + defer otelx.End(span, &err) + + placeholders := strings.TrimSuffix(strings.Repeat("?, ", len(ids)), ", ") + args := make([]any, 0, len(ids)+1) + for _, id := range ids { + args = append(args, id) + } + args = append(args, p.NetworkID(ctx)) + + tableName := new(identity.Identity).TableName(ctx) + if p.c.Dialect.Name() == "cockroach" { + tableName += "@primary" + } + count, err := p.GetConnection(ctx).RawQuery(fmt.Sprintf( + "DELETE FROM %s WHERE id IN (%s) AND nid = ?", + tableName, + placeholders, + ), + args..., + ).ExecWithCount() + if err != nil { + return sqlcon.HandleError(err) + } + if count != len(ids) { + return errors.WithStack(sqlcon.ErrNoRows) + } + return nil +} + func (p *IdentityPersister) GetIdentity(ctx context.Context, id uuid.UUID, expand identity.Expandables) (_ *identity.Identity, err error) { ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.GetIdentity", trace.WithAttributes( From 03b65342fae2dda748d631a1daf1c80cf8db8b25 Mon Sep 17 00:00:00 2001 From: Henning Perl Date: Mon, 9 Sep 2024 14:19:49 +0200 Subject: [PATCH 02/21] chore: regenerate sdk --- internal/client-go/.openapi-generator/FILES | 2 + internal/client-go/README.md | 1 + internal/client-go/go.sum | 1 - .../model_identity_credentials_code.go | 85 +++------- ...model_identity_credentials_code_address.go | 151 ++++++++++++++++++ .../model_identity_patch_response.go | 41 ++++- ...odel_update_login_flow_with_code_method.go | 37 +++++ internal/httpclient/.openapi-generator/FILES | 2 + internal/httpclient/README.md | 1 + .../model_identity_credentials_code.go | 85 +++------- ...model_identity_credentials_code_address.go | 151 ++++++++++++++++++ .../model_identity_patch_response.go | 41 ++++- ...odel_update_login_flow_with_code_method.go | 37 +++++ spec/api.json | 40 +++-- spec/swagger.json | 54 ++++--- 15 files changed, 558 insertions(+), 171 deletions(-) create mode 100644 internal/client-go/model_identity_credentials_code_address.go create mode 100644 internal/httpclient/model_identity_credentials_code_address.go diff --git a/internal/client-go/.openapi-generator/FILES b/internal/client-go/.openapi-generator/FILES index 5eaa392f30a9..118cf9b06463 100644 --- a/internal/client-go/.openapi-generator/FILES +++ b/internal/client-go/.openapi-generator/FILES @@ -42,6 +42,7 @@ docs/Identity.md docs/IdentityAPI.md docs/IdentityCredentials.md docs/IdentityCredentialsCode.md +docs/IdentityCredentialsCodeAddress.md docs/IdentityCredentialsOidc.md docs/IdentityCredentialsOidcProvider.md docs/IdentityCredentialsPassword.md @@ -165,6 +166,7 @@ model_health_status.go model_identity.go model_identity_credentials.go model_identity_credentials_code.go +model_identity_credentials_code_address.go model_identity_credentials_oidc.go model_identity_credentials_oidc_provider.go model_identity_credentials_password.go diff --git a/internal/client-go/README.md b/internal/client-go/README.md index 6e2097d9a320..97593523117a 100644 --- a/internal/client-go/README.md +++ b/internal/client-go/README.md @@ -166,6 +166,7 @@ Class | Method | HTTP request | Description - [Identity](docs/Identity.md) - [IdentityCredentials](docs/IdentityCredentials.md) - [IdentityCredentialsCode](docs/IdentityCredentialsCode.md) + - [IdentityCredentialsCodeAddress](docs/IdentityCredentialsCodeAddress.md) - [IdentityCredentialsOidc](docs/IdentityCredentialsOidc.md) - [IdentityCredentialsOidcProvider](docs/IdentityCredentialsOidcProvider.md) - [IdentityCredentialsPassword](docs/IdentityCredentialsPassword.md) diff --git a/internal/client-go/go.sum b/internal/client-go/go.sum index 6cc3f5911d11..c966c8ddfd0d 100644 --- a/internal/client-go/go.sum +++ b/internal/client-go/go.sum @@ -4,7 +4,6 @@ github.com/golang/protobuf v1.2.0/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5y golang.org/x/net v0.0.0-20180724234803-3673e40ba225/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= golang.org/x/net v0.0.0-20190108225652-1e06a53dbb7e h1:bRhVy7zSSasaqNksaRZiA5EEI+Ei4I1nO5Jh72wfHlg= golang.org/x/net v0.0.0-20190108225652-1e06a53dbb7e/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= -golang.org/x/oauth2 v0.0.0-20190604053449-0f29369cfe45/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw= golang.org/x/sync v0.0.0-20181221193216-37e7f081c4d4 h1:YUO/7uOKsKeq9UokNS62b8FYywz3ker1l1vDZRCRefw= golang.org/x/sync v0.0.0-20181221193216-37e7f081c4d4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= diff --git a/internal/client-go/model_identity_credentials_code.go b/internal/client-go/model_identity_credentials_code.go index 75857f31c272..53fefb6719eb 100644 --- a/internal/client-go/model_identity_credentials_code.go +++ b/internal/client-go/model_identity_credentials_code.go @@ -13,14 +13,11 @@ package client import ( "encoding/json" - "time" ) // IdentityCredentialsCode CredentialsCode represents a one time login/registration code type IdentityCredentialsCode struct { - // The type of the address for this code - AddressType *string `json:"address_type,omitempty"` - UsedAt NullableTime `json:"used_at,omitempty"` + Addresses []IdentityCredentialsCodeAddress `json:"addresses,omitempty"` } // NewIdentityCredentialsCode instantiates a new IdentityCredentialsCode object @@ -40,88 +37,42 @@ func NewIdentityCredentialsCodeWithDefaults() *IdentityCredentialsCode { return &this } -// GetAddressType returns the AddressType field value if set, zero value otherwise. -func (o *IdentityCredentialsCode) GetAddressType() string { - if o == nil || o.AddressType == nil { - var ret string +// GetAddresses returns the Addresses field value if set, zero value otherwise. +func (o *IdentityCredentialsCode) GetAddresses() []IdentityCredentialsCodeAddress { + if o == nil || o.Addresses == nil { + var ret []IdentityCredentialsCodeAddress return ret } - return *o.AddressType + return o.Addresses } -// GetAddressTypeOk returns a tuple with the AddressType field value if set, nil otherwise +// GetAddressesOk returns a tuple with the Addresses field value if set, nil otherwise // and a boolean to check if the value has been set. -func (o *IdentityCredentialsCode) GetAddressTypeOk() (*string, bool) { - if o == nil || o.AddressType == nil { +func (o *IdentityCredentialsCode) GetAddressesOk() ([]IdentityCredentialsCodeAddress, bool) { + if o == nil || o.Addresses == nil { return nil, false } - return o.AddressType, true + return o.Addresses, true } -// HasAddressType returns a boolean if a field has been set. -func (o *IdentityCredentialsCode) HasAddressType() bool { - if o != nil && o.AddressType != nil { +// HasAddresses returns a boolean if a field has been set. +func (o *IdentityCredentialsCode) HasAddresses() bool { + if o != nil && o.Addresses != nil { return true } return false } -// SetAddressType gets a reference to the given string and assigns it to the AddressType field. -func (o *IdentityCredentialsCode) SetAddressType(v string) { - o.AddressType = &v -} - -// GetUsedAt returns the UsedAt field value if set, zero value otherwise (both if not set or set to explicit null). -func (o *IdentityCredentialsCode) GetUsedAt() time.Time { - if o == nil || o.UsedAt.Get() == nil { - var ret time.Time - return ret - } - return *o.UsedAt.Get() -} - -// GetUsedAtOk returns a tuple with the UsedAt field value if set, nil otherwise -// and a boolean to check if the value has been set. -// NOTE: If the value is an explicit nil, `nil, true` will be returned -func (o *IdentityCredentialsCode) GetUsedAtOk() (*time.Time, bool) { - if o == nil { - return nil, false - } - return o.UsedAt.Get(), o.UsedAt.IsSet() -} - -// HasUsedAt returns a boolean if a field has been set. -func (o *IdentityCredentialsCode) HasUsedAt() bool { - if o != nil && o.UsedAt.IsSet() { - return true - } - - return false -} - -// SetUsedAt gets a reference to the given NullableTime and assigns it to the UsedAt field. -func (o *IdentityCredentialsCode) SetUsedAt(v time.Time) { - o.UsedAt.Set(&v) -} - -// SetUsedAtNil sets the value for UsedAt to be an explicit nil -func (o *IdentityCredentialsCode) SetUsedAtNil() { - o.UsedAt.Set(nil) -} - -// UnsetUsedAt ensures that no value is present for UsedAt, not even an explicit nil -func (o *IdentityCredentialsCode) UnsetUsedAt() { - o.UsedAt.Unset() +// SetAddresses gets a reference to the given []IdentityCredentialsCodeAddress and assigns it to the Addresses field. +func (o *IdentityCredentialsCode) SetAddresses(v []IdentityCredentialsCodeAddress) { + o.Addresses = v } func (o IdentityCredentialsCode) MarshalJSON() ([]byte, error) { toSerialize := map[string]interface{}{} - if o.AddressType != nil { - toSerialize["address_type"] = o.AddressType - } - if o.UsedAt.IsSet() { - toSerialize["used_at"] = o.UsedAt.Get() + if o.Addresses != nil { + toSerialize["addresses"] = o.Addresses } return json.Marshal(toSerialize) } diff --git a/internal/client-go/model_identity_credentials_code_address.go b/internal/client-go/model_identity_credentials_code_address.go new file mode 100644 index 000000000000..c739045e79e0 --- /dev/null +++ b/internal/client-go/model_identity_credentials_code_address.go @@ -0,0 +1,151 @@ +/* + * Ory Identities API + * + * This is the API specification for Ory Identities with features such as registration, login, recovery, account verification, profile settings, password reset, identity management, session management, email and sms delivery, and more. + * + * API version: + * Contact: office@ory.sh + */ + +// Code generated by OpenAPI Generator (https://openapi-generator.tech); DO NOT EDIT. + +package client + +import ( + "encoding/json" +) + +// IdentityCredentialsCodeAddress struct for IdentityCredentialsCodeAddress +type IdentityCredentialsCodeAddress struct { + // The address for this code + Address *string `json:"address,omitempty"` + Channel *string `json:"channel,omitempty"` +} + +// NewIdentityCredentialsCodeAddress instantiates a new IdentityCredentialsCodeAddress object +// This constructor will assign default values to properties that have it defined, +// and makes sure properties required by API are set, but the set of arguments +// will change when the set of required properties is changed +func NewIdentityCredentialsCodeAddress() *IdentityCredentialsCodeAddress { + this := IdentityCredentialsCodeAddress{} + return &this +} + +// NewIdentityCredentialsCodeAddressWithDefaults instantiates a new IdentityCredentialsCodeAddress object +// This constructor will only assign default values to properties that have it defined, +// but it doesn't guarantee that properties required by API are set +func NewIdentityCredentialsCodeAddressWithDefaults() *IdentityCredentialsCodeAddress { + this := IdentityCredentialsCodeAddress{} + return &this +} + +// GetAddress returns the Address field value if set, zero value otherwise. +func (o *IdentityCredentialsCodeAddress) GetAddress() string { + if o == nil || o.Address == nil { + var ret string + return ret + } + return *o.Address +} + +// GetAddressOk returns a tuple with the Address field value if set, nil otherwise +// and a boolean to check if the value has been set. +func (o *IdentityCredentialsCodeAddress) GetAddressOk() (*string, bool) { + if o == nil || o.Address == nil { + return nil, false + } + return o.Address, true +} + +// HasAddress returns a boolean if a field has been set. +func (o *IdentityCredentialsCodeAddress) HasAddress() bool { + if o != nil && o.Address != nil { + return true + } + + return false +} + +// SetAddress gets a reference to the given string and assigns it to the Address field. +func (o *IdentityCredentialsCodeAddress) SetAddress(v string) { + o.Address = &v +} + +// GetChannel returns the Channel field value if set, zero value otherwise. +func (o *IdentityCredentialsCodeAddress) GetChannel() string { + if o == nil || o.Channel == nil { + var ret string + return ret + } + return *o.Channel +} + +// GetChannelOk returns a tuple with the Channel field value if set, nil otherwise +// and a boolean to check if the value has been set. +func (o *IdentityCredentialsCodeAddress) GetChannelOk() (*string, bool) { + if o == nil || o.Channel == nil { + return nil, false + } + return o.Channel, true +} + +// HasChannel returns a boolean if a field has been set. +func (o *IdentityCredentialsCodeAddress) HasChannel() bool { + if o != nil && o.Channel != nil { + return true + } + + return false +} + +// SetChannel gets a reference to the given string and assigns it to the Channel field. +func (o *IdentityCredentialsCodeAddress) SetChannel(v string) { + o.Channel = &v +} + +func (o IdentityCredentialsCodeAddress) MarshalJSON() ([]byte, error) { + toSerialize := map[string]interface{}{} + if o.Address != nil { + toSerialize["address"] = o.Address + } + if o.Channel != nil { + toSerialize["channel"] = o.Channel + } + return json.Marshal(toSerialize) +} + +type NullableIdentityCredentialsCodeAddress struct { + value *IdentityCredentialsCodeAddress + isSet bool +} + +func (v NullableIdentityCredentialsCodeAddress) Get() *IdentityCredentialsCodeAddress { + return v.value +} + +func (v *NullableIdentityCredentialsCodeAddress) Set(val *IdentityCredentialsCodeAddress) { + v.value = val + v.isSet = true +} + +func (v NullableIdentityCredentialsCodeAddress) IsSet() bool { + return v.isSet +} + +func (v *NullableIdentityCredentialsCodeAddress) Unset() { + v.value = nil + v.isSet = false +} + +func NewNullableIdentityCredentialsCodeAddress(val *IdentityCredentialsCodeAddress) *NullableIdentityCredentialsCodeAddress { + return &NullableIdentityCredentialsCodeAddress{value: val, isSet: true} +} + +func (v NullableIdentityCredentialsCodeAddress) MarshalJSON() ([]byte, error) { + return json.Marshal(v.value) +} + +func (v *NullableIdentityCredentialsCodeAddress) UnmarshalJSON(src []byte) error { + v.isSet = true + return json.Unmarshal(src, &v.value) +} diff --git a/internal/client-go/model_identity_patch_response.go b/internal/client-go/model_identity_patch_response.go index 2ee305f7da81..f67224edad01 100644 --- a/internal/client-go/model_identity_patch_response.go +++ b/internal/client-go/model_identity_patch_response.go @@ -17,8 +17,9 @@ import ( // IdentityPatchResponse Response for a single identity patch type IdentityPatchResponse struct { - // The action for this specific patch create ActionCreate Create this identity. - Action *string `json:"action,omitempty"` + // The action for this specific patch create ActionCreate Create this identity. error ActionError Error indicates that the patch failed. + Action *string `json:"action,omitempty"` + Error interface{} `json:"error,omitempty"` // The identity ID payload of this patch Identity *string `json:"identity,omitempty"` // The ID of this patch response, if an ID was specified in the patch. @@ -74,6 +75,39 @@ func (o *IdentityPatchResponse) SetAction(v string) { o.Action = &v } +// GetError returns the Error field value if set, zero value otherwise (both if not set or set to explicit null). +func (o *IdentityPatchResponse) GetError() interface{} { + if o == nil { + var ret interface{} + return ret + } + return o.Error +} + +// GetErrorOk returns a tuple with the Error field value if set, nil otherwise +// and a boolean to check if the value has been set. +// NOTE: If the value is an explicit nil, `nil, true` will be returned +func (o *IdentityPatchResponse) GetErrorOk() (*interface{}, bool) { + if o == nil || o.Error == nil { + return nil, false + } + return &o.Error, true +} + +// HasError returns a boolean if a field has been set. +func (o *IdentityPatchResponse) HasError() bool { + if o != nil && o.Error != nil { + return true + } + + return false +} + +// SetError gets a reference to the given interface{} and assigns it to the Error field. +func (o *IdentityPatchResponse) SetError(v interface{}) { + o.Error = v +} + // GetIdentity returns the Identity field value if set, zero value otherwise. func (o *IdentityPatchResponse) GetIdentity() string { if o == nil || o.Identity == nil { @@ -143,6 +177,9 @@ func (o IdentityPatchResponse) MarshalJSON() ([]byte, error) { if o.Action != nil { toSerialize["action"] = o.Action } + if o.Error != nil { + toSerialize["error"] = o.Error + } if o.Identity != nil { toSerialize["identity"] = o.Identity } diff --git a/internal/client-go/model_update_login_flow_with_code_method.go b/internal/client-go/model_update_login_flow_with_code_method.go index 5833200a3ce9..06272618da90 100644 --- a/internal/client-go/model_update_login_flow_with_code_method.go +++ b/internal/client-go/model_update_login_flow_with_code_method.go @@ -17,6 +17,8 @@ import ( // UpdateLoginFlowWithCodeMethod Update Login flow using the code method type UpdateLoginFlowWithCodeMethod struct { + // Address is the address to send the code to, in case that there are multiple addresses. This field is only used in two-factor flows and is ineffective for passwordless flows. + Address *string `json:"address,omitempty"` // Code is the 6 digits code sent to the user Code *string `json:"code,omitempty"` // CSRFToken is the anti-CSRF token @@ -50,6 +52,38 @@ func NewUpdateLoginFlowWithCodeMethodWithDefaults() *UpdateLoginFlowWithCodeMeth return &this } +// GetAddress returns the Address field value if set, zero value otherwise. +func (o *UpdateLoginFlowWithCodeMethod) GetAddress() string { + if o == nil || o.Address == nil { + var ret string + return ret + } + return *o.Address +} + +// GetAddressOk returns a tuple with the Address field value if set, nil otherwise +// and a boolean to check if the value has been set. +func (o *UpdateLoginFlowWithCodeMethod) GetAddressOk() (*string, bool) { + if o == nil || o.Address == nil { + return nil, false + } + return o.Address, true +} + +// HasAddress returns a boolean if a field has been set. +func (o *UpdateLoginFlowWithCodeMethod) HasAddress() bool { + if o != nil && o.Address != nil { + return true + } + + return false +} + +// SetAddress gets a reference to the given string and assigns it to the Address field. +func (o *UpdateLoginFlowWithCodeMethod) SetAddress(v string) { + o.Address = &v +} + // GetCode returns the Code field value if set, zero value otherwise. func (o *UpdateLoginFlowWithCodeMethod) GetCode() string { if o == nil || o.Code == nil { @@ -228,6 +262,9 @@ func (o *UpdateLoginFlowWithCodeMethod) SetTransientPayload(v map[string]interfa func (o UpdateLoginFlowWithCodeMethod) MarshalJSON() ([]byte, error) { toSerialize := map[string]interface{}{} + if o.Address != nil { + toSerialize["address"] = o.Address + } if o.Code != nil { toSerialize["code"] = o.Code } diff --git a/internal/httpclient/.openapi-generator/FILES b/internal/httpclient/.openapi-generator/FILES index 5eaa392f30a9..118cf9b06463 100644 --- a/internal/httpclient/.openapi-generator/FILES +++ b/internal/httpclient/.openapi-generator/FILES @@ -42,6 +42,7 @@ docs/Identity.md docs/IdentityAPI.md docs/IdentityCredentials.md docs/IdentityCredentialsCode.md +docs/IdentityCredentialsCodeAddress.md docs/IdentityCredentialsOidc.md docs/IdentityCredentialsOidcProvider.md docs/IdentityCredentialsPassword.md @@ -165,6 +166,7 @@ model_health_status.go model_identity.go model_identity_credentials.go model_identity_credentials_code.go +model_identity_credentials_code_address.go model_identity_credentials_oidc.go model_identity_credentials_oidc_provider.go model_identity_credentials_password.go diff --git a/internal/httpclient/README.md b/internal/httpclient/README.md index 6e2097d9a320..97593523117a 100644 --- a/internal/httpclient/README.md +++ b/internal/httpclient/README.md @@ -166,6 +166,7 @@ Class | Method | HTTP request | Description - [Identity](docs/Identity.md) - [IdentityCredentials](docs/IdentityCredentials.md) - [IdentityCredentialsCode](docs/IdentityCredentialsCode.md) + - [IdentityCredentialsCodeAddress](docs/IdentityCredentialsCodeAddress.md) - [IdentityCredentialsOidc](docs/IdentityCredentialsOidc.md) - [IdentityCredentialsOidcProvider](docs/IdentityCredentialsOidcProvider.md) - [IdentityCredentialsPassword](docs/IdentityCredentialsPassword.md) diff --git a/internal/httpclient/model_identity_credentials_code.go b/internal/httpclient/model_identity_credentials_code.go index 75857f31c272..53fefb6719eb 100644 --- a/internal/httpclient/model_identity_credentials_code.go +++ b/internal/httpclient/model_identity_credentials_code.go @@ -13,14 +13,11 @@ package client import ( "encoding/json" - "time" ) // IdentityCredentialsCode CredentialsCode represents a one time login/registration code type IdentityCredentialsCode struct { - // The type of the address for this code - AddressType *string `json:"address_type,omitempty"` - UsedAt NullableTime `json:"used_at,omitempty"` + Addresses []IdentityCredentialsCodeAddress `json:"addresses,omitempty"` } // NewIdentityCredentialsCode instantiates a new IdentityCredentialsCode object @@ -40,88 +37,42 @@ func NewIdentityCredentialsCodeWithDefaults() *IdentityCredentialsCode { return &this } -// GetAddressType returns the AddressType field value if set, zero value otherwise. -func (o *IdentityCredentialsCode) GetAddressType() string { - if o == nil || o.AddressType == nil { - var ret string +// GetAddresses returns the Addresses field value if set, zero value otherwise. +func (o *IdentityCredentialsCode) GetAddresses() []IdentityCredentialsCodeAddress { + if o == nil || o.Addresses == nil { + var ret []IdentityCredentialsCodeAddress return ret } - return *o.AddressType + return o.Addresses } -// GetAddressTypeOk returns a tuple with the AddressType field value if set, nil otherwise +// GetAddressesOk returns a tuple with the Addresses field value if set, nil otherwise // and a boolean to check if the value has been set. -func (o *IdentityCredentialsCode) GetAddressTypeOk() (*string, bool) { - if o == nil || o.AddressType == nil { +func (o *IdentityCredentialsCode) GetAddressesOk() ([]IdentityCredentialsCodeAddress, bool) { + if o == nil || o.Addresses == nil { return nil, false } - return o.AddressType, true + return o.Addresses, true } -// HasAddressType returns a boolean if a field has been set. -func (o *IdentityCredentialsCode) HasAddressType() bool { - if o != nil && o.AddressType != nil { +// HasAddresses returns a boolean if a field has been set. +func (o *IdentityCredentialsCode) HasAddresses() bool { + if o != nil && o.Addresses != nil { return true } return false } -// SetAddressType gets a reference to the given string and assigns it to the AddressType field. -func (o *IdentityCredentialsCode) SetAddressType(v string) { - o.AddressType = &v -} - -// GetUsedAt returns the UsedAt field value if set, zero value otherwise (both if not set or set to explicit null). -func (o *IdentityCredentialsCode) GetUsedAt() time.Time { - if o == nil || o.UsedAt.Get() == nil { - var ret time.Time - return ret - } - return *o.UsedAt.Get() -} - -// GetUsedAtOk returns a tuple with the UsedAt field value if set, nil otherwise -// and a boolean to check if the value has been set. -// NOTE: If the value is an explicit nil, `nil, true` will be returned -func (o *IdentityCredentialsCode) GetUsedAtOk() (*time.Time, bool) { - if o == nil { - return nil, false - } - return o.UsedAt.Get(), o.UsedAt.IsSet() -} - -// HasUsedAt returns a boolean if a field has been set. -func (o *IdentityCredentialsCode) HasUsedAt() bool { - if o != nil && o.UsedAt.IsSet() { - return true - } - - return false -} - -// SetUsedAt gets a reference to the given NullableTime and assigns it to the UsedAt field. -func (o *IdentityCredentialsCode) SetUsedAt(v time.Time) { - o.UsedAt.Set(&v) -} - -// SetUsedAtNil sets the value for UsedAt to be an explicit nil -func (o *IdentityCredentialsCode) SetUsedAtNil() { - o.UsedAt.Set(nil) -} - -// UnsetUsedAt ensures that no value is present for UsedAt, not even an explicit nil -func (o *IdentityCredentialsCode) UnsetUsedAt() { - o.UsedAt.Unset() +// SetAddresses gets a reference to the given []IdentityCredentialsCodeAddress and assigns it to the Addresses field. +func (o *IdentityCredentialsCode) SetAddresses(v []IdentityCredentialsCodeAddress) { + o.Addresses = v } func (o IdentityCredentialsCode) MarshalJSON() ([]byte, error) { toSerialize := map[string]interface{}{} - if o.AddressType != nil { - toSerialize["address_type"] = o.AddressType - } - if o.UsedAt.IsSet() { - toSerialize["used_at"] = o.UsedAt.Get() + if o.Addresses != nil { + toSerialize["addresses"] = o.Addresses } return json.Marshal(toSerialize) } diff --git a/internal/httpclient/model_identity_credentials_code_address.go b/internal/httpclient/model_identity_credentials_code_address.go new file mode 100644 index 000000000000..c739045e79e0 --- /dev/null +++ b/internal/httpclient/model_identity_credentials_code_address.go @@ -0,0 +1,151 @@ +/* + * Ory Identities API + * + * This is the API specification for Ory Identities with features such as registration, login, recovery, account verification, profile settings, password reset, identity management, session management, email and sms delivery, and more. + * + * API version: + * Contact: office@ory.sh + */ + +// Code generated by OpenAPI Generator (https://openapi-generator.tech); DO NOT EDIT. + +package client + +import ( + "encoding/json" +) + +// IdentityCredentialsCodeAddress struct for IdentityCredentialsCodeAddress +type IdentityCredentialsCodeAddress struct { + // The address for this code + Address *string `json:"address,omitempty"` + Channel *string `json:"channel,omitempty"` +} + +// NewIdentityCredentialsCodeAddress instantiates a new IdentityCredentialsCodeAddress object +// This constructor will assign default values to properties that have it defined, +// and makes sure properties required by API are set, but the set of arguments +// will change when the set of required properties is changed +func NewIdentityCredentialsCodeAddress() *IdentityCredentialsCodeAddress { + this := IdentityCredentialsCodeAddress{} + return &this +} + +// NewIdentityCredentialsCodeAddressWithDefaults instantiates a new IdentityCredentialsCodeAddress object +// This constructor will only assign default values to properties that have it defined, +// but it doesn't guarantee that properties required by API are set +func NewIdentityCredentialsCodeAddressWithDefaults() *IdentityCredentialsCodeAddress { + this := IdentityCredentialsCodeAddress{} + return &this +} + +// GetAddress returns the Address field value if set, zero value otherwise. +func (o *IdentityCredentialsCodeAddress) GetAddress() string { + if o == nil || o.Address == nil { + var ret string + return ret + } + return *o.Address +} + +// GetAddressOk returns a tuple with the Address field value if set, nil otherwise +// and a boolean to check if the value has been set. +func (o *IdentityCredentialsCodeAddress) GetAddressOk() (*string, bool) { + if o == nil || o.Address == nil { + return nil, false + } + return o.Address, true +} + +// HasAddress returns a boolean if a field has been set. +func (o *IdentityCredentialsCodeAddress) HasAddress() bool { + if o != nil && o.Address != nil { + return true + } + + return false +} + +// SetAddress gets a reference to the given string and assigns it to the Address field. +func (o *IdentityCredentialsCodeAddress) SetAddress(v string) { + o.Address = &v +} + +// GetChannel returns the Channel field value if set, zero value otherwise. +func (o *IdentityCredentialsCodeAddress) GetChannel() string { + if o == nil || o.Channel == nil { + var ret string + return ret + } + return *o.Channel +} + +// GetChannelOk returns a tuple with the Channel field value if set, nil otherwise +// and a boolean to check if the value has been set. +func (o *IdentityCredentialsCodeAddress) GetChannelOk() (*string, bool) { + if o == nil || o.Channel == nil { + return nil, false + } + return o.Channel, true +} + +// HasChannel returns a boolean if a field has been set. +func (o *IdentityCredentialsCodeAddress) HasChannel() bool { + if o != nil && o.Channel != nil { + return true + } + + return false +} + +// SetChannel gets a reference to the given string and assigns it to the Channel field. +func (o *IdentityCredentialsCodeAddress) SetChannel(v string) { + o.Channel = &v +} + +func (o IdentityCredentialsCodeAddress) MarshalJSON() ([]byte, error) { + toSerialize := map[string]interface{}{} + if o.Address != nil { + toSerialize["address"] = o.Address + } + if o.Channel != nil { + toSerialize["channel"] = o.Channel + } + return json.Marshal(toSerialize) +} + +type NullableIdentityCredentialsCodeAddress struct { + value *IdentityCredentialsCodeAddress + isSet bool +} + +func (v NullableIdentityCredentialsCodeAddress) Get() *IdentityCredentialsCodeAddress { + return v.value +} + +func (v *NullableIdentityCredentialsCodeAddress) Set(val *IdentityCredentialsCodeAddress) { + v.value = val + v.isSet = true +} + +func (v NullableIdentityCredentialsCodeAddress) IsSet() bool { + return v.isSet +} + +func (v *NullableIdentityCredentialsCodeAddress) Unset() { + v.value = nil + v.isSet = false +} + +func NewNullableIdentityCredentialsCodeAddress(val *IdentityCredentialsCodeAddress) *NullableIdentityCredentialsCodeAddress { + return &NullableIdentityCredentialsCodeAddress{value: val, isSet: true} +} + +func (v NullableIdentityCredentialsCodeAddress) MarshalJSON() ([]byte, error) { + return json.Marshal(v.value) +} + +func (v *NullableIdentityCredentialsCodeAddress) UnmarshalJSON(src []byte) error { + v.isSet = true + return json.Unmarshal(src, &v.value) +} diff --git a/internal/httpclient/model_identity_patch_response.go b/internal/httpclient/model_identity_patch_response.go index 2ee305f7da81..f67224edad01 100644 --- a/internal/httpclient/model_identity_patch_response.go +++ b/internal/httpclient/model_identity_patch_response.go @@ -17,8 +17,9 @@ import ( // IdentityPatchResponse Response for a single identity patch type IdentityPatchResponse struct { - // The action for this specific patch create ActionCreate Create this identity. - Action *string `json:"action,omitempty"` + // The action for this specific patch create ActionCreate Create this identity. error ActionError Error indicates that the patch failed. + Action *string `json:"action,omitempty"` + Error interface{} `json:"error,omitempty"` // The identity ID payload of this patch Identity *string `json:"identity,omitempty"` // The ID of this patch response, if an ID was specified in the patch. @@ -74,6 +75,39 @@ func (o *IdentityPatchResponse) SetAction(v string) { o.Action = &v } +// GetError returns the Error field value if set, zero value otherwise (both if not set or set to explicit null). +func (o *IdentityPatchResponse) GetError() interface{} { + if o == nil { + var ret interface{} + return ret + } + return o.Error +} + +// GetErrorOk returns a tuple with the Error field value if set, nil otherwise +// and a boolean to check if the value has been set. +// NOTE: If the value is an explicit nil, `nil, true` will be returned +func (o *IdentityPatchResponse) GetErrorOk() (*interface{}, bool) { + if o == nil || o.Error == nil { + return nil, false + } + return &o.Error, true +} + +// HasError returns a boolean if a field has been set. +func (o *IdentityPatchResponse) HasError() bool { + if o != nil && o.Error != nil { + return true + } + + return false +} + +// SetError gets a reference to the given interface{} and assigns it to the Error field. +func (o *IdentityPatchResponse) SetError(v interface{}) { + o.Error = v +} + // GetIdentity returns the Identity field value if set, zero value otherwise. func (o *IdentityPatchResponse) GetIdentity() string { if o == nil || o.Identity == nil { @@ -143,6 +177,9 @@ func (o IdentityPatchResponse) MarshalJSON() ([]byte, error) { if o.Action != nil { toSerialize["action"] = o.Action } + if o.Error != nil { + toSerialize["error"] = o.Error + } if o.Identity != nil { toSerialize["identity"] = o.Identity } diff --git a/internal/httpclient/model_update_login_flow_with_code_method.go b/internal/httpclient/model_update_login_flow_with_code_method.go index 5833200a3ce9..06272618da90 100644 --- a/internal/httpclient/model_update_login_flow_with_code_method.go +++ b/internal/httpclient/model_update_login_flow_with_code_method.go @@ -17,6 +17,8 @@ import ( // UpdateLoginFlowWithCodeMethod Update Login flow using the code method type UpdateLoginFlowWithCodeMethod struct { + // Address is the address to send the code to, in case that there are multiple addresses. This field is only used in two-factor flows and is ineffective for passwordless flows. + Address *string `json:"address,omitempty"` // Code is the 6 digits code sent to the user Code *string `json:"code,omitempty"` // CSRFToken is the anti-CSRF token @@ -50,6 +52,38 @@ func NewUpdateLoginFlowWithCodeMethodWithDefaults() *UpdateLoginFlowWithCodeMeth return &this } +// GetAddress returns the Address field value if set, zero value otherwise. +func (o *UpdateLoginFlowWithCodeMethod) GetAddress() string { + if o == nil || o.Address == nil { + var ret string + return ret + } + return *o.Address +} + +// GetAddressOk returns a tuple with the Address field value if set, nil otherwise +// and a boolean to check if the value has been set. +func (o *UpdateLoginFlowWithCodeMethod) GetAddressOk() (*string, bool) { + if o == nil || o.Address == nil { + return nil, false + } + return o.Address, true +} + +// HasAddress returns a boolean if a field has been set. +func (o *UpdateLoginFlowWithCodeMethod) HasAddress() bool { + if o != nil && o.Address != nil { + return true + } + + return false +} + +// SetAddress gets a reference to the given string and assigns it to the Address field. +func (o *UpdateLoginFlowWithCodeMethod) SetAddress(v string) { + o.Address = &v +} + // GetCode returns the Code field value if set, zero value otherwise. func (o *UpdateLoginFlowWithCodeMethod) GetCode() string { if o == nil || o.Code == nil { @@ -228,6 +262,9 @@ func (o *UpdateLoginFlowWithCodeMethod) SetTransientPayload(v map[string]interfa func (o UpdateLoginFlowWithCodeMethod) MarshalJSON() ([]byte, error) { toSerialize := map[string]interface{}{} + if o.Address != nil { + toSerialize["address"] = o.Address + } if o.Code != nil { toSerialize["code"] = o.Code } diff --git a/spec/api.json b/spec/api.json index 8a4dbefe714e..768461e86ec4 100644 --- a/spec/api.json +++ b/spec/api.json @@ -81,6 +81,9 @@ } }, "schemas": { + "CodeChannel": { + "type": "string" + }, "DefaultError": {}, "Duration": { "description": "A Duration represents the elapsed time between two instants\nas an int64 nanosecond count. The representation limits the\nlargest representable duration to approximately 290 years.", @@ -1071,12 +1074,23 @@ "identityCredentialsCode": { "description": "CredentialsCode represents a one time login/registration code", "properties": { - "address_type": { - "description": "The type of the address for this code", + "addresses": { + "items": { + "$ref": "#/components/schemas/identityCredentialsCodeAddress" + }, + "type": "array" + } + }, + "type": "object" + }, + "identityCredentialsCodeAddress": { + "properties": { + "address": { + "description": "The address for this code", "type": "string" }, - "used_at": { - "$ref": "#/components/schemas/NullTime" + "channel": { + "$ref": "#/components/schemas/CodeChannel" } }, "type": "object" @@ -1149,12 +1163,16 @@ "description": "Response for a single identity patch", "properties": { "action": { - "description": "The action for this specific patch\ncreate ActionCreate Create this identity.", + "description": "The action for this specific patch\ncreate ActionCreate Create this identity.\nerror ActionError Error indicates that the patch failed.", "enum": [ - "create" + "create", + "error" ], "type": "string", - "x-go-enum-desc": "create ActionCreate Create this identity." + "x-go-enum-desc": "create ActionCreate Create this identity.\nerror ActionError Error indicates that the patch failed." + }, + "error": { + "$ref": "#/components/schemas/DefaultError" }, "identity": { "description": "The identity ID payload of this patch", @@ -2723,6 +2741,10 @@ "updateLoginFlowWithCodeMethod": { "description": "Update Login flow using the code method", "properties": { + "address": { + "description": "Address is the address to send the code to, in case that there are multiple addresses. This field\nis only used in two-factor flows and is ineffective for passwordless flows.", + "type": "string" + }, "code": { "description": "Code is the 6 digits code sent to the user", "type": "string" @@ -5611,7 +5633,7 @@ } }, { - "description": "Via should contain the identity's credential the code should be sent to. Only relevant in aal2 flows.", + "description": "Via should contain the identity's credential the code should be sent to. Only relevant in aal2 flows.\n\nDEPRECATED: This field is deprecated. Please remove it from your requests. The user will now see a choice\nof MFA credentials to choose from to perform the second factor instead.", "in": "query", "name": "via", "schema": { @@ -5711,7 +5733,7 @@ } }, { - "description": "Via should contain the identity's credential the code should be sent to. Only relevant in aal2 flows.", + "description": "Via should contain the identity's credential the code should be sent to. Only relevant in aal2 flows.\n\nDEPRECATED: This field is deprecated. Please remove it from your requests. The user will now see a choice\nof MFA credentials to choose from to perform the second factor instead.", "in": "query", "name": "via", "schema": { diff --git a/spec/swagger.json b/spec/swagger.json index de605124627b..1e5ec02dd281 100755 --- a/spec/swagger.json +++ b/spec/swagger.json @@ -1604,7 +1604,7 @@ }, { "type": "string", - "description": "Via should contain the identity's credential the code should be sent to. Only relevant in aal2 flows.", + "description": "Via should contain the identity's credential the code should be sent to. Only relevant in aal2 flows.\n\nDEPRECATED: This field is deprecated. Please remove it from your requests. The user will now see a choice\nof MFA credentials to choose from to perform the second factor instead.", "name": "via", "in": "query" } @@ -1685,7 +1685,7 @@ }, { "type": "string", - "description": "Via should contain the identity's credential the code should be sent to. Only relevant in aal2 flows.", + "description": "Via should contain the identity's credential the code should be sent to. Only relevant in aal2 flows.\n\nDEPRECATED: This field is deprecated. Please remove it from your requests. The user will now see a choice\nof MFA credentials to choose from to perform the second factor instead.", "name": "via", "in": "query" } @@ -3245,6 +3245,9 @@ } }, "definitions": { + "CodeChannel": { + "type": "string" + }, "DefaultError": {}, "Duration": { "description": "A Duration represents the elapsed time between two instants\nas an int64 nanosecond count. The representation limits the\nlargest representable duration to approximately 290 years.", @@ -3259,20 +3262,6 @@ "type": "object", "title": "JSONRawMessage represents a json.RawMessage that works well with JSON, SQL, and Swagger." }, - "NullTime": { - "description": "NullTime implements the [Scanner] interface so\nit can be used as a scan destination, similar to [NullString].", - "type": "object", - "title": "NullTime represents a [time.Time] that may be null.", - "properties": { - "Time": { - "type": "string", - "format": "date-time" - }, - "Valid": { - "type": "boolean" - } - } - }, "NullUUID": { "description": "NullUUID can be used with the standard sql package to represent a\nUUID value that can be NULL in the database.", "type": "object", @@ -4197,12 +4186,23 @@ "description": "CredentialsCode represents a one time login/registration code", "type": "object", "properties": { - "address_type": { - "description": "The type of the address for this code", + "addresses": { + "type": "array", + "items": { + "$ref": "#/definitions/identityCredentialsCodeAddress" + } + } + } + }, + "identityCredentialsCodeAddress": { + "type": "object", + "properties": { + "address": { + "description": "The address for this code", "type": "string" }, - "used_at": { - "$ref": "#/definitions/NullTime" + "channel": { + "$ref": "#/definitions/CodeChannel" } } }, @@ -4275,12 +4275,16 @@ "type": "object", "properties": { "action": { - "description": "The action for this specific patch\ncreate ActionCreate Create this identity.", + "description": "The action for this specific patch\ncreate ActionCreate Create this identity.\nerror ActionError Error indicates that the patch failed.", "type": "string", "enum": [ - "create" + "create", + "error" ], - "x-go-enum-desc": "create ActionCreate Create this identity." + "x-go-enum-desc": "create ActionCreate Create this identity.\nerror ActionError Error indicates that the patch failed." + }, + "error": { + "$ref": "#/definitions/DefaultError" }, "identity": { "description": "The identity ID payload of this patch", @@ -5765,6 +5769,10 @@ "csrf_token" ], "properties": { + "address": { + "description": "Address is the address to send the code to, in case that there are multiple addresses. This field\nis only used in two-factor flows and is ineffective for passwordless flows.", + "type": "string" + }, "code": { "description": "Code is the 6 digits code sent to the user", "type": "string" From 2b7e3e966b7164489ef116742c0a9d03a4bf3ae7 Mon Sep 17 00:00:00 2001 From: Henning Perl Date: Tue, 10 Sep 2024 09:35:38 +0200 Subject: [PATCH 03/21] fix: implement Unwrap() on partial create error --- identity/manager.go | 7 +++++++ persistence/sql/identity/persister_identity.go | 2 +- 2 files changed, 8 insertions(+), 1 deletion(-) diff --git a/identity/manager.go b/identity/manager.go index 16f68f792e01..fd13bc6041f2 100644 --- a/identity/manager.go +++ b/identity/manager.go @@ -340,6 +340,13 @@ type CreateIdentitiesError struct { func (e *CreateIdentitiesError) Error() string { return fmt.Sprintf("create identities error: %d identities failed", len(e.Failed)) } +func (e *CreateIdentitiesError) Unwrap() []error { + var errs []error + for _, failed := range e.Failed { + errs = append(errs, failed.Error) + } + return errs +} func (e *CreateIdentitiesError) Contains(ident *Identity) bool { for _, failed := range e.Failed { if failed.Identity.ID == ident.ID { diff --git a/persistence/sql/identity/persister_identity.go b/persistence/sql/identity/persister_identity.go index 1616880091e3..bc0fc3bff3d3 100644 --- a/persistence/sql/identity/persister_identity.go +++ b/persistence/sql/identity/persister_identity.go @@ -631,7 +631,7 @@ func (p *IdentityPersister) CreateIdentities(ctx context.Context, identities ... if _, ok := failedIdentityIDs[ident.ID]; ok { partialErr.Failed = append(partialErr.Failed, &identity.FailedIdentity{ Identity: ident, - Error: herodot.ErrConflict.WithReason("This identity conflicts with another identity that already exists."), + Error: sqlcon.ErrUniqueViolation, }) failedIDs = append(failedIDs, ident.ID) } From e2b7831a1ff7dedf36e05abd735e5aaa25178a48 Mon Sep 17 00:00:00 2001 From: Henning Perl Date: Tue, 10 Sep 2024 09:46:17 +0200 Subject: [PATCH 04/21] bump upload action --- .github/workflows/ci.yaml | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/.github/workflows/ci.yaml b/.github/workflows/ci.yaml index 9c2ee0555b42..545f208563e4 100644 --- a/.github/workflows/ci.yaml +++ b/.github/workflows/ci.yaml @@ -210,7 +210,7 @@ jobs: REACT_UI_PATH: react-ui CYPRESS_RECORD_KEY: ${{ secrets.CYPRESS_RECORD_KEY }} - if: failure() - uses: actions/upload-artifact@v2 + uses: actions/upload-artifact@v4 with: name: logs path: test/e2e/*.e2e.log @@ -320,12 +320,12 @@ jobs: NODE_UI_PATH: node-ui REACT_UI_PATH: react-ui - if: failure() - uses: actions/upload-artifact@v2 + uses: actions/upload-artifact@v4 with: name: logs path: test/e2e/*.e2e.log - if: failure() - uses: actions/upload-artifact@v2 + uses: actions/upload-artifact@v4 with: name: playwright-test-results-${{ github.sha }} path: | From 01f199de5e3f3b13d57e0a709b4be8a691d74f8e Mon Sep 17 00:00:00 2001 From: Henning Perl Date: Tue, 10 Sep 2024 09:51:56 +0200 Subject: [PATCH 05/21] fix tests --- .../Test_buildInsertQueryArgs-case=cockroach.json | 2 +- ...QueryValues-case=testModel-case=cockroach.json | 10 ---------- persistence/sql/batch/create_test.go | 15 ++++++++++++++- 3 files changed, 15 insertions(+), 12 deletions(-) delete mode 100644 persistence/sql/batch/.snapshots/Test_buildInsertQueryValues-case=testModel-case=cockroach.json diff --git a/persistence/sql/batch/.snapshots/Test_buildInsertQueryArgs-case=cockroach.json b/persistence/sql/batch/.snapshots/Test_buildInsertQueryArgs-case=cockroach.json index 4fc722f33afd..ef25e1645b65 100644 --- a/persistence/sql/batch/.snapshots/Test_buildInsertQueryArgs-case=cockroach.json +++ b/persistence/sql/batch/.snapshots/Test_buildInsertQueryArgs-case=cockroach.json @@ -11,5 +11,5 @@ "traits", "updated_at" ], - "Placeholders": "(?, ?, ?, ?, ?, ?, ?, ?),\n(?, gen_random_uuid(), ?, ?, ?, ?, ?, ?),\n(?, gen_random_uuid(), ?, ?, ?, ?, ?, ?),\n(?, ?, ?, ?, ?, ?, ?, ?),\n(?, gen_random_uuid(), ?, ?, ?, ?, ?, ?),\n(?, gen_random_uuid(), ?, ?, ?, ?, ?, ?),\n(?, ?, ?, ?, ?, ?, ?, ?),\n(?, gen_random_uuid(), ?, ?, ?, ?, ?, ?),\n(?, gen_random_uuid(), ?, ?, ?, ?, ?, ?),\n(?, ?, ?, ?, ?, ?, ?, ?)" + "Placeholders": "(?, ?, ?, ?, ?, ?, ?, ?),\n(?, ?, ?, ?, ?, ?, ?, ?),\n(?, ?, ?, ?, ?, ?, ?, ?),\n(?, ?, ?, ?, ?, ?, ?, ?),\n(?, ?, ?, ?, ?, ?, ?, ?),\n(?, ?, ?, ?, ?, ?, ?, ?),\n(?, ?, ?, ?, ?, ?, ?, ?),\n(?, ?, ?, ?, ?, ?, ?, ?),\n(?, ?, ?, ?, ?, ?, ?, ?),\n(?, ?, ?, ?, ?, ?, ?, ?)" } diff --git a/persistence/sql/batch/.snapshots/Test_buildInsertQueryValues-case=testModel-case=cockroach.json b/persistence/sql/batch/.snapshots/Test_buildInsertQueryValues-case=testModel-case=cockroach.json deleted file mode 100644 index 9c8e755cacd5..000000000000 --- a/persistence/sql/batch/.snapshots/Test_buildInsertQueryValues-case=testModel-case=cockroach.json +++ /dev/null @@ -1,10 +0,0 @@ -[ - "0001-01-01T00:00:00Z", - "0001-01-01T00:00:00Z", - "string", - 42, - null, - { - "foo": "bar" - } -] diff --git a/persistence/sql/batch/create_test.go b/persistence/sql/batch/create_test.go index f5c81664a486..31c736cdc847 100644 --- a/persistence/sql/batch/create_test.go +++ b/persistence/sql/batch/create_test.go @@ -120,7 +120,20 @@ func Test_buildInsertQueryValues(t *testing.T) { t.Run("case=cockroach", func(t *testing.T) { values, err := buildInsertQueryValues(dbal.DriverCockroachDB, mapper, []string{"created_at", "updated_at", "id", "string", "int", "null_time_ptr", "traits"}, []*testModel{model}, nowFunc) require.NoError(t, err) - snapshotx.SnapshotT(t, values) + + assert.NotNil(t, model.CreatedAt) + assert.Equal(t, model.CreatedAt, values[0]) + + assert.NotNil(t, model.UpdatedAt) + assert.Equal(t, model.UpdatedAt, values[1]) + + assert.NotZero(t, model.ID) + assert.Equal(t, model.ID, values[2]) + + assert.Equal(t, model.String, values[3]) + assert.Equal(t, model.Int, values[4]) + + assert.Nil(t, model.NullTimePtr) }) t.Run("case=others", func(t *testing.T) { From 447c5d6564baa8ea1648ed2a77ef264d332eb6f1 Mon Sep 17 00:00:00 2001 From: Henning Perl Date: Tue, 10 Sep 2024 10:14:37 +0200 Subject: [PATCH 06/21] fix: mysql --- persistence/sql/batch/create.go | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/persistence/sql/batch/create.go b/persistence/sql/batch/create.go index 67b9b49cc1f0..6c423532902d 100644 --- a/persistence/sql/batch/create.go +++ b/persistence/sql/batch/create.go @@ -186,6 +186,12 @@ func Create[T any](ctx context.Context, p *TracerConnection, models []*T) (err e if err != nil { return sqlcon.HandleError(err) } + // MySQL, which does not support RETURNING, also does not have ON CONFLICT DO + // NOTHING, meaning that MySQL will always fail the whole transaction on a single + // record conflict. + if conn.Dialect.Name() == dbal.DriverMySQL { + return sqlcon.HandleError(rows.Close()) + } idIdx := slices.Index(queryArgs.Columns, "id") if idIdx == -1 { @@ -196,9 +202,7 @@ func Create[T any](ctx context.Context, p *TracerConnection, models []*T) (err e idValues = append(idValues, values[i].(uuid.UUID)) } - // Hydrate the models from the RETURNING clause. Note that MySQL, which does not - // support RETURNING, also does not have ON CONFLICT DO NOTHING, meaning that - // MySQL will always fail the whole transaction on a single record conflict. + // Hydrate the models from the RETURNING clause. idsInDB := make(map[uuid.UUID]struct{}) for rows.Next() { if err := rows.Err(); err != nil { From cec29ca6f023ccdd84299005e9d7697c4b66ed8b Mon Sep 17 00:00:00 2001 From: Henning Perl Date: Tue, 10 Sep 2024 12:33:50 +0200 Subject: [PATCH 07/21] fix tests --- internal/client-go/go.sum | 1 + persistence/sql/batch/create.go | 6 +++++- 2 files changed, 6 insertions(+), 1 deletion(-) diff --git a/internal/client-go/go.sum b/internal/client-go/go.sum index c966c8ddfd0d..6cc3f5911d11 100644 --- a/internal/client-go/go.sum +++ b/internal/client-go/go.sum @@ -4,6 +4,7 @@ github.com/golang/protobuf v1.2.0/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5y golang.org/x/net v0.0.0-20180724234803-3673e40ba225/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= golang.org/x/net v0.0.0-20190108225652-1e06a53dbb7e h1:bRhVy7zSSasaqNksaRZiA5EEI+Ei4I1nO5Jh72wfHlg= golang.org/x/net v0.0.0-20190108225652-1e06a53dbb7e/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= +golang.org/x/oauth2 v0.0.0-20190604053449-0f29369cfe45/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw= golang.org/x/sync v0.0.0-20181221193216-37e7f081c4d4 h1:YUO/7uOKsKeq9UokNS62b8FYywz3ker1l1vDZRCRefw= golang.org/x/sync v0.0.0-20181221193216-37e7f081c4d4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= diff --git a/persistence/sql/batch/create.go b/persistence/sql/batch/create.go index 6c423532902d..e6f69aef5c8d 100644 --- a/persistence/sql/batch/create.go +++ b/persistence/sql/batch/create.go @@ -233,7 +233,11 @@ func Create[T any](ctx context.Context, p *TracerConnection, models []*T) (err e } } - return partialConflictError.ErrOrNil() + if len(partialConflictError.Failed) > 0 { + return sqlcon.ErrUniqueViolation.WithWrap(&partialConflictError) + } + + return nil } // setModelID was copy & pasted from pop. It basically sets From badcd6ea47dcc39c256b4264bb33e59157458da6 Mon Sep 17 00:00:00 2001 From: Henning Perl Date: Tue, 10 Sep 2024 14:35:44 +0200 Subject: [PATCH 08/21] fix: uploading --- .github/workflows/ci.yaml | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/.github/workflows/ci.yaml b/.github/workflows/ci.yaml index 545f208563e4..17e6e8e7efb1 100644 --- a/.github/workflows/ci.yaml +++ b/.github/workflows/ci.yaml @@ -212,7 +212,7 @@ jobs: - if: failure() uses: actions/upload-artifact@v4 with: - name: logs + name: cypress-logs path: test/e2e/*.e2e.log test-e2e-playwright: @@ -322,12 +322,12 @@ jobs: - if: failure() uses: actions/upload-artifact@v4 with: - name: logs + name: cypress-${{ matrix.database }}-logs path: test/e2e/*.e2e.log - if: failure() uses: actions/upload-artifact@v4 with: - name: playwright-test-results-${{ github.sha }} + name: playwright-test-results-${{ matrix.database }}-${{ github.sha }} path: | test/e2e/test-results/ test/e2e/playwright-report/ From 6fec60d32362b675dd82e3ef602a5f0fdb2b90f2 Mon Sep 17 00:00:00 2001 From: Henning Perl Date: Tue, 10 Sep 2024 14:36:01 +0200 Subject: [PATCH 09/21] fix: wrapping --- persistence/sql/identity/persister_identity.go | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/persistence/sql/identity/persister_identity.go b/persistence/sql/identity/persister_identity.go index bc0fc3bff3d3..b0ecf85e8a43 100644 --- a/persistence/sql/identity/persister_identity.go +++ b/persistence/sql/identity/persister_identity.go @@ -636,10 +636,14 @@ func (p *IdentityPersister) CreateIdentities(ctx context.Context, identities ... failedIDs = append(failedIDs, ident.ID) } } + // Manually roll back by deleting the identities that were inserted before the + // error occurred. if err := p.DeleteIdentities(ctx, failedIDs); err != nil { return sqlcon.HandleError(err) } - return partialErr + // Wrap the partial error with the first error that occurred, so that the caller + // can continue to handle the error either as a partial error or a full error. + return partialErr.Failed[0].Error.WithWrap(partialErr) } return nil From 04b302c4d578627040f6061b4e31d2ee2002b5bb Mon Sep 17 00:00:00 2001 From: Henning Perl Date: Tue, 10 Sep 2024 15:38:49 +0200 Subject: [PATCH 10/21] fix artifact upload names --- .github/workflows/ci.yaml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/.github/workflows/ci.yaml b/.github/workflows/ci.yaml index 17e6e8e7efb1..f36f008c5e32 100644 --- a/.github/workflows/ci.yaml +++ b/.github/workflows/ci.yaml @@ -212,7 +212,7 @@ jobs: - if: failure() uses: actions/upload-artifact@v4 with: - name: cypress-logs + name: cypress-${{ matrix.database }}-logs path: test/e2e/*.e2e.log test-e2e-playwright: @@ -322,7 +322,7 @@ jobs: - if: failure() uses: actions/upload-artifact@v4 with: - name: cypress-${{ matrix.database }}-logs + name: playwright-${{ matrix.database }}-logs path: test/e2e/*.e2e.log - if: failure() uses: actions/upload-artifact@v4 From 0ed2ed133107727e0730fcf121900dcf411635d8 Mon Sep 17 00:00:00 2001 From: Henning Perl Date: Wed, 11 Sep 2024 11:12:19 +0200 Subject: [PATCH 11/21] test: fix lookup E2E test --- .../cypress/integration/profiles/mfa/lookup.spec.ts | 12 +++++------- 1 file changed, 5 insertions(+), 7 deletions(-) diff --git a/test/e2e/cypress/integration/profiles/mfa/lookup.spec.ts b/test/e2e/cypress/integration/profiles/mfa/lookup.spec.ts index 047496a5ef37..1bb4e2f94898 100644 --- a/test/e2e/cypress/integration/profiles/mfa/lookup.spec.ts +++ b/test/e2e/cypress/integration/profiles/mfa/lookup.spec.ts @@ -192,10 +192,9 @@ context("2FA lookup secrets", () => { cy.visit(settings) cy.get('button[name="lookup_secret_reveal"]').click() cy.getLookupSecrets().should((c) => { - let newCodes = codes - newCodes[0] = "Used" - newCodes[1] = "Used" - expect(c).to.eql(newCodes) + expect(c.slice(2)).to.eql(codes.slice(2)) + expect(c[0]).to.match(/(Secret was used at )|(Used)/g) + expect(c[1]).to.match(/(Secret was used at )|(Used)/g) }) // Regenerating the codes means the old one become invalid @@ -235,9 +234,8 @@ context("2FA lookup secrets", () => { cy.visit(settings) cy.get('button[name="lookup_secret_reveal"]').click() cy.getLookupSecrets().should((c) => { - let newCodes = regenCodes - newCodes[0] = "Used" - expect(c).to.eql(newCodes) + expect(c.slice(1)).to.eql(regenCodes.slice(1)) + expect(c[0]).to.match(/(Secret was used at )|(Used)/g) }) }) From 03c127c7427f417f23a3f7309d83440fa2df89e8 Mon Sep 17 00:00:00 2001 From: Henning Perl Date: Wed, 11 Sep 2024 12:42:55 +0200 Subject: [PATCH 12/21] test: always use www.example.org as return_to URL in tests --- .../profiles/email/logout/success.spec.ts | 12 ++++++------ .../profiles/oidc/registration/success.spec.ts | 4 ++-- .../profiles/recovery/code/success.spec.ts | 4 ++-- .../profiles/recovery/link/success.spec.ts | 7 +++++-- .../profiles/recovery/return-to/success.spec.ts | 8 ++++---- .../profiles/two-steps/registration/oidc.spec.ts | 4 ++-- 6 files changed, 21 insertions(+), 18 deletions(-) diff --git a/test/e2e/cypress/integration/profiles/email/logout/success.spec.ts b/test/e2e/cypress/integration/profiles/email/logout/success.spec.ts index f13358166afb..3926f21012b1 100644 --- a/test/e2e/cypress/integration/profiles/email/logout/success.spec.ts +++ b/test/e2e/cypress/integration/profiles/email/logout/success.spec.ts @@ -75,12 +75,12 @@ context("Testing logout flows", () => { cy.visit(settings, { qs: { - return_to: "https://www.ory.sh", + return_to: "https://www.example.org", }, }) cy.get("a[href*='logout']").click() - cy.location("host").should("eq", "www.ory.sh") + cy.location("host").should("eq", "www.example.org") }) it("should be able to sign out on welcome page", () => { @@ -94,12 +94,12 @@ context("Testing logout flows", () => { cy.visit(welcome, { qs: { - return_to: "https://www.ory.sh", + return_to: "https://www.example.org", }, }) cy.get("a[href*='logout']").click() - cy.location("host").should("eq", "www.ory.sh") + cy.location("host").should("eq", "www.example.org") }) it("should be able to sign out at 2fa page", () => { @@ -122,7 +122,7 @@ context("Testing logout flows", () => { cy.logout() cy.visit(route, { qs: { - return_to: "https://www.ory.sh", + return_to: "https://www.example.org", }, }) @@ -135,7 +135,7 @@ context("Testing logout flows", () => { cy.get("a[href*='logout']").click() - cy.location("host").should("eq", "www.ory.sh") + cy.location("host").should("eq", "www.example.org") cy.useLookupSecrets(false) }) }) diff --git a/test/e2e/cypress/integration/profiles/oidc/registration/success.spec.ts b/test/e2e/cypress/integration/profiles/oidc/registration/success.spec.ts index 54adac4bdd16..132845623f31 100644 --- a/test/e2e/cypress/integration/profiles/oidc/registration/success.spec.ts +++ b/test/e2e/cypress/integration/profiles/oidc/registration/success.spec.ts @@ -194,9 +194,9 @@ context("Social Sign Up Successes", () => { app, email, website, - route: registration + "?return_to=https://www.ory.sh/", + route: registration + "?return_to=https://www.example.org/", }) - cy.location("href").should("eq", "https://www.ory.sh/") + cy.location("href").should("eq", "https://www.example.org/") cy.logout() }) diff --git a/test/e2e/cypress/integration/profiles/recovery/code/success.spec.ts b/test/e2e/cypress/integration/profiles/recovery/code/success.spec.ts index a27bca61d967..baafe5a389c7 100644 --- a/test/e2e/cypress/integration/profiles/recovery/code/success.spec.ts +++ b/test/e2e/cypress/integration/profiles/recovery/code/success.spec.ts @@ -180,7 +180,7 @@ context("Account Recovery With Code Success", () => { const identity = gen.identityWithWebsite() cy.registerApi(identity) - cy.visit(express.recovery + "?return_to=https://www.ory.sh/") + cy.visit(express.recovery + "?return_to=https://www.example.org/") cy.get("input[name='email']").type(identity.email) cy.get("button[value='code']").click() cy.get('[data-testid="ui/message/1060003"]').should( @@ -196,6 +196,6 @@ context("Account Recovery With Code Success", () => { cy.get('input[name="password"]').clear().type(gen.password()) cy.get('button[value="password"]').click() - cy.url().should("eq", "https://www.ory.sh/") + cy.url().should("eq", "https://www.example.org/") }) }) diff --git a/test/e2e/cypress/integration/profiles/recovery/link/success.spec.ts b/test/e2e/cypress/integration/profiles/recovery/link/success.spec.ts index fc4137200b59..abfa089375b8 100644 --- a/test/e2e/cypress/integration/profiles/recovery/link/success.spec.ts +++ b/test/e2e/cypress/integration/profiles/recovery/link/success.spec.ts @@ -109,7 +109,10 @@ context("Account Recovery Success", () => { const identity = gen.identityWithWebsite() cy.registerApi(identity) - cy.recoverApi({ email: identity.email, returnTo: "https://www.ory.sh/" }) + cy.recoverApi({ + email: identity.email, + returnTo: "https://www.example.org/", + }) cy.recoverEmail({ expect: identity }) @@ -120,7 +123,7 @@ context("Account Recovery Success", () => { .clear() .type(gen.password()) cy.get('button[value="password"]').click() - cy.url().should("eq", "https://www.ory.sh/") + cy.url().should("eq", "https://www.example.org/") }) it("should recover even if already logged into another account", () => { diff --git a/test/e2e/cypress/integration/profiles/recovery/return-to/success.spec.ts b/test/e2e/cypress/integration/profiles/recovery/return-to/success.spec.ts index 0fa3f12a9524..af5ffa0acd1c 100644 --- a/test/e2e/cypress/integration/profiles/recovery/return-to/success.spec.ts +++ b/test/e2e/cypress/integration/profiles/recovery/return-to/success.spec.ts @@ -63,7 +63,7 @@ context("Recovery with `return_to`", () => { } it("should return to the `return_to` url after successful account recovery and settings update", () => { - cy.visit(recovery + "?return_to=https://www.ory.sh/") + cy.visit(recovery + "?return_to=https://www.example.org/") doRecovery() cy.get('[data-testid="ui/message/1060001"]', { timeout: 30000 }).should( @@ -80,7 +80,7 @@ context("Recovery with `return_to`", () => { .type(newPassword) cy.get('button[value="password"]').click() - cy.location("hostname").should("eq", "www.ory.sh") + cy.location("hostname").should("eq", "www.example.org") }) it("should return to the `return_to` url even with mfa enabled after successful account recovery and settings update", () => { @@ -108,7 +108,7 @@ context("Recovery with `return_to`", () => { cy.logout() cy.clearAllCookies() - cy.visit(recovery + "?return_to=https://www.ory.sh/") + cy.visit(recovery + "?return_to=https://www.example.org/") doRecovery() cy.shouldShow2FAScreen() @@ -122,7 +122,7 @@ context("Recovery with `return_to`", () => { .clear() .type(newPassword) cy.get('button[value="password"]').click() - cy.location("hostname").should("eq", "www.ory.sh") + cy.location("hostname").should("eq", "www.example.org") }) }) }) diff --git a/test/e2e/cypress/integration/profiles/two-steps/registration/oidc.spec.ts b/test/e2e/cypress/integration/profiles/two-steps/registration/oidc.spec.ts index cd4cf069c556..ca2d5e57078a 100644 --- a/test/e2e/cypress/integration/profiles/two-steps/registration/oidc.spec.ts +++ b/test/e2e/cypress/integration/profiles/two-steps/registration/oidc.spec.ts @@ -194,9 +194,9 @@ context("Social Sign Up Successes", () => { app, email, website, - route: registration + "?return_to=https://www.ory.sh/", + route: registration + "?return_to=https://www.example.org/", }) - cy.location("href").should("eq", "https://www.ory.sh/") + cy.location("href").should("eq", "https://www.example.org/") cy.logout() }) From a2842a01c035423651d1a8a979a2edf4e9184cdf Mon Sep 17 00:00:00 2001 From: Henning Perl Date: Thu, 12 Sep 2024 09:23:20 +0200 Subject: [PATCH 13/21] code review --- identity/handler_test.go | 99 +++++++++---------- identity/manager.go | 52 ++++++---- .../sql/identity/persister_identity.go | 7 +- 3 files changed, 81 insertions(+), 77 deletions(-) diff --git a/identity/handler_test.go b/identity/handler_test.go index 2162095c0431..33d6a46adf87 100644 --- a/identity/handler_test.go +++ b/identity/handler_test.go @@ -759,68 +759,63 @@ func TestHandler(t *testing.T) { assert.Contains(t, res.Get("error.reason").String(), strconv.Itoa(identity.BatchPatchIdentitiesLimit), "the error reason should contain the limit") }) - t.Run("case=fails all on a bad identity", func(t *testing.T) { + t.Run("case=fails some on a bad identity", func(t *testing.T) { // Test setup: we have a list of valid identitiy patches and a list of invalid ones. // Each run adds one invalid patch to the list and sends it to the server. // --> we expect the server to fail all patches in the list. // Finally, we send just the valid patches // --> we expect the server to succeed all patches in the list. - validPatches := []*identity.BatchIdentityPatch{ - {Create: validCreateIdentityBody("valid-patch", 0)}, - {Create: validCreateIdentityBody("valid-patch", 1)}, - {Create: validCreateIdentityBody("valid-patch", 2)}, - {Create: validCreateIdentityBody("valid-patch", 3)}, - {Create: validCreateIdentityBody("valid-patch", 4)}, - } - for _, tt := range []struct { - name string - body *identity.CreateIdentityBody - }{ - { - name: "missing-all-fields", - body: &identity.CreateIdentityBody{}, - }, - { - name: "duplicate-identity", - body: validCreateIdentityBody("duplicate-identity", 0), - }, - { - name: "invalid-traits", - body: &identity.CreateIdentityBody{ - Traits: json.RawMessage(`"invalid traits"`), - }, - }, - } { - t.Run("invalid because "+tt.name, func(t *testing.T) { - validPatches := []*identity.BatchIdentityPatch{ - {Create: validCreateIdentityBody(tt.name, 0)}, - {Create: validCreateIdentityBody(tt.name, 1)}, - {Create: validCreateIdentityBody(tt.name, 2)}, - {Create: validCreateIdentityBody(tt.name, 3)}, - {Create: validCreateIdentityBody(tt.name, 4)}, - } + t.Run("case=invalid patches fail", func(t *testing.T) { + patches := []*identity.BatchIdentityPatch{ + {Create: validCreateIdentityBody("valid", 0)}, + {Create: validCreateIdentityBody("valid", 1)}, + {Create: &identity.CreateIdentityBody{}}, // <-- invalid: missing all fields + {Create: validCreateIdentityBody("valid", 2)}, + {Create: validCreateIdentityBody("valid", 0)}, // <-- duplicate + {Create: validCreateIdentityBody("valid", 3)}, + {Create: &identity.CreateIdentityBody{Traits: json.RawMessage(`"invalid traits"`)}}, // <-- invalid traits + {Create: validCreateIdentityBody("valid", 4)}, + } - patches := make([]*identity.BatchIdentityPatch, 0, len(validPatches)+1) - patches = append(patches, validPatches[0:3]...) - patches = append(patches, &identity.BatchIdentityPatch{Create: tt.body}) - patches = append(patches, validPatches[3:5]...) - for i, p := range patches { - id := uuid.NewV5(uuid.Nil, fmt.Sprintf("%s-%d", tt.name, i)) - p.ID = &id - } + // Create unique IDs for each patch + var patchIDs []string + for i, p := range patches { + id := uuid.NewV5(uuid.Nil, fmt.Sprintf("%d", i)) + p.ID = &id + patchIDs = append(patchIDs, id.String()) + } - req := &identity.BatchPatchIdentitiesBody{Identities: patches} - body := send(t, adminTS, "PATCH", "/identities", http.StatusOK, req) - var actions []string - for _, a := range body.Get("identities.#.action").Array() { - actions = append(actions, a.String()) - } - assert.Equal(t, []string{"create", "create", "create", "error", "create", "create"}, actions, body) - }) - } + req := &identity.BatchPatchIdentitiesBody{Identities: patches} + body := send(t, adminTS, "PATCH", "/identities", http.StatusOK, req) + var actions []string + for _, a := range body.Get("identities.#.action").Array() { + actions = append(actions, a.String()) + } + assert.Equal(t, + []string{"create", "create", "error", "create", "error", "create", "error", "create"}, + actions, body) + + // Check that all patch IDs are returned + for i, gotPatchID := range body.Get("identities.#.patch_id").Array() { + assert.Equal(t, patchIDs[i], gotPatchID.String()) + } + + // Check specific errors + assert.Equal(t, "Bad Request", body.Get("identities.2.error.status").String()) + assert.Equal(t, "Conflict", body.Get("identities.4.error.status").String()) + assert.Equal(t, "Bad Request", body.Get("identities.6.error.status").String()) + + }) t.Run("valid patches succeed", func(t *testing.T) { + validPatches := []*identity.BatchIdentityPatch{ + {Create: validCreateIdentityBody("valid-patch", 0)}, + {Create: validCreateIdentityBody("valid-patch", 1)}, + {Create: validCreateIdentityBody("valid-patch", 2)}, + {Create: validCreateIdentityBody("valid-patch", 3)}, + {Create: validCreateIdentityBody("valid-patch", 4)}, + } req := &identity.BatchPatchIdentitiesBody{Identities: validPatches} send(t, adminTS, "PATCH", "/identities", http.StatusOK, req) }) diff --git a/identity/manager.go b/identity/manager.go index fd13bc6041f2..485979ef773f 100644 --- a/identity/manager.go +++ b/identity/manager.go @@ -334,41 +334,56 @@ type FailedIdentity struct { } type CreateIdentitiesError struct { - Failed []*FailedIdentity + failedIdentities map[*Identity]*herodot.DefaultError } func (e *CreateIdentitiesError) Error() string { - return fmt.Sprintf("create identities error: %d identities failed", len(e.Failed)) + e.init() + return fmt.Sprintf("create identities error: %d identities failed", len(e.failedIdentities)) } func (e *CreateIdentitiesError) Unwrap() []error { + e.init() var errs []error - for _, failed := range e.Failed { - errs = append(errs, failed.Error) + for _, err := range e.failedIdentities { + errs = append(errs, err) } return errs } -func (e *CreateIdentitiesError) Contains(ident *Identity) bool { - for _, failed := range e.Failed { - if failed.Identity.ID == ident.ID { - return true - } + +func (e *CreateIdentitiesError) AddFailedIdentity(ident *Identity, err *herodot.DefaultError) { + e.init() + e.failedIdentities[ident] = err +} +func (e *CreateIdentitiesError) Merge(other *CreateIdentitiesError) { + e.init() + for k, v := range other.failedIdentities { + e.failedIdentities[k] = v } - return false +} +func (e *CreateIdentitiesError) Contains(ident *Identity) bool { + e.init() + _, found := e.failedIdentities[ident] + return found } func (e *CreateIdentitiesError) Find(ident *Identity) *FailedIdentity { - for _, failed := range e.Failed { - if failed.Identity.ID == ident.ID { - return failed - } + e.init() + if err, found := e.failedIdentities[ident]; found { + return &FailedIdentity{Identity: ident, Error: err} } + return nil } func (e *CreateIdentitiesError) ErrOrNil() error { - if len(e.Failed) == 0 { + if e.failedIdentities == nil || len(e.failedIdentities) == 0 { return nil } return e } +func (e *CreateIdentitiesError) init() { + if e.failedIdentities == nil { + e.failedIdentities = map[*Identity]*herodot.DefaultError{} + } +} func (m *Manager) CreateIdentities(ctx context.Context, identities []*Identity, opts ...ManagerOption) (err error) { ctx, span := m.r.Tracer(ctx).Tracer().Start(ctx, "identity.Manager.CreateIdentities") @@ -383,10 +398,7 @@ func (m *Manager) CreateIdentities(ctx context.Context, identities []*Identity, o := newManagerOptions(opts) if err := m.ValidateIdentity(ctx, ident, o); err != nil { - createIdentitiesError.Failed = append(createIdentitiesError.Failed, &FailedIdentity{ - Identity: ident, - Error: herodot.ErrBadRequest.WithReasonf("%s", err).WithWrap(err), - }) + createIdentitiesError.AddFailedIdentity(ident, herodot.ErrBadRequest.WithReasonf("%s", err).WithWrap(err)) continue } validIdentities = append(validIdentities, ident) @@ -394,7 +406,7 @@ func (m *Manager) CreateIdentities(ctx context.Context, identities []*Identity, if err := m.r.PrivilegedIdentityPool().CreateIdentities(ctx, validIdentities...); err != nil { if partialErr := new(CreateIdentitiesError); errors.As(err, &partialErr) { - createIdentitiesError.Failed = append(createIdentitiesError.Failed, partialErr.Failed...) + createIdentitiesError.Merge(partialErr) } else { return err } diff --git a/persistence/sql/identity/persister_identity.go b/persistence/sql/identity/persister_identity.go index b0ecf85e8a43..64643bde529f 100644 --- a/persistence/sql/identity/persister_identity.go +++ b/persistence/sql/identity/persister_identity.go @@ -629,10 +629,7 @@ func (p *IdentityPersister) CreateIdentities(ctx context.Context, identities ... failedIDs := make([]uuid.UUID, 0, len(failedIdentityIDs)) for _, ident := range identities { if _, ok := failedIdentityIDs[ident.ID]; ok { - partialErr.Failed = append(partialErr.Failed, &identity.FailedIdentity{ - Identity: ident, - Error: sqlcon.ErrUniqueViolation, - }) + partialErr.AddFailedIdentity(ident, sqlcon.ErrUniqueViolation) failedIDs = append(failedIDs, ident.ID) } } @@ -643,7 +640,7 @@ func (p *IdentityPersister) CreateIdentities(ctx context.Context, identities ... } // Wrap the partial error with the first error that occurred, so that the caller // can continue to handle the error either as a partial error or a full error. - return partialErr.Failed[0].Error.WithWrap(partialErr) + return partialErr } return nil From 50b21ed149d9619f312358dce201ec392352431f Mon Sep 17 00:00:00 2001 From: Henning Perl Date: Thu, 12 Sep 2024 09:25:54 +0200 Subject: [PATCH 14/21] code review --- persistence/sql/batch/create_test.go | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/persistence/sql/batch/create_test.go b/persistence/sql/batch/create_test.go index 31c736cdc847..60ef6af4734a 100644 --- a/persistence/sql/batch/create_test.go +++ b/persistence/sql/batch/create_test.go @@ -121,10 +121,10 @@ func Test_buildInsertQueryValues(t *testing.T) { values, err := buildInsertQueryValues(dbal.DriverCockroachDB, mapper, []string{"created_at", "updated_at", "id", "string", "int", "null_time_ptr", "traits"}, []*testModel{model}, nowFunc) require.NoError(t, err) - assert.NotNil(t, model.CreatedAt) + assert.NotZero(t, model.CreatedAt) assert.Equal(t, model.CreatedAt, values[0]) - assert.NotNil(t, model.UpdatedAt) + assert.NotZero(t, model.UpdatedAt) assert.Equal(t, model.UpdatedAt, values[1]) assert.NotZero(t, model.ID) @@ -140,10 +140,10 @@ func Test_buildInsertQueryValues(t *testing.T) { values, err := buildInsertQueryValues("other", mapper, []string{"created_at", "updated_at", "id", "string", "int", "null_time_ptr", "traits"}, []*testModel{model}, nowFunc) require.NoError(t, err) - assert.NotNil(t, model.CreatedAt) + assert.NotZero(t, model.CreatedAt) assert.Equal(t, model.CreatedAt, values[0]) - assert.NotNil(t, model.UpdatedAt) + assert.NotZero(t, model.UpdatedAt) assert.Equal(t, model.UpdatedAt, values[1]) assert.NotZero(t, model.ID) From 17ed6fe2a228b228b135f60c7b2f0dbd0c89d875 Mon Sep 17 00:00:00 2001 From: Henning Perl Date: Thu, 12 Sep 2024 09:32:24 +0200 Subject: [PATCH 15/21] fix test --- persistence/sql/batch/create_test.go | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/persistence/sql/batch/create_test.go b/persistence/sql/batch/create_test.go index 60ef6af4734a..2449cc6f41a4 100644 --- a/persistence/sql/batch/create_test.go +++ b/persistence/sql/batch/create_test.go @@ -114,17 +114,17 @@ func Test_buildInsertQueryValues(t *testing.T) { } mapper := reflectx.NewMapper("db") - nowFunc := func() time.Time { - return time.Time{} - } + frozenTime := time.Date(2023, 1, 1, 0, 0, 0, 0, time.UTC) + nowFunc := func() time.Time { return frozenTime } + t.Run("case=cockroach", func(t *testing.T) { values, err := buildInsertQueryValues(dbal.DriverCockroachDB, mapper, []string{"created_at", "updated_at", "id", "string", "int", "null_time_ptr", "traits"}, []*testModel{model}, nowFunc) require.NoError(t, err) - assert.NotZero(t, model.CreatedAt) + assert.Equal(t, frozenTime, model.CreatedAt) assert.Equal(t, model.CreatedAt, values[0]) - assert.NotZero(t, model.UpdatedAt) + assert.Equal(t, frozenTime, model.UpdatedAt) assert.Equal(t, model.UpdatedAt, values[1]) assert.NotZero(t, model.ID) @@ -140,10 +140,10 @@ func Test_buildInsertQueryValues(t *testing.T) { values, err := buildInsertQueryValues("other", mapper, []string{"created_at", "updated_at", "id", "string", "int", "null_time_ptr", "traits"}, []*testModel{model}, nowFunc) require.NoError(t, err) - assert.NotZero(t, model.CreatedAt) + assert.Equal(t, frozenTime, model.CreatedAt) assert.Equal(t, model.CreatedAt, values[0]) - assert.NotZero(t, model.UpdatedAt) + assert.Equal(t, frozenTime, model.UpdatedAt) assert.Equal(t, model.UpdatedAt, values[1]) assert.NotZero(t, model.ID) From e4913293737713fe49060d0f24233be2565222c0 Mon Sep 17 00:00:00 2001 From: Henning Perl Date: Fri, 13 Sep 2024 12:58:48 +0200 Subject: [PATCH 16/21] add batch persister test --- persistence/sql/batch/test_persister.go | 90 +++++++++++++++++++++++++ persistence/sql/persister_test.go | 5 ++ 2 files changed, 95 insertions(+) create mode 100644 persistence/sql/batch/test_persister.go diff --git a/persistence/sql/batch/test_persister.go b/persistence/sql/batch/test_persister.go new file mode 100644 index 000000000000..7d476ca60150 --- /dev/null +++ b/persistence/sql/batch/test_persister.go @@ -0,0 +1,90 @@ +package batch + +import ( + "context" + "testing" + + "github.com/gobuffalo/pop/v6" + "github.com/gofrs/uuid" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/ory/kratos/identity" + "github.com/ory/kratos/persistence" + "github.com/ory/x/dbal" + "github.com/ory/x/otelx" + "github.com/ory/x/sqlcon" +) + +func TestPersister(ctx context.Context, tracer *otelx.Tracer, p persistence.Persister) func(t *testing.T) { + return func(t *testing.T) { + t.Run("method=batch.Create", func(t *testing.T) { + + ident1 := identity.NewIdentity("") + ident1.NID = p.NetworkID(ctx) + ident2 := identity.NewIdentity("") + ident2.NID = p.NetworkID(ctx) + + // Create two identities + _ = p.Transaction(ctx, func(ctx context.Context, tx *pop.Connection) error { + conn := &TracerConnection{ + Tracer: tracer, + Connection: tx, + } + + err := Create(ctx, conn, []*identity.Identity{ident1, ident2}) + require.NoError(t, err) + + return nil + }) + + require.NotEqual(t, uuid.Nil, ident1.ID) + require.NotEqual(t, uuid.Nil, ident2.ID) + + // Create conflicting verifiable addresses + _ = p.Transaction(ctx, func(ctx context.Context, tx *pop.Connection) error { + conn := &TracerConnection{ + Tracer: tracer, + Connection: tx, + } + + err := Create(ctx, conn, []*identity.VerifiableAddress{{ + Value: "foo.1@bar.de", + IdentityID: ident1.ID, + NID: ident1.NID, + }, { + Value: "foo.2@bar.de", + IdentityID: ident1.ID, + NID: ident1.NID, + }, { + Value: "conflict@bar.de", + IdentityID: ident1.ID, + NID: ident1.NID, + }, { + Value: "foo.3@bar.de", + IdentityID: ident1.ID, + NID: ident1.NID, + }, { + Value: "conflict@bar.de", + IdentityID: ident1.ID, + NID: ident1.NID, + }, { + Value: "foo.4@bar.de", + IdentityID: ident1.ID, + NID: ident1.NID, + }}) + + assert.ErrorIs(t, err, sqlcon.ErrUniqueViolation) + + if conn.Connection.Dialect.Name() != dbal.DriverMySQL { + // MySQL does not support partial errors. + partialErr := new(PartialConflictError[identity.VerifiableAddress]) + require.ErrorAs(t, err, &partialErr) + assert.Len(t, partialErr.Failed, 1) + } + + return nil + }) + }) + } +} diff --git a/persistence/sql/persister_test.go b/persistence/sql/persister_test.go index 57593577c8c7..3029cdc51ef0 100644 --- a/persistence/sql/persister_test.go +++ b/persistence/sql/persister_test.go @@ -29,6 +29,7 @@ import ( "github.com/ory/kratos/internal" "github.com/ory/kratos/internal/testhelpers" "github.com/ory/kratos/persistence/sql" + "github.com/ory/kratos/persistence/sql/batch" sqltesthelpers "github.com/ory/kratos/persistence/sql/testhelpers" "github.com/ory/kratos/schema" errorx "github.com/ory/kratos/selfservice/errorx/test" @@ -264,6 +265,10 @@ func TestPersister(t *testing.T) { t.Parallel() continuity.TestPersister(ctx, p)(t) }) + t.Run("contract=batch.TestPersister", func(t *testing.T) { + t.Parallel() + batch.TestPersister(ctx, reg.Tracer(ctx), p)(t) + }) }) } } From 3777ccfbf5d6e18b3d178c39bce4dcea03d1ebe1 Mon Sep 17 00:00:00 2001 From: Henning Perl Date: Fri, 13 Sep 2024 13:32:13 +0200 Subject: [PATCH 17/21] chore: format --- persistence/sql/batch/test_persister.go | 3 +++ 1 file changed, 3 insertions(+) diff --git a/persistence/sql/batch/test_persister.go b/persistence/sql/batch/test_persister.go index 7d476ca60150..248498db775a 100644 --- a/persistence/sql/batch/test_persister.go +++ b/persistence/sql/batch/test_persister.go @@ -1,3 +1,6 @@ +// Copyright © 2024 Ory Corp +// SPDX-License-Identifier: Apache-2.0 + package batch import ( From 9484feb20e501e359c58d4f9ab272d677fb87ded Mon Sep 17 00:00:00 2001 From: Henning Perl Date: Mon, 16 Sep 2024 13:12:57 +0200 Subject: [PATCH 18/21] only use partial inserts on batch inserts --- internal/client-go/go.sum | 1 + ...t_buildInsertQueryArgs-case=cockroach.json | 2 +- ...yValues-case=testModel-case=cockroach.json | 10 ++ persistence/sql/batch/create.go | 149 +++++++++++++++--- persistence/sql/batch/create_test.go | 74 +++++---- persistence/sql/batch/test_persister.go | 99 +++++++----- .../sql/identity/persister_identity.go | 24 ++- 7 files changed, 263 insertions(+), 96 deletions(-) create mode 100644 persistence/sql/batch/.snapshots/Test_buildInsertQueryValues-case=testModel-case=cockroach.json diff --git a/internal/client-go/go.sum b/internal/client-go/go.sum index c966c8ddfd0d..6cc3f5911d11 100644 --- a/internal/client-go/go.sum +++ b/internal/client-go/go.sum @@ -4,6 +4,7 @@ github.com/golang/protobuf v1.2.0/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5y golang.org/x/net v0.0.0-20180724234803-3673e40ba225/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= golang.org/x/net v0.0.0-20190108225652-1e06a53dbb7e h1:bRhVy7zSSasaqNksaRZiA5EEI+Ei4I1nO5Jh72wfHlg= golang.org/x/net v0.0.0-20190108225652-1e06a53dbb7e/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= +golang.org/x/oauth2 v0.0.0-20190604053449-0f29369cfe45/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw= golang.org/x/sync v0.0.0-20181221193216-37e7f081c4d4 h1:YUO/7uOKsKeq9UokNS62b8FYywz3ker1l1vDZRCRefw= golang.org/x/sync v0.0.0-20181221193216-37e7f081c4d4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= diff --git a/persistence/sql/batch/.snapshots/Test_buildInsertQueryArgs-case=cockroach.json b/persistence/sql/batch/.snapshots/Test_buildInsertQueryArgs-case=cockroach.json index ef25e1645b65..4fc722f33afd 100644 --- a/persistence/sql/batch/.snapshots/Test_buildInsertQueryArgs-case=cockroach.json +++ b/persistence/sql/batch/.snapshots/Test_buildInsertQueryArgs-case=cockroach.json @@ -11,5 +11,5 @@ "traits", "updated_at" ], - "Placeholders": "(?, ?, ?, ?, ?, ?, ?, ?),\n(?, ?, ?, ?, ?, ?, ?, ?),\n(?, ?, ?, ?, ?, ?, ?, ?),\n(?, ?, ?, ?, ?, ?, ?, ?),\n(?, ?, ?, ?, ?, ?, ?, ?),\n(?, ?, ?, ?, ?, ?, ?, ?),\n(?, ?, ?, ?, ?, ?, ?, ?),\n(?, ?, ?, ?, ?, ?, ?, ?),\n(?, ?, ?, ?, ?, ?, ?, ?),\n(?, ?, ?, ?, ?, ?, ?, ?)" + "Placeholders": "(?, ?, ?, ?, ?, ?, ?, ?),\n(?, gen_random_uuid(), ?, ?, ?, ?, ?, ?),\n(?, gen_random_uuid(), ?, ?, ?, ?, ?, ?),\n(?, ?, ?, ?, ?, ?, ?, ?),\n(?, gen_random_uuid(), ?, ?, ?, ?, ?, ?),\n(?, gen_random_uuid(), ?, ?, ?, ?, ?, ?),\n(?, ?, ?, ?, ?, ?, ?, ?),\n(?, gen_random_uuid(), ?, ?, ?, ?, ?, ?),\n(?, gen_random_uuid(), ?, ?, ?, ?, ?, ?),\n(?, ?, ?, ?, ?, ?, ?, ?)" } diff --git a/persistence/sql/batch/.snapshots/Test_buildInsertQueryValues-case=testModel-case=cockroach.json b/persistence/sql/batch/.snapshots/Test_buildInsertQueryValues-case=testModel-case=cockroach.json new file mode 100644 index 000000000000..04c9db394e80 --- /dev/null +++ b/persistence/sql/batch/.snapshots/Test_buildInsertQueryValues-case=testModel-case=cockroach.json @@ -0,0 +1,10 @@ +[ + "2023-01-01T00:00:00Z", + "2023-01-01T00:00:00Z", + "string", + 42, + null, + { + "foo": "bar" + } +] diff --git a/persistence/sql/batch/create.go b/persistence/sql/batch/create.go index e6f69aef5c8d..9898a95ee5c6 100644 --- a/persistence/sql/batch/create.go +++ b/persistence/sql/batch/create.go @@ -5,6 +5,7 @@ package batch import ( "context" + "database/sql" "fmt" "reflect" "slices" @@ -55,7 +56,7 @@ func (p *PartialConflictError[T]) ErrOrNil() error { return p } -func buildInsertQueryArgs[T any](ctx context.Context, dialect string, mapper *reflectx.Mapper, quoter quoter, models []*T) insertQueryArgs { +func buildInsertQueryArgs[T any](ctx context.Context, models []*T, opts *createOpts) insertQueryArgs { var ( v T model = pop.NewModel(v, ctx) @@ -75,7 +76,7 @@ func buildInsertQueryArgs[T any](ctx context.Context, dialect string, mapper *re sort.Strings(columns) for _, col := range columns { - quotedColumns = append(quotedColumns, quoter.Quote(col)) + quotedColumns = append(quotedColumns, opts.quoter.Quote(col)) } // We generate a list (for every row one) of VALUE statements here that @@ -84,29 +85,51 @@ func buildInsertQueryArgs[T any](ctx context.Context, dialect string, mapper *re // (?, ?, ?, ?), // (?, ?, ?, ?), // (?, ?, ?, ?) - for range models { + for _, m := range models { + m := reflect.ValueOf(m) + pl := make([]string, len(placeholderRow)) copy(pl, placeholderRow) + // There is a special case - when using CockroachDB we want to generate + // UUIDs using "gen_random_uuid()" which ends up in a VALUE statement of: + // + // (gen_random_uuid(), ?, ?, ?), + for k := range placeholderRow { + if columns[k] != "id" { + continue + } + + field := opts.mapper.FieldByName(m, columns[k]) + val, ok := field.Interface().(uuid.UUID) + if !ok { + continue + } + + if val == uuid.Nil && opts.dialect == dbal.DriverCockroachDB && !opts.partialInserts { + pl[k] = "gen_random_uuid()" + break + } + } placeholders = append(placeholders, fmt.Sprintf("(%s)", strings.Join(pl, ", "))) } return insertQueryArgs{ - TableName: quoter.Quote(model.TableName()), + TableName: opts.quoter.Quote(model.TableName()), ColumnsDecl: strings.Join(quotedColumns, ", "), Columns: columns, Placeholders: strings.Join(placeholders, ",\n"), } } -func buildInsertQueryValues[T any](dialect string, mapper *reflectx.Mapper, columns []string, models []*T, nowFunc func() time.Time) (values []any, err error) { +func buildInsertQueryValues[T any](columns []string, models []*T, opts *createOpts) (values []any, err error) { for _, m := range models { m := reflect.ValueOf(m) - now := nowFunc() + now := opts.now() // Append model fields to args for _, c := range columns { - field := mapper.FieldByName(m, c) + field := opts.mapper.FieldByName(m, c) switch c { case "created_at": @@ -118,6 +141,19 @@ func buildInsertQueryValues[T any](dialect string, mapper *reflectx.Mapper, colu case "id": if field.Interface().(uuid.UUID) != uuid.Nil { break // breaks switch, not for + } else if opts.dialect == dbal.DriverCockroachDB && !opts.partialInserts { + // This is a special case: + // 1. We're using cockroach + // 2. It's the primary key field ("ID") + // 3. A UUID was not yet set. + // + // If all these conditions meet, the VALUE statement will look as such: + // + // (gen_random_uuid(), ?, ?, ?, ...) + // + // For that reason, we do not add the ID value to the list of arguments, + // because one of the arguments is using a built-in and thus doesn't need a value. + continue // break switch, not for } id, err := uuid.NewV4() @@ -142,9 +178,48 @@ func buildInsertQueryValues[T any](dialect string, mapper *reflectx.Mapper, colu return values, nil } +type createOpts struct { + partialInserts bool + dialect string + mapper *reflectx.Mapper + quoter quoter + now func() time.Time +} + +type CreateOpts func(*createOpts) + +// WithPartialInserts allows to insert only the models that do not conflict with +// an existing record. WithPartialInserts will also generate the IDs for the +// models before inserting them, so that the successful inserts can be correlated +// with the input models. +// +// In particular, WithPartialInserts does not work with MySQL, because it does +// not support the "RETURNING" clause. +// +// WithPartialInserts does not work with CockroachDB and gen_random_uuid(), +// because then the successful inserts cannot be correlated with the input +// models. Note: gen_random_uuid() will skip the UNIQUE constraint check, which +// needs to hit all regions in a distributed setup. Therefore, WithPartialInserts +// should not be used to insert models for only a single identity. +var WithPartialInserts CreateOpts = func(o *createOpts) { + o.partialInserts = true +} + +func newCreateOpts(conn *pop.Connection, opts ...CreateOpts) *createOpts { + o := new(createOpts) + o.dialect = conn.Dialect.Name() + o.mapper = conn.TX.Mapper + o.quoter = conn.Dialect.(quoter) + o.now = func() time.Time { return time.Now().UTC().Truncate(time.Microsecond) } + for _, f := range opts { + f(o) + } + return o +} + // Create batch-inserts the given models into the database using a single INSERT statement. // The models are either all created or none. -func Create[T any](ctx context.Context, p *TracerConnection, models []*T) (err error) { +func Create[T any](ctx context.Context, p *TracerConnection, models []*T, opts ...CreateOpts) (err error) { ctx, span := p.Tracer.Tracer().Start(ctx, "persistence.sql.batch.Create", trace.WithAttributes(attribute.Int("count", len(models)))) defer otelx.End(span, &err) @@ -157,13 +232,10 @@ func Create[T any](ctx context.Context, p *TracerConnection, models []*T) (err e model := pop.NewModel(v, ctx) conn := p.Connection - quoter, ok := conn.Dialect.(quoter) - if !ok { - return errors.Errorf("store is not a quoter: %T", conn.Store) - } + options := newCreateOpts(conn, opts...) - queryArgs := buildInsertQueryArgs(ctx, conn.Dialect.Name(), conn.TX.Mapper, quoter, models) - values, err := buildInsertQueryValues(conn.Dialect.Name(), conn.TX.Mapper, queryArgs.Columns, models, func() time.Time { return time.Now().UTC().Truncate(time.Microsecond) }) + queryArgs := buildInsertQueryArgs(ctx, models, options) + values, err := buildInsertQueryValues(queryArgs.Columns, models, options) if err != nil { return err } @@ -171,7 +243,11 @@ func Create[T any](ctx context.Context, p *TracerConnection, models []*T) (err e var returningClause string if conn.Dialect.Name() != dbal.DriverMySQL { // PostgreSQL, CockroachDB, SQLite support RETURNING. - returningClause = fmt.Sprintf("ON CONFLICT DO NOTHING RETURNING %s", model.IDField()) + if options.partialInserts { + returningClause = fmt.Sprintf("ON CONFLICT DO NOTHING RETURNING %s", model.IDField()) + } else { + returningClause = fmt.Sprintf("RETURNING %s", model.IDField()) + } } query := conn.Dialect.TranslateSQL(fmt.Sprintf( @@ -193,15 +269,36 @@ func Create[T any](ctx context.Context, p *TracerConnection, models []*T) (err e return sqlcon.HandleError(rows.Close()) } - idIdx := slices.Index(queryArgs.Columns, "id") - if idIdx == -1 { - return errors.New("id column not found") + if options.partialInserts { + return handlePartialInserts(ctx, queryArgs, values, models, rows) + } else { + return handleFullInserts(ctx, models, rows) } - var idValues []uuid.UUID - for i := idIdx; i < len(values); i += len(queryArgs.Columns) { - idValues = append(idValues, values[i].(uuid.UUID)) + +} + +func handleFullInserts[T any](ctx context.Context, models []*T, rows *sql.Rows) error { + // Hydrate the models from the RETURNING clause. + for i := 0; rows.Next(); i++ { + if err := rows.Err(); err != nil { + return sqlcon.HandleError(err) + } + var id uuid.UUID + if err := rows.Scan(&id); err != nil { + return errors.WithStack(err) + } + if err := setModelID(id, pop.NewModel(models[i], ctx)); err != nil { + return err + } + } + if err := rows.Err(); err != nil { + return sqlcon.HandleError(err) } + return sqlcon.HandleError(rows.Close()) +} + +func handlePartialInserts[T any](ctx context.Context, queryArgs insertQueryArgs, values []any, models []*T, rows *sql.Rows) error { // Hydrate the models from the RETURNING clause. idsInDB := make(map[uuid.UUID]struct{}) for rows.Next() { @@ -222,6 +319,15 @@ func Create[T any](ctx context.Context, p *TracerConnection, models []*T) (err e return sqlcon.HandleError(err) } + idIdx := slices.Index(queryArgs.Columns, "id") + if idIdx == -1 { + return errors.New("id column not found") + } + var idValues []uuid.UUID + for i := idIdx; i < len(values); i += len(queryArgs.Columns) { + idValues = append(idValues, values[i].(uuid.UUID)) + } + var partialConflictError PartialConflictError[T] for i, id := range idValues { if _, ok := idsInDB[id]; !ok { @@ -238,6 +344,7 @@ func Create[T any](ctx context.Context, p *TracerConnection, models []*T) (err e } return nil + } // setModelID was copy & pasted from pop. It basically sets diff --git a/persistence/sql/batch/create_test.go b/persistence/sql/batch/create_test.go index 2449cc6f41a4..131589478e6e 100644 --- a/persistence/sql/batch/create_test.go +++ b/persistence/sql/batch/create_test.go @@ -9,14 +9,13 @@ import ( "testing" "time" - "github.com/ory/x/dbal" - "github.com/gofrs/uuid" "github.com/jmoiron/sqlx/reflectx" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "github.com/ory/kratos/identity" + "github.com/ory/x/dbal" "github.com/ory/x/snapshotx" "github.com/ory/x/sqlxx" ) @@ -53,8 +52,11 @@ func Test_buildInsertQueryArgs(t *testing.T) { ctx := context.Background() t.Run("case=testModel", func(t *testing.T) { models := makeModels[testModel]() - mapper := reflectx.NewMapper("db") - args := buildInsertQueryArgs(ctx, "other", mapper, testQuoter{}, models) + opts := &createOpts{ + dialect: "other", + quoter: testQuoter{}, + mapper: reflectx.NewMapper("db")} + args := buildInsertQueryArgs(ctx, models, opts) snapshotx.SnapshotT(t, args) query := fmt.Sprintf("INSERT INTO %s (%s) VALUES\n%s", args.TableName, args.ColumnsDecl, args.Placeholders) @@ -73,22 +75,31 @@ func Test_buildInsertQueryArgs(t *testing.T) { t.Run("case=Identities", func(t *testing.T) { models := makeModels[identity.Identity]() - mapper := reflectx.NewMapper("db") - args := buildInsertQueryArgs(ctx, "other", mapper, testQuoter{}, models) + opts := &createOpts{ + dialect: "other", + quoter: testQuoter{}, + mapper: reflectx.NewMapper("db")} + args := buildInsertQueryArgs(ctx, models, opts) snapshotx.SnapshotT(t, args) }) t.Run("case=RecoveryAddress", func(t *testing.T) { models := makeModels[identity.RecoveryAddress]() - mapper := reflectx.NewMapper("db") - args := buildInsertQueryArgs(ctx, "other", mapper, testQuoter{}, models) + opts := &createOpts{ + dialect: "other", + quoter: testQuoter{}, + mapper: reflectx.NewMapper("db")} + args := buildInsertQueryArgs(ctx, models, opts) snapshotx.SnapshotT(t, args) }) t.Run("case=RecoveryAddress", func(t *testing.T) { models := makeModels[identity.RecoveryAddress]() - mapper := reflectx.NewMapper("db") - args := buildInsertQueryArgs(ctx, "other", mapper, testQuoter{}, models) + opts := &createOpts{ + dialect: "other", + quoter: testQuoter{}, + mapper: reflectx.NewMapper("db")} + args := buildInsertQueryArgs(ctx, models, opts) snapshotx.SnapshotT(t, args) }) @@ -99,8 +110,11 @@ func Test_buildInsertQueryArgs(t *testing.T) { models[k].ID = uuid.FromStringOrNil(fmt.Sprintf("ae0125a9-2786-4ada-82d2-d169cf75047%d", k)) } } - mapper := reflectx.NewMapper("db") - args := buildInsertQueryArgs(ctx, "cockroach", mapper, testQuoter{}, models) + opts := &createOpts{ + dialect: dbal.DriverCockroachDB, + quoter: testQuoter{}, + mapper: reflectx.NewMapper("db")} + args := buildInsertQueryArgs(ctx, models, opts) snapshotx.SnapshotT(t, args) }) } @@ -112,32 +126,32 @@ func Test_buildInsertQueryValues(t *testing.T) { Int: 42, Traits: []byte(`{"foo": "bar"}`), } - mapper := reflectx.NewMapper("db") frozenTime := time.Date(2023, 1, 1, 0, 0, 0, 0, time.UTC) - nowFunc := func() time.Time { return frozenTime } + opts := &createOpts{ + mapper: reflectx.NewMapper("db"), + quoter: testQuoter{}, + now: func() time.Time { return frozenTime }, + } t.Run("case=cockroach", func(t *testing.T) { - values, err := buildInsertQueryValues(dbal.DriverCockroachDB, mapper, []string{"created_at", "updated_at", "id", "string", "int", "null_time_ptr", "traits"}, []*testModel{model}, nowFunc) + opts.dialect = dbal.DriverCockroachDB + values, err := buildInsertQueryValues( + []string{"created_at", "updated_at", "id", "string", "int", "null_time_ptr", "traits"}, + []*testModel{model}, + opts, + ) require.NoError(t, err) - - assert.Equal(t, frozenTime, model.CreatedAt) - assert.Equal(t, model.CreatedAt, values[0]) - - assert.Equal(t, frozenTime, model.UpdatedAt) - assert.Equal(t, model.UpdatedAt, values[1]) - - assert.NotZero(t, model.ID) - assert.Equal(t, model.ID, values[2]) - - assert.Equal(t, model.String, values[3]) - assert.Equal(t, model.Int, values[4]) - - assert.Nil(t, model.NullTimePtr) + snapshotx.SnapshotT(t, values) }) t.Run("case=others", func(t *testing.T) { - values, err := buildInsertQueryValues("other", mapper, []string{"created_at", "updated_at", "id", "string", "int", "null_time_ptr", "traits"}, []*testModel{model}, nowFunc) + opts.dialect = "other" + values, err := buildInsertQueryValues( + []string{"created_at", "updated_at", "id", "string", "int", "null_time_ptr", "traits"}, + []*testModel{model}, + opts, + ) require.NoError(t, err) assert.Equal(t, frozenTime, model.CreatedAt) diff --git a/persistence/sql/batch/test_persister.go b/persistence/sql/batch/test_persister.go index 248498db775a..be0e9ac7a4c5 100644 --- a/persistence/sql/batch/test_persister.go +++ b/persistence/sql/batch/test_persister.go @@ -5,6 +5,7 @@ package batch import ( "context" + "errors" "testing" "github.com/gobuffalo/pop/v6" @@ -45,48 +46,66 @@ func TestPersister(ctx context.Context, tracer *otelx.Tracer, p persistence.Pers require.NotEqual(t, uuid.Nil, ident2.ID) // Create conflicting verifiable addresses - _ = p.Transaction(ctx, func(ctx context.Context, tx *pop.Connection) error { - conn := &TracerConnection{ - Tracer: tracer, - Connection: tx, - } + addresses := []*identity.VerifiableAddress{{ + Value: "foo.1@bar.de", + IdentityID: ident1.ID, + NID: ident1.NID, + }, { + Value: "foo.2@bar.de", + IdentityID: ident1.ID, + NID: ident1.NID, + }, { + Value: "conflict@bar.de", + IdentityID: ident1.ID, + NID: ident1.NID, + }, { + Value: "foo.3@bar.de", + IdentityID: ident1.ID, + NID: ident1.NID, + }, { + Value: "conflict@bar.de", + IdentityID: ident1.ID, + NID: ident1.NID, + }, { + Value: "foo.4@bar.de", + IdentityID: ident1.ID, + NID: ident1.NID, + }} - err := Create(ctx, conn, []*identity.VerifiableAddress{{ - Value: "foo.1@bar.de", - IdentityID: ident1.ID, - NID: ident1.NID, - }, { - Value: "foo.2@bar.de", - IdentityID: ident1.ID, - NID: ident1.NID, - }, { - Value: "conflict@bar.de", - IdentityID: ident1.ID, - NID: ident1.NID, - }, { - Value: "foo.3@bar.de", - IdentityID: ident1.ID, - NID: ident1.NID, - }, { - Value: "conflict@bar.de", - IdentityID: ident1.ID, - NID: ident1.NID, - }, { - Value: "foo.4@bar.de", - IdentityID: ident1.ID, - NID: ident1.NID, - }}) - - assert.ErrorIs(t, err, sqlcon.ErrUniqueViolation) - - if conn.Connection.Dialect.Name() != dbal.DriverMySQL { - // MySQL does not support partial errors. - partialErr := new(PartialConflictError[identity.VerifiableAddress]) - require.ErrorAs(t, err, &partialErr) - assert.Len(t, partialErr.Failed, 1) - } + t.Run("case=fails all without partial inserts", func(t *testing.T) { + _ = p.Transaction(ctx, func(ctx context.Context, tx *pop.Connection) error { + conn := &TracerConnection{ + Tracer: tracer, + Connection: tx, + } + err := Create(ctx, conn, addresses) + assert.ErrorIs(t, err, sqlcon.ErrUniqueViolation) + if partial := new(PartialConflictError[identity.VerifiableAddress]); errors.As(err, &partial) { + require.NoError(t, partial, "expected no partial error") + } + return err + }) + }) - return nil + t.Run("case=return partial error with partial inserts", func(t *testing.T) { + _ = p.Transaction(ctx, func(ctx context.Context, tx *pop.Connection) error { + conn := &TracerConnection{ + Tracer: tracer, + Connection: tx, + } + + err := Create(ctx, conn, addresses, WithPartialInserts) + assert.ErrorIs(t, err, sqlcon.ErrUniqueViolation) + + if conn.Connection.Dialect.Name() != dbal.DriverMySQL { + // MySQL does not support partial errors. + partialErr := new(PartialConflictError[identity.VerifiableAddress]) + require.ErrorAs(t, err, &partialErr) + assert.Len(t, partialErr.Failed, 1) + } + + return nil + }) }) }) } diff --git a/persistence/sql/identity/persister_identity.go b/persistence/sql/identity/persister_identity.go index 64643bde529f..9004ea0bcfbc 100644 --- a/persistence/sql/identity/persister_identity.go +++ b/persistence/sql/identity/persister_identity.go @@ -324,6 +324,11 @@ func (p *IdentityPersister) createIdentityCredentials(ctx context.Context, conn identifiers []*identity.CredentialIdentifier ) + var opts []batch.CreateOpts + if len(identities) > 1 { + opts = append(opts, batch.WithPartialInserts) + } + for _, ident := range identities { for k := range ident.Credentials { cred := ident.Credentials[k] @@ -349,7 +354,7 @@ func (p *IdentityPersister) createIdentityCredentials(ctx context.Context, conn ident.Credentials[k] = cred } } - if err = batch.Create(ctx, traceConn, credentials); err != nil { + if err = batch.Create(ctx, traceConn, credentials, opts...); err != nil { return err } @@ -377,7 +382,7 @@ func (p *IdentityPersister) createIdentityCredentials(ctx context.Context, conn } } - if err = batch.Create(ctx, traceConn, identifiers); err != nil { + if err = batch.Create(ctx, traceConn, identifiers, opts...); err != nil { return err } @@ -397,8 +402,12 @@ func (p *IdentityPersister) createVerifiableAddresses(ctx context.Context, conn work = append(work, &id.VerifiableAddresses[i]) } } + var opts []batch.CreateOpts + if len(identities) > 1 { + opts = append(opts, batch.WithPartialInserts) + } - return batch.Create(ctx, &batch.TracerConnection{Tracer: p.r.Tracer(ctx), Connection: conn}, work) + return batch.Create(ctx, &batch.TracerConnection{Tracer: p.r.Tracer(ctx), Connection: conn}, work, opts...) } func updateAssociation[T interface { @@ -509,7 +518,12 @@ func (p *IdentityPersister) createRecoveryAddresses(ctx context.Context, conn *p } } - return batch.Create(ctx, &batch.TracerConnection{Tracer: p.r.Tracer(ctx), Connection: conn}, work) + var opts []batch.CreateOpts + if len(identities) > 1 { + opts = append(opts, batch.WithPartialInserts) + } + + return batch.Create(ctx, &batch.TracerConnection{Tracer: p.r.Tracer(ctx), Connection: conn}, work, opts...) } func (p *IdentityPersister) CountIdentities(ctx context.Context) (n int64, err error) { @@ -574,6 +588,8 @@ func (p *IdentityPersister) CreateIdentities(ctx context.Context, identities ... Connection: tx, } + // Don't use batch.WithPartialInserts, because identities have no other + // constraints other than the primary key that could cause conflicts. if err := batch.Create(ctx, conn, identities); err != nil { return sqlcon.HandleError(err) } From b593fc6a5047012e7197da2fdab2acefef4e73ec Mon Sep 17 00:00:00 2001 From: Henning Perl Date: Tue, 17 Sep 2024 10:11:27 +0200 Subject: [PATCH 19/21] code review --- persistence/sql/batch/create.go | 27 ++++++++++++++------------- 1 file changed, 14 insertions(+), 13 deletions(-) diff --git a/persistence/sql/batch/create.go b/persistence/sql/batch/create.go index 9898a95ee5c6..7abab0c68253 100644 --- a/persistence/sql/batch/create.go +++ b/persistence/sql/batch/create.go @@ -262,6 +262,8 @@ func Create[T any](ctx context.Context, p *TracerConnection, models []*T, opts . if err != nil { return sqlcon.HandleError(err) } + defer rows.Close() + // MySQL, which does not support RETURNING, also does not have ON CONFLICT DO // NOTHING, meaning that MySQL will always fail the whole transaction on a single // record conflict. @@ -270,14 +272,14 @@ func Create[T any](ctx context.Context, p *TracerConnection, models []*T, opts . } if options.partialInserts { - return handlePartialInserts(ctx, queryArgs, values, models, rows) + return handlePartialInserts(queryArgs, values, models, rows) } else { - return handleFullInserts(ctx, models, rows) + return handleFullInserts(models, rows) } } -func handleFullInserts[T any](ctx context.Context, models []*T, rows *sql.Rows) error { +func handleFullInserts[T any](models []*T, rows *sql.Rows) error { // Hydrate the models from the RETURNING clause. for i := 0; rows.Next(); i++ { if err := rows.Err(); err != nil { @@ -287,7 +289,7 @@ func handleFullInserts[T any](ctx context.Context, models []*T, rows *sql.Rows) if err := rows.Scan(&id); err != nil { return errors.WithStack(err) } - if err := setModelID(id, pop.NewModel(models[i], ctx)); err != nil { + if err := setModelID(id, models[i]); err != nil { return err } } @@ -298,7 +300,7 @@ func handleFullInserts[T any](ctx context.Context, models []*T, rows *sql.Rows) return sqlcon.HandleError(rows.Close()) } -func handlePartialInserts[T any](ctx context.Context, queryArgs insertQueryArgs, values []any, models []*T, rows *sql.Rows) error { +func handlePartialInserts[T any](queryArgs insertQueryArgs, values []any, models []*T, rows *sql.Rows) error { // Hydrate the models from the RETURNING clause. idsInDB := make(map[uuid.UUID]struct{}) for rows.Next() { @@ -333,7 +335,7 @@ func handlePartialInserts[T any](ctx context.Context, queryArgs insertQueryArgs, if _, ok := idsInDB[id]; !ok { partialConflictError.Failed = append(partialConflictError.Failed, models[i]) } else { - if err := setModelID(id, pop.NewModel(models[i], ctx)); err != nil { + if err := setModelID(id, models[i]); err != nil { return err } } @@ -347,15 +349,14 @@ func handlePartialInserts[T any](ctx context.Context, queryArgs insertQueryArgs, } -// setModelID was copy & pasted from pop. It basically sets -// the primary key to the given value read from the SQL row. -func setModelID(id uuid.UUID, model *pop.Model) error { - el := reflect.ValueOf(model.Value).Elem() - fbn := el.FieldByName("ID") - if !fbn.IsValid() { +// setModelID sets the id field of the model to the id. +func setModelID(id uuid.UUID, model any) error { + el := reflect.ValueOf(model).Elem() + idField := el.FieldByName("ID") + if !idField.IsValid() { return errors.New("model does not have a field named id") } - fbn.Set(reflect.ValueOf(id)) + idField.Set(reflect.ValueOf(id)) return nil } From 99a010c89c78a7831411e8ebc464994e7754b921 Mon Sep 17 00:00:00 2001 From: Henning Perl Date: Tue, 17 Sep 2024 11:41:23 +0200 Subject: [PATCH 20/21] fix: usage of rows.Err and rows.Close --- persistence/sql/batch/create.go | 15 ++------------- 1 file changed, 2 insertions(+), 13 deletions(-) diff --git a/persistence/sql/batch/create.go b/persistence/sql/batch/create.go index 7abab0c68253..01bc85a732a0 100644 --- a/persistence/sql/batch/create.go +++ b/persistence/sql/batch/create.go @@ -268,7 +268,7 @@ func Create[T any](ctx context.Context, p *TracerConnection, models []*T, opts . // NOTHING, meaning that MySQL will always fail the whole transaction on a single // record conflict. if conn.Dialect.Name() == dbal.DriverMySQL { - return sqlcon.HandleError(rows.Close()) + return nil } if options.partialInserts { @@ -282,9 +282,6 @@ func Create[T any](ctx context.Context, p *TracerConnection, models []*T, opts . func handleFullInserts[T any](models []*T, rows *sql.Rows) error { // Hydrate the models from the RETURNING clause. for i := 0; rows.Next(); i++ { - if err := rows.Err(); err != nil { - return sqlcon.HandleError(err) - } var id uuid.UUID if err := rows.Scan(&id); err != nil { return errors.WithStack(err) @@ -297,16 +294,13 @@ func handleFullInserts[T any](models []*T, rows *sql.Rows) error { return sqlcon.HandleError(err) } - return sqlcon.HandleError(rows.Close()) + return nil } func handlePartialInserts[T any](queryArgs insertQueryArgs, values []any, models []*T, rows *sql.Rows) error { // Hydrate the models from the RETURNING clause. idsInDB := make(map[uuid.UUID]struct{}) for rows.Next() { - if err := rows.Err(); err != nil { - return sqlcon.HandleError(err) - } var id uuid.UUID if err := rows.Scan(&id); err != nil { return errors.WithStack(err) @@ -317,10 +311,6 @@ func handlePartialInserts[T any](queryArgs insertQueryArgs, values []any, models return sqlcon.HandleError(err) } - if err := rows.Close(); err != nil { - return sqlcon.HandleError(err) - } - idIdx := slices.Index(queryArgs.Columns, "id") if idIdx == -1 { return errors.New("id column not found") @@ -346,7 +336,6 @@ func handlePartialInserts[T any](queryArgs insertQueryArgs, values []any, models } return nil - } // setModelID sets the id field of the model to the id. From efa05f6ee1a2eaf6c672a88945df8a91e32ec25d Mon Sep 17 00:00:00 2001 From: Henning Perl Date: Tue, 17 Sep 2024 11:52:47 +0200 Subject: [PATCH 21/21] code review --- persistence/sql/batch/create.go | 22 +++++++++++++++------- 1 file changed, 15 insertions(+), 7 deletions(-) diff --git a/persistence/sql/batch/create.go b/persistence/sql/batch/create.go index 01bc85a732a0..30b3c9768a7a 100644 --- a/persistence/sql/batch/create.go +++ b/persistence/sql/batch/create.go @@ -41,6 +41,9 @@ type ( Connection *pop.Connection } + // PartialConflictError represents a partial conflict during [Create]. It always + // wraps a [sqlcon.ErrUniqueViolation], so that the caller can either abort the + // whole transaction, or handle the partial success. PartialConflictError[T any] struct { Failed []*T } @@ -55,6 +58,12 @@ func (p *PartialConflictError[T]) ErrOrNil() error { } return p } +func (p *PartialConflictError[T]) Unwrap() error { + if len(p.Failed) == 0 { + return nil + } + return sqlcon.ErrUniqueViolation +} func buildInsertQueryArgs[T any](ctx context.Context, models []*T, opts *createOpts) insertQueryArgs { var ( @@ -217,8 +226,11 @@ func newCreateOpts(conn *pop.Connection, opts ...CreateOpts) *createOpts { return o } -// Create batch-inserts the given models into the database using a single INSERT statement. -// The models are either all created or none. +// Create batch-inserts the given models into the database using a single INSERT +// statement. By default, the models are either all created or none. If +// [WithPartialInserts] is passed as an option, partial inserts are supported, +// and the models that could not be inserted are returned in an +// [PartialConflictError]. func Create[T any](ctx context.Context, p *TracerConnection, models []*T, opts ...CreateOpts) (err error) { ctx, span := p.Tracer.Tracer().Start(ctx, "persistence.sql.batch.Create", trace.WithAttributes(attribute.Int("count", len(models)))) @@ -331,11 +343,7 @@ func handlePartialInserts[T any](queryArgs insertQueryArgs, values []any, models } } - if len(partialConflictError.Failed) > 0 { - return sqlcon.ErrUniqueViolation.WithWrap(&partialConflictError) - } - - return nil + return partialConflictError.ErrOrNil() } // setModelID sets the id field of the model to the id.