Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: allow partially failing batch inserts #4083

Merged
merged 27 commits into from
Sep 17, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
4808d9c
feat: partial failures for batch import
hperl Sep 9, 2024
03b6534
chore: regenerate sdk
hperl Sep 9, 2024
db8c7a4
Merge branch 'master' into hperl/improve-batch-patch-identities-error…
hperl Sep 10, 2024
2b7e3e9
fix: implement Unwrap() on partial create error
hperl Sep 10, 2024
e2b7831
bump upload action
hperl Sep 10, 2024
01f199d
fix tests
hperl Sep 10, 2024
447c5d6
fix: mysql
hperl Sep 10, 2024
cec29ca
fix tests
hperl Sep 10, 2024
badcd6e
fix: uploading
hperl Sep 10, 2024
6fec60d
fix: wrapping
hperl Sep 10, 2024
04b302c
fix artifact upload names
hperl Sep 10, 2024
0ed2ed1
test: fix lookup E2E test
hperl Sep 11, 2024
03c127c
test: always use www.example.org as return_to URL in tests
hperl Sep 11, 2024
166302d
Merge branch 'master' into hperl/improve-batch-patch-identities-error…
zepatrik Sep 11, 2024
a2842a0
code review
hperl Sep 12, 2024
50b21ed
code review
hperl Sep 12, 2024
17ed6fe
fix test
hperl Sep 12, 2024
e491329
add batch persister test
hperl Sep 13, 2024
8a892c9
Merge remote-tracking branch 'origin/master' into hperl/improve-batch…
hperl Sep 13, 2024
3777ccf
chore: format
hperl Sep 13, 2024
9484feb
only use partial inserts on batch inserts
hperl Sep 16, 2024
afd8d6a
Merge remote-tracking branch 'origin/master' into hperl/improve-batch…
hperl Sep 16, 2024
b593fc6
code review
hperl Sep 17, 2024
99a010c
fix: usage of rows.Err and rows.Close
hperl Sep 17, 2024
efa05f6
code review
hperl Sep 17, 2024
910c590
Merge remote-tracking branch 'origin/master' into hperl/improve-batch…
hperl Sep 17, 2024
0fc3029
Merge branch 'master' into hperl/improve-batch-patch-identities-error…
hperl Sep 17, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 11 additions & 2 deletions identity/handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -617,13 +617,22 @@ func (h *Handler) batchPatchIdentities(w http.ResponseWriter, r *http.Request, _
}
}

if err := h.r.IdentityManager().CreateIdentities(r.Context(), identities); err != nil {
err := h.r.IdentityManager().CreateIdentities(r.Context(), identities)
partialErr := new(CreateIdentitiesError)
if err != nil && !errors.As(err, &partialErr) {
h.r.Writer().WriteError(w, r, err)
return
}
for resIdx, identitiesIdx := range indexInIdentities {
if identitiesIdx != nil {
res.Identities[resIdx].IdentityID = &identities[*identitiesIdx].ID
ident := identities[*identitiesIdx]
// Check if the identity was created successfully.
if failed := partialErr.Find(ident); failed != nil {
hperl marked this conversation as resolved.
Show resolved Hide resolved
res.Identities[resIdx].Action = ActionError
res.Identities[resIdx].Error = failed.Error
} else {
res.Identities[resIdx].IdentityID = &ident.ID
}
}
}

Expand Down
90 changes: 50 additions & 40 deletions identity/handler_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -759,53 +759,63 @@ func TestHandler(t *testing.T) {
assert.Contains(t, res.Get("error.reason").String(), strconv.Itoa(identity.BatchPatchIdentitiesLimit),
"the error reason should contain the limit")
})
t.Run("case=fails all on a bad identity", func(t *testing.T) {
t.Run("case=fails some on a bad identity", func(t *testing.T) {
// Test setup: we have a list of valid identitiy patches and a list of invalid ones.
// Each run adds one invalid patch to the list and sends it to the server.
// --> we expectedIdentifiers the server to fail all patches in the list.
// --> we expect the server to fail all patches in the list.
// Finally, we send just the valid patches
// --> we expectedIdentifiers the server to succeed all patches in the list.
validPatches := []*identity.BatchIdentityPatch{
{Create: validCreateIdentityBody("valid-patch", 0)},
{Create: validCreateIdentityBody("valid-patch", 1)},
{Create: validCreateIdentityBody("valid-patch", 2)},
{Create: validCreateIdentityBody("valid-patch", 3)},
{Create: validCreateIdentityBody("valid-patch", 4)},
}
// --> we expect the server to succeed all patches in the list.

t.Run("case=invalid patches fail", func(t *testing.T) {
patches := []*identity.BatchIdentityPatch{
{Create: validCreateIdentityBody("valid", 0)},
{Create: validCreateIdentityBody("valid", 1)},
{Create: &identity.CreateIdentityBody{}}, // <-- invalid: missing all fields
{Create: validCreateIdentityBody("valid", 2)},
{Create: validCreateIdentityBody("valid", 0)}, // <-- duplicate
{Create: validCreateIdentityBody("valid", 3)},
{Create: &identity.CreateIdentityBody{Traits: json.RawMessage(`"invalid traits"`)}}, // <-- invalid traits
{Create: validCreateIdentityBody("valid", 4)},
}

for _, tt := range []struct {
name string
body *identity.CreateIdentityBody
expectStatus int
}{
{
name: "missing all fields",
body: &identity.CreateIdentityBody{},
expectStatus: http.StatusBadRequest,
},
{
name: "duplicate identity",
body: validCreateIdentityBody("valid-patch", 0),
expectStatus: http.StatusConflict,
},
{
name: "invalid traits",
body: &identity.CreateIdentityBody{
Traits: json.RawMessage(`"invalid traits"`),
},
expectStatus: http.StatusBadRequest,
},
} {
t.Run("invalid because "+tt.name, func(t *testing.T) {
patches := append([]*identity.BatchIdentityPatch{}, validPatches...)
patches = append(patches, &identity.BatchIdentityPatch{Create: tt.body})
// Create unique IDs for each patch
var patchIDs []string
for i, p := range patches {
id := uuid.NewV5(uuid.Nil, fmt.Sprintf("%d", i))
p.ID = &id
patchIDs = append(patchIDs, id.String())
}

req := &identity.BatchPatchIdentitiesBody{Identities: patches}
send(t, adminTS, "PATCH", "/identities", tt.expectStatus, req)
})
}
req := &identity.BatchPatchIdentitiesBody{Identities: patches}
body := send(t, adminTS, "PATCH", "/identities", http.StatusOK, req)
var actions []string
for _, a := range body.Get("identities.#.action").Array() {
actions = append(actions, a.String())
}
assert.Equal(t,
[]string{"create", "create", "error", "create", "error", "create", "error", "create"},
actions, body)

// Check that all patch IDs are returned
for i, gotPatchID := range body.Get("identities.#.patch_id").Array() {
assert.Equal(t, patchIDs[i], gotPatchID.String())
}

// Check specific errors
assert.Equal(t, "Bad Request", body.Get("identities.2.error.status").String())
assert.Equal(t, "Conflict", body.Get("identities.4.error.status").String())
assert.Equal(t, "Bad Request", body.Get("identities.6.error.status").String())

})

t.Run("valid patches succeed", func(t *testing.T) {
validPatches := []*identity.BatchIdentityPatch{
{Create: validCreateIdentityBody("valid-patch", 0)},
{Create: validCreateIdentityBody("valid-patch", 1)},
{Create: validCreateIdentityBody("valid-patch", 2)},
{Create: validCreateIdentityBody("valid-patch", 3)},
{Create: validCreateIdentityBody("valid-patch", 4)},
}
req := &identity.BatchPatchIdentitiesBody{Identities: validPatches}
send(t, adminTS, "PATCH", "/identities", http.StatusOK, req)
})
Expand Down
21 changes: 11 additions & 10 deletions identity/identity.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,22 +11,17 @@ import (
"sync"
"time"

"github.com/gofrs/uuid"
"github.com/pkg/errors"
"github.com/samber/lo"

"github.com/tidwall/sjson"

"github.com/tidwall/gjson"

"github.com/ory/kratos/cipher"
"github.com/tidwall/sjson"

"github.com/ory/herodot"
"github.com/ory/kratos/cipher"
"github.com/ory/kratos/driver/config"
"github.com/ory/x/pagination/keysetpagination"
"github.com/ory/x/sqlxx"

"github.com/ory/kratos/driver/config"

"github.com/gofrs/uuid"
"github.com/pkg/errors"
)

// An Identity's State
Expand Down Expand Up @@ -645,6 +640,9 @@ const (
// Create this identity.
ActionCreate BatchPatchAction = "create"

// Error indicates that the patch failed.
ActionError BatchPatchAction = "error"

// Future actions:
//
// Delete this identity.
Expand Down Expand Up @@ -677,4 +675,7 @@ type BatchIdentityPatchResponse struct {

// The ID of this patch response, if an ID was specified in the patch.
PatchID *uuid.UUID `json:"patch_id,omitempty"`

// The error message, if the action was "error".
Error *herodot.DefaultError `json:"error,omitempty"`
}
82 changes: 74 additions & 8 deletions identity/manager.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ package identity
import (
"context"
"encoding/json"
"fmt"
"reflect"
"slices"
"sort"
Expand Down Expand Up @@ -323,26 +324,91 @@ func (e *ErrDuplicateCredentials) HasHints() bool {
return len(e.availableCredentials) > 0 || len(e.availableOIDCProviders) > 0 || len(e.identifierHint) > 0
}

type FailedIdentity struct {
Identity *Identity
Error *herodot.DefaultError
}

type CreateIdentitiesError struct {
failedIdentities map[*Identity]*herodot.DefaultError
}

func (e *CreateIdentitiesError) Error() string {
e.init()
return fmt.Sprintf("create identities error: %d identities failed", len(e.failedIdentities))
}
func (e *CreateIdentitiesError) Unwrap() []error {
e.init()
var errs []error
for _, err := range e.failedIdentities {
errs = append(errs, err)
}
return errs
}

func (e *CreateIdentitiesError) AddFailedIdentity(ident *Identity, err *herodot.DefaultError) {
e.init()
e.failedIdentities[ident] = err
}
func (e *CreateIdentitiesError) Merge(other *CreateIdentitiesError) {
e.init()
for k, v := range other.failedIdentities {
e.failedIdentities[k] = v
}
}
func (e *CreateIdentitiesError) Contains(ident *Identity) bool {
e.init()
_, found := e.failedIdentities[ident]
return found
}
func (e *CreateIdentitiesError) Find(ident *Identity) *FailedIdentity {
e.init()
if err, found := e.failedIdentities[ident]; found {
return &FailedIdentity{Identity: ident, Error: err}
}

return nil
}
func (e *CreateIdentitiesError) ErrOrNil() error {
if e.failedIdentities == nil || len(e.failedIdentities) == 0 {
return nil
}
return e
}
func (e *CreateIdentitiesError) init() {
if e.failedIdentities == nil {
e.failedIdentities = map[*Identity]*herodot.DefaultError{}
}
}

func (m *Manager) CreateIdentities(ctx context.Context, identities []*Identity, opts ...ManagerOption) (err error) {
ctx, span := m.r.Tracer(ctx).Tracer().Start(ctx, "identity.Manager.CreateIdentities")
defer otelx.End(span, &err)

for _, i := range identities {
if i.SchemaID == "" {
i.SchemaID = m.r.Config().DefaultIdentityTraitsSchemaID(ctx)
createIdentitiesError := &CreateIdentitiesError{}
validIdentities := make([]*Identity, 0, len(identities))
for _, ident := range identities {
if ident.SchemaID == "" {
ident.SchemaID = m.r.Config().DefaultIdentityTraitsSchemaID(ctx)
}

o := newManagerOptions(opts)
if err := m.ValidateIdentity(ctx, i, o); err != nil {
return err
if err := m.ValidateIdentity(ctx, ident, o); err != nil {
createIdentitiesError.AddFailedIdentity(ident, herodot.ErrBadRequest.WithReasonf("%s", err).WithWrap(err))
continue
}
validIdentities = append(validIdentities, ident)
}

if err := m.r.PrivilegedIdentityPool().CreateIdentities(ctx, identities...); err != nil {
return err
if err := m.r.PrivilegedIdentityPool().CreateIdentities(ctx, validIdentities...); err != nil {
if partialErr := new(CreateIdentitiesError); errors.As(err, &partialErr) {
createIdentitiesError.Merge(partialErr)
} else {
return err
}
}

return nil
return createIdentitiesError.ErrOrNil()
}

func (m *Manager) requiresPrivilegedAccess(ctx context.Context, original, updated *Identity, o *ManagerOptions) (err error) {
Expand Down
6 changes: 5 additions & 1 deletion identity/pool.go
Original file line number Diff line number Diff line change
Expand Up @@ -61,9 +61,13 @@ type (
FindByCredentialsIdentifier(ctx context.Context, ct CredentialsType, match string) (*Identity, *Credentials, error)

// DeleteIdentity removes an identity by its id. Will return an error
// if identity exists, backend connectivity is broken, or trait validation fails.
// if identity does not exists, or backend connectivity is broken.
DeleteIdentity(context.Context, uuid.UUID) error

// DeleteIdentities removes identities by its id. Will return an error
// if any identity does not exists, or backend connectivity is broken.
DeleteIdentities(context.Context, []uuid.UUID) error

// UpdateVerifiableAddress updates an identity's verifiable address.
UpdateVerifiableAddress(ctx context.Context, address *VerifiableAddress) error

Expand Down
41 changes: 39 additions & 2 deletions internal/client-go/model_identity_patch_response.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Loading
Loading