Skip to content

Commit

Permalink
feat: allow partially failing batch inserts (#4083)
Browse files Browse the repository at this point in the history
When batch-inserting multiple identities, conflicts or validation errors of a subset of identities in the batch still allow the rest of the identities to be inserted. The returned JSON contains the error details that lead to the failure.

---------

Co-authored-by: Patrik <[email protected]>
  • Loading branch information
hperl and zepatrik authored Sep 17, 2024
1 parent e451b74 commit 4ba7033
Show file tree
Hide file tree
Showing 22 changed files with 697 additions and 189 deletions.
13 changes: 11 additions & 2 deletions identity/handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -617,13 +617,22 @@ func (h *Handler) batchPatchIdentities(w http.ResponseWriter, r *http.Request, _
}
}

if err := h.r.IdentityManager().CreateIdentities(r.Context(), identities); err != nil {
err := h.r.IdentityManager().CreateIdentities(r.Context(), identities)
partialErr := new(CreateIdentitiesError)
if err != nil && !errors.As(err, &partialErr) {
h.r.Writer().WriteError(w, r, err)
return
}
for resIdx, identitiesIdx := range indexInIdentities {
if identitiesIdx != nil {
res.Identities[resIdx].IdentityID = &identities[*identitiesIdx].ID
ident := identities[*identitiesIdx]
// Check if the identity was created successfully.
if failed := partialErr.Find(ident); failed != nil {
res.Identities[resIdx].Action = ActionError
res.Identities[resIdx].Error = failed.Error
} else {
res.Identities[resIdx].IdentityID = &ident.ID
}
}
}

Expand Down
90 changes: 50 additions & 40 deletions identity/handler_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -759,53 +759,63 @@ func TestHandler(t *testing.T) {
assert.Contains(t, res.Get("error.reason").String(), strconv.Itoa(identity.BatchPatchIdentitiesLimit),
"the error reason should contain the limit")
})
t.Run("case=fails all on a bad identity", func(t *testing.T) {
t.Run("case=fails some on a bad identity", func(t *testing.T) {
// Test setup: we have a list of valid identitiy patches and a list of invalid ones.
// Each run adds one invalid patch to the list and sends it to the server.
// --> we expectedIdentifiers the server to fail all patches in the list.
// --> we expect the server to fail all patches in the list.
// Finally, we send just the valid patches
// --> we expectedIdentifiers the server to succeed all patches in the list.
validPatches := []*identity.BatchIdentityPatch{
{Create: validCreateIdentityBody("valid-patch", 0)},
{Create: validCreateIdentityBody("valid-patch", 1)},
{Create: validCreateIdentityBody("valid-patch", 2)},
{Create: validCreateIdentityBody("valid-patch", 3)},
{Create: validCreateIdentityBody("valid-patch", 4)},
}
// --> we expect the server to succeed all patches in the list.

t.Run("case=invalid patches fail", func(t *testing.T) {
patches := []*identity.BatchIdentityPatch{
{Create: validCreateIdentityBody("valid", 0)},
{Create: validCreateIdentityBody("valid", 1)},
{Create: &identity.CreateIdentityBody{}}, // <-- invalid: missing all fields
{Create: validCreateIdentityBody("valid", 2)},
{Create: validCreateIdentityBody("valid", 0)}, // <-- duplicate
{Create: validCreateIdentityBody("valid", 3)},
{Create: &identity.CreateIdentityBody{Traits: json.RawMessage(`"invalid traits"`)}}, // <-- invalid traits
{Create: validCreateIdentityBody("valid", 4)},
}

for _, tt := range []struct {
name string
body *identity.CreateIdentityBody
expectStatus int
}{
{
name: "missing all fields",
body: &identity.CreateIdentityBody{},
expectStatus: http.StatusBadRequest,
},
{
name: "duplicate identity",
body: validCreateIdentityBody("valid-patch", 0),
expectStatus: http.StatusConflict,
},
{
name: "invalid traits",
body: &identity.CreateIdentityBody{
Traits: json.RawMessage(`"invalid traits"`),
},
expectStatus: http.StatusBadRequest,
},
} {
t.Run("invalid because "+tt.name, func(t *testing.T) {
patches := append([]*identity.BatchIdentityPatch{}, validPatches...)
patches = append(patches, &identity.BatchIdentityPatch{Create: tt.body})
// Create unique IDs for each patch
var patchIDs []string
for i, p := range patches {
id := uuid.NewV5(uuid.Nil, fmt.Sprintf("%d", i))
p.ID = &id
patchIDs = append(patchIDs, id.String())
}

req := &identity.BatchPatchIdentitiesBody{Identities: patches}
send(t, adminTS, "PATCH", "/identities", tt.expectStatus, req)
})
}
req := &identity.BatchPatchIdentitiesBody{Identities: patches}
body := send(t, adminTS, "PATCH", "/identities", http.StatusOK, req)
var actions []string
for _, a := range body.Get("identities.#.action").Array() {
actions = append(actions, a.String())
}
assert.Equal(t,
[]string{"create", "create", "error", "create", "error", "create", "error", "create"},
actions, body)

// Check that all patch IDs are returned
for i, gotPatchID := range body.Get("identities.#.patch_id").Array() {
assert.Equal(t, patchIDs[i], gotPatchID.String())
}

// Check specific errors
assert.Equal(t, "Bad Request", body.Get("identities.2.error.status").String())
assert.Equal(t, "Conflict", body.Get("identities.4.error.status").String())
assert.Equal(t, "Bad Request", body.Get("identities.6.error.status").String())

})

t.Run("valid patches succeed", func(t *testing.T) {
validPatches := []*identity.BatchIdentityPatch{
{Create: validCreateIdentityBody("valid-patch", 0)},
{Create: validCreateIdentityBody("valid-patch", 1)},
{Create: validCreateIdentityBody("valid-patch", 2)},
{Create: validCreateIdentityBody("valid-patch", 3)},
{Create: validCreateIdentityBody("valid-patch", 4)},
}
req := &identity.BatchPatchIdentitiesBody{Identities: validPatches}
send(t, adminTS, "PATCH", "/identities", http.StatusOK, req)
})
Expand Down
21 changes: 11 additions & 10 deletions identity/identity.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,22 +11,17 @@ import (
"sync"
"time"

"github.com/gofrs/uuid"
"github.com/pkg/errors"
"github.com/samber/lo"

"github.com/tidwall/sjson"

"github.com/tidwall/gjson"

"github.com/ory/kratos/cipher"
"github.com/tidwall/sjson"

"github.com/ory/herodot"
"github.com/ory/kratos/cipher"
"github.com/ory/kratos/driver/config"
"github.com/ory/x/pagination/keysetpagination"
"github.com/ory/x/sqlxx"

"github.com/ory/kratos/driver/config"

"github.com/gofrs/uuid"
"github.com/pkg/errors"
)

// An Identity's State
Expand Down Expand Up @@ -645,6 +640,9 @@ const (
// Create this identity.
ActionCreate BatchPatchAction = "create"

// Error indicates that the patch failed.
ActionError BatchPatchAction = "error"

// Future actions:
//
// Delete this identity.
Expand Down Expand Up @@ -677,4 +675,7 @@ type BatchIdentityPatchResponse struct {

// The ID of this patch response, if an ID was specified in the patch.
PatchID *uuid.UUID `json:"patch_id,omitempty"`

// The error message, if the action was "error".
Error *herodot.DefaultError `json:"error,omitempty"`
}
82 changes: 74 additions & 8 deletions identity/manager.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ package identity
import (
"context"
"encoding/json"
"fmt"
"reflect"
"slices"
"sort"
Expand Down Expand Up @@ -323,26 +324,91 @@ func (e *ErrDuplicateCredentials) HasHints() bool {
return len(e.availableCredentials) > 0 || len(e.availableOIDCProviders) > 0 || len(e.identifierHint) > 0
}

type FailedIdentity struct {
Identity *Identity
Error *herodot.DefaultError
}

type CreateIdentitiesError struct {
failedIdentities map[*Identity]*herodot.DefaultError
}

func (e *CreateIdentitiesError) Error() string {
e.init()
return fmt.Sprintf("create identities error: %d identities failed", len(e.failedIdentities))
}
func (e *CreateIdentitiesError) Unwrap() []error {
e.init()
var errs []error
for _, err := range e.failedIdentities {
errs = append(errs, err)
}
return errs
}

func (e *CreateIdentitiesError) AddFailedIdentity(ident *Identity, err *herodot.DefaultError) {
e.init()
e.failedIdentities[ident] = err
}
func (e *CreateIdentitiesError) Merge(other *CreateIdentitiesError) {
e.init()
for k, v := range other.failedIdentities {
e.failedIdentities[k] = v
}
}
func (e *CreateIdentitiesError) Contains(ident *Identity) bool {
e.init()
_, found := e.failedIdentities[ident]
return found
}
func (e *CreateIdentitiesError) Find(ident *Identity) *FailedIdentity {
e.init()
if err, found := e.failedIdentities[ident]; found {
return &FailedIdentity{Identity: ident, Error: err}
}

return nil
}
func (e *CreateIdentitiesError) ErrOrNil() error {
if e.failedIdentities == nil || len(e.failedIdentities) == 0 {
return nil
}
return e
}
func (e *CreateIdentitiesError) init() {
if e.failedIdentities == nil {
e.failedIdentities = map[*Identity]*herodot.DefaultError{}
}
}

func (m *Manager) CreateIdentities(ctx context.Context, identities []*Identity, opts ...ManagerOption) (err error) {
ctx, span := m.r.Tracer(ctx).Tracer().Start(ctx, "identity.Manager.CreateIdentities")
defer otelx.End(span, &err)

for _, i := range identities {
if i.SchemaID == "" {
i.SchemaID = m.r.Config().DefaultIdentityTraitsSchemaID(ctx)
createIdentitiesError := &CreateIdentitiesError{}
validIdentities := make([]*Identity, 0, len(identities))
for _, ident := range identities {
if ident.SchemaID == "" {
ident.SchemaID = m.r.Config().DefaultIdentityTraitsSchemaID(ctx)
}

o := newManagerOptions(opts)
if err := m.ValidateIdentity(ctx, i, o); err != nil {
return err
if err := m.ValidateIdentity(ctx, ident, o); err != nil {
createIdentitiesError.AddFailedIdentity(ident, herodot.ErrBadRequest.WithReasonf("%s", err).WithWrap(err))
continue
}
validIdentities = append(validIdentities, ident)
}

if err := m.r.PrivilegedIdentityPool().CreateIdentities(ctx, identities...); err != nil {
return err
if err := m.r.PrivilegedIdentityPool().CreateIdentities(ctx, validIdentities...); err != nil {
if partialErr := new(CreateIdentitiesError); errors.As(err, &partialErr) {
createIdentitiesError.Merge(partialErr)
} else {
return err
}
}

return nil
return createIdentitiesError.ErrOrNil()
}

func (m *Manager) requiresPrivilegedAccess(ctx context.Context, original, updated *Identity, o *ManagerOptions) (err error) {
Expand Down
6 changes: 5 additions & 1 deletion identity/pool.go
Original file line number Diff line number Diff line change
Expand Up @@ -61,9 +61,13 @@ type (
FindByCredentialsIdentifier(ctx context.Context, ct CredentialsType, match string) (*Identity, *Credentials, error)

// DeleteIdentity removes an identity by its id. Will return an error
// if identity exists, backend connectivity is broken, or trait validation fails.
// if identity does not exists, or backend connectivity is broken.
DeleteIdentity(context.Context, uuid.UUID) error

// DeleteIdentities removes identities by its id. Will return an error
// if any identity does not exists, or backend connectivity is broken.
DeleteIdentities(context.Context, []uuid.UUID) error

// UpdateVerifiableAddress updates an identity's verifiable address.
UpdateVerifiableAddress(ctx context.Context, address *VerifiableAddress) error

Expand Down
41 changes: 39 additions & 2 deletions internal/client-go/model_identity_patch_response.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Loading

0 comments on commit 4ba7033

Please sign in to comment.