From a2842a01c035423651d1a8a979a2edf4e9184cdf Mon Sep 17 00:00:00 2001 From: Henning Perl Date: Thu, 12 Sep 2024 09:23:20 +0200 Subject: [PATCH] code review --- identity/handler_test.go | 99 +++++++++---------- identity/manager.go | 52 ++++++---- .../sql/identity/persister_identity.go | 7 +- 3 files changed, 81 insertions(+), 77 deletions(-) diff --git a/identity/handler_test.go b/identity/handler_test.go index 2162095c0431..33d6a46adf87 100644 --- a/identity/handler_test.go +++ b/identity/handler_test.go @@ -759,68 +759,63 @@ func TestHandler(t *testing.T) { assert.Contains(t, res.Get("error.reason").String(), strconv.Itoa(identity.BatchPatchIdentitiesLimit), "the error reason should contain the limit") }) - t.Run("case=fails all on a bad identity", func(t *testing.T) { + t.Run("case=fails some on a bad identity", func(t *testing.T) { // Test setup: we have a list of valid identitiy patches and a list of invalid ones. // Each run adds one invalid patch to the list and sends it to the server. // --> we expect the server to fail all patches in the list. // Finally, we send just the valid patches // --> we expect the server to succeed all patches in the list. - validPatches := []*identity.BatchIdentityPatch{ - {Create: validCreateIdentityBody("valid-patch", 0)}, - {Create: validCreateIdentityBody("valid-patch", 1)}, - {Create: validCreateIdentityBody("valid-patch", 2)}, - {Create: validCreateIdentityBody("valid-patch", 3)}, - {Create: validCreateIdentityBody("valid-patch", 4)}, - } - for _, tt := range []struct { - name string - body *identity.CreateIdentityBody - }{ - { - name: "missing-all-fields", - body: &identity.CreateIdentityBody{}, - }, - { - name: "duplicate-identity", - body: validCreateIdentityBody("duplicate-identity", 0), - }, - { - name: "invalid-traits", - body: &identity.CreateIdentityBody{ - Traits: json.RawMessage(`"invalid traits"`), - }, - }, - } { - t.Run("invalid because "+tt.name, func(t *testing.T) { - validPatches := []*identity.BatchIdentityPatch{ - {Create: validCreateIdentityBody(tt.name, 0)}, - {Create: validCreateIdentityBody(tt.name, 1)}, - {Create: validCreateIdentityBody(tt.name, 2)}, - {Create: validCreateIdentityBody(tt.name, 3)}, - {Create: validCreateIdentityBody(tt.name, 4)}, - } + t.Run("case=invalid patches fail", func(t *testing.T) { + patches := []*identity.BatchIdentityPatch{ + {Create: validCreateIdentityBody("valid", 0)}, + {Create: validCreateIdentityBody("valid", 1)}, + {Create: &identity.CreateIdentityBody{}}, // <-- invalid: missing all fields + {Create: validCreateIdentityBody("valid", 2)}, + {Create: validCreateIdentityBody("valid", 0)}, // <-- duplicate + {Create: validCreateIdentityBody("valid", 3)}, + {Create: &identity.CreateIdentityBody{Traits: json.RawMessage(`"invalid traits"`)}}, // <-- invalid traits + {Create: validCreateIdentityBody("valid", 4)}, + } - patches := make([]*identity.BatchIdentityPatch, 0, len(validPatches)+1) - patches = append(patches, validPatches[0:3]...) - patches = append(patches, &identity.BatchIdentityPatch{Create: tt.body}) - patches = append(patches, validPatches[3:5]...) - for i, p := range patches { - id := uuid.NewV5(uuid.Nil, fmt.Sprintf("%s-%d", tt.name, i)) - p.ID = &id - } + // Create unique IDs for each patch + var patchIDs []string + for i, p := range patches { + id := uuid.NewV5(uuid.Nil, fmt.Sprintf("%d", i)) + p.ID = &id + patchIDs = append(patchIDs, id.String()) + } - req := &identity.BatchPatchIdentitiesBody{Identities: patches} - body := send(t, adminTS, "PATCH", "/identities", http.StatusOK, req) - var actions []string - for _, a := range body.Get("identities.#.action").Array() { - actions = append(actions, a.String()) - } - assert.Equal(t, []string{"create", "create", "create", "error", "create", "create"}, actions, body) - }) - } + req := &identity.BatchPatchIdentitiesBody{Identities: patches} + body := send(t, adminTS, "PATCH", "/identities", http.StatusOK, req) + var actions []string + for _, a := range body.Get("identities.#.action").Array() { + actions = append(actions, a.String()) + } + assert.Equal(t, + []string{"create", "create", "error", "create", "error", "create", "error", "create"}, + actions, body) + + // Check that all patch IDs are returned + for i, gotPatchID := range body.Get("identities.#.patch_id").Array() { + assert.Equal(t, patchIDs[i], gotPatchID.String()) + } + + // Check specific errors + assert.Equal(t, "Bad Request", body.Get("identities.2.error.status").String()) + assert.Equal(t, "Conflict", body.Get("identities.4.error.status").String()) + assert.Equal(t, "Bad Request", body.Get("identities.6.error.status").String()) + + }) t.Run("valid patches succeed", func(t *testing.T) { + validPatches := []*identity.BatchIdentityPatch{ + {Create: validCreateIdentityBody("valid-patch", 0)}, + {Create: validCreateIdentityBody("valid-patch", 1)}, + {Create: validCreateIdentityBody("valid-patch", 2)}, + {Create: validCreateIdentityBody("valid-patch", 3)}, + {Create: validCreateIdentityBody("valid-patch", 4)}, + } req := &identity.BatchPatchIdentitiesBody{Identities: validPatches} send(t, adminTS, "PATCH", "/identities", http.StatusOK, req) }) diff --git a/identity/manager.go b/identity/manager.go index fd13bc6041f2..485979ef773f 100644 --- a/identity/manager.go +++ b/identity/manager.go @@ -334,41 +334,56 @@ type FailedIdentity struct { } type CreateIdentitiesError struct { - Failed []*FailedIdentity + failedIdentities map[*Identity]*herodot.DefaultError } func (e *CreateIdentitiesError) Error() string { - return fmt.Sprintf("create identities error: %d identities failed", len(e.Failed)) + e.init() + return fmt.Sprintf("create identities error: %d identities failed", len(e.failedIdentities)) } func (e *CreateIdentitiesError) Unwrap() []error { + e.init() var errs []error - for _, failed := range e.Failed { - errs = append(errs, failed.Error) + for _, err := range e.failedIdentities { + errs = append(errs, err) } return errs } -func (e *CreateIdentitiesError) Contains(ident *Identity) bool { - for _, failed := range e.Failed { - if failed.Identity.ID == ident.ID { - return true - } + +func (e *CreateIdentitiesError) AddFailedIdentity(ident *Identity, err *herodot.DefaultError) { + e.init() + e.failedIdentities[ident] = err +} +func (e *CreateIdentitiesError) Merge(other *CreateIdentitiesError) { + e.init() + for k, v := range other.failedIdentities { + e.failedIdentities[k] = v } - return false +} +func (e *CreateIdentitiesError) Contains(ident *Identity) bool { + e.init() + _, found := e.failedIdentities[ident] + return found } func (e *CreateIdentitiesError) Find(ident *Identity) *FailedIdentity { - for _, failed := range e.Failed { - if failed.Identity.ID == ident.ID { - return failed - } + e.init() + if err, found := e.failedIdentities[ident]; found { + return &FailedIdentity{Identity: ident, Error: err} } + return nil } func (e *CreateIdentitiesError) ErrOrNil() error { - if len(e.Failed) == 0 { + if e.failedIdentities == nil || len(e.failedIdentities) == 0 { return nil } return e } +func (e *CreateIdentitiesError) init() { + if e.failedIdentities == nil { + e.failedIdentities = map[*Identity]*herodot.DefaultError{} + } +} func (m *Manager) CreateIdentities(ctx context.Context, identities []*Identity, opts ...ManagerOption) (err error) { ctx, span := m.r.Tracer(ctx).Tracer().Start(ctx, "identity.Manager.CreateIdentities") @@ -383,10 +398,7 @@ func (m *Manager) CreateIdentities(ctx context.Context, identities []*Identity, o := newManagerOptions(opts) if err := m.ValidateIdentity(ctx, ident, o); err != nil { - createIdentitiesError.Failed = append(createIdentitiesError.Failed, &FailedIdentity{ - Identity: ident, - Error: herodot.ErrBadRequest.WithReasonf("%s", err).WithWrap(err), - }) + createIdentitiesError.AddFailedIdentity(ident, herodot.ErrBadRequest.WithReasonf("%s", err).WithWrap(err)) continue } validIdentities = append(validIdentities, ident) @@ -394,7 +406,7 @@ func (m *Manager) CreateIdentities(ctx context.Context, identities []*Identity, if err := m.r.PrivilegedIdentityPool().CreateIdentities(ctx, validIdentities...); err != nil { if partialErr := new(CreateIdentitiesError); errors.As(err, &partialErr) { - createIdentitiesError.Failed = append(createIdentitiesError.Failed, partialErr.Failed...) + createIdentitiesError.Merge(partialErr) } else { return err } diff --git a/persistence/sql/identity/persister_identity.go b/persistence/sql/identity/persister_identity.go index b0ecf85e8a43..64643bde529f 100644 --- a/persistence/sql/identity/persister_identity.go +++ b/persistence/sql/identity/persister_identity.go @@ -629,10 +629,7 @@ func (p *IdentityPersister) CreateIdentities(ctx context.Context, identities ... failedIDs := make([]uuid.UUID, 0, len(failedIdentityIDs)) for _, ident := range identities { if _, ok := failedIdentityIDs[ident.ID]; ok { - partialErr.Failed = append(partialErr.Failed, &identity.FailedIdentity{ - Identity: ident, - Error: sqlcon.ErrUniqueViolation, - }) + partialErr.AddFailedIdentity(ident, sqlcon.ErrUniqueViolation) failedIDs = append(failedIDs, ident.ID) } } @@ -643,7 +640,7 @@ func (p *IdentityPersister) CreateIdentities(ctx context.Context, identities ... } // Wrap the partial error with the first error that occurred, so that the caller // can continue to handle the error either as a partial error or a full error. - return partialErr.Failed[0].Error.WithWrap(partialErr) + return partialErr } return nil