Skip to content

Commit

Permalink
code review
Browse files Browse the repository at this point in the history
  • Loading branch information
hperl committed Sep 12, 2024
1 parent 166302d commit a2842a0
Show file tree
Hide file tree
Showing 3 changed files with 81 additions and 77 deletions.
99 changes: 47 additions & 52 deletions identity/handler_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -759,68 +759,63 @@ func TestHandler(t *testing.T) {
assert.Contains(t, res.Get("error.reason").String(), strconv.Itoa(identity.BatchPatchIdentitiesLimit),
"the error reason should contain the limit")
})
t.Run("case=fails all on a bad identity", func(t *testing.T) {
t.Run("case=fails some on a bad identity", func(t *testing.T) {
// Test setup: we have a list of valid identitiy patches and a list of invalid ones.
// Each run adds one invalid patch to the list and sends it to the server.
// --> we expect the server to fail all patches in the list.
// Finally, we send just the valid patches
// --> we expect the server to succeed all patches in the list.
validPatches := []*identity.BatchIdentityPatch{
{Create: validCreateIdentityBody("valid-patch", 0)},
{Create: validCreateIdentityBody("valid-patch", 1)},
{Create: validCreateIdentityBody("valid-patch", 2)},
{Create: validCreateIdentityBody("valid-patch", 3)},
{Create: validCreateIdentityBody("valid-patch", 4)},
}

for _, tt := range []struct {
name string
body *identity.CreateIdentityBody
}{
{
name: "missing-all-fields",
body: &identity.CreateIdentityBody{},
},
{
name: "duplicate-identity",
body: validCreateIdentityBody("duplicate-identity", 0),
},
{
name: "invalid-traits",
body: &identity.CreateIdentityBody{
Traits: json.RawMessage(`"invalid traits"`),
},
},
} {
t.Run("invalid because "+tt.name, func(t *testing.T) {
validPatches := []*identity.BatchIdentityPatch{
{Create: validCreateIdentityBody(tt.name, 0)},
{Create: validCreateIdentityBody(tt.name, 1)},
{Create: validCreateIdentityBody(tt.name, 2)},
{Create: validCreateIdentityBody(tt.name, 3)},
{Create: validCreateIdentityBody(tt.name, 4)},
}
t.Run("case=invalid patches fail", func(t *testing.T) {
patches := []*identity.BatchIdentityPatch{
{Create: validCreateIdentityBody("valid", 0)},
{Create: validCreateIdentityBody("valid", 1)},
{Create: &identity.CreateIdentityBody{}}, // <-- invalid: missing all fields
{Create: validCreateIdentityBody("valid", 2)},
{Create: validCreateIdentityBody("valid", 0)}, // <-- duplicate
{Create: validCreateIdentityBody("valid", 3)},
{Create: &identity.CreateIdentityBody{Traits: json.RawMessage(`"invalid traits"`)}}, // <-- invalid traits
{Create: validCreateIdentityBody("valid", 4)},
}

patches := make([]*identity.BatchIdentityPatch, 0, len(validPatches)+1)
patches = append(patches, validPatches[0:3]...)
patches = append(patches, &identity.BatchIdentityPatch{Create: tt.body})
patches = append(patches, validPatches[3:5]...)
for i, p := range patches {
id := uuid.NewV5(uuid.Nil, fmt.Sprintf("%s-%d", tt.name, i))
p.ID = &id
}
// Create unique IDs for each patch
var patchIDs []string
for i, p := range patches {
id := uuid.NewV5(uuid.Nil, fmt.Sprintf("%d", i))
p.ID = &id
patchIDs = append(patchIDs, id.String())
}

req := &identity.BatchPatchIdentitiesBody{Identities: patches}
body := send(t, adminTS, "PATCH", "/identities", http.StatusOK, req)
var actions []string
for _, a := range body.Get("identities.#.action").Array() {
actions = append(actions, a.String())
}
assert.Equal(t, []string{"create", "create", "create", "error", "create", "create"}, actions, body)
})
}
req := &identity.BatchPatchIdentitiesBody{Identities: patches}
body := send(t, adminTS, "PATCH", "/identities", http.StatusOK, req)
var actions []string
for _, a := range body.Get("identities.#.action").Array() {
actions = append(actions, a.String())
}
assert.Equal(t,
[]string{"create", "create", "error", "create", "error", "create", "error", "create"},
actions, body)

// Check that all patch IDs are returned
for i, gotPatchID := range body.Get("identities.#.patch_id").Array() {
assert.Equal(t, patchIDs[i], gotPatchID.String())
}

// Check specific errors
assert.Equal(t, "Bad Request", body.Get("identities.2.error.status").String())
assert.Equal(t, "Conflict", body.Get("identities.4.error.status").String())
assert.Equal(t, "Bad Request", body.Get("identities.6.error.status").String())

})

t.Run("valid patches succeed", func(t *testing.T) {
validPatches := []*identity.BatchIdentityPatch{
{Create: validCreateIdentityBody("valid-patch", 0)},
{Create: validCreateIdentityBody("valid-patch", 1)},
{Create: validCreateIdentityBody("valid-patch", 2)},
{Create: validCreateIdentityBody("valid-patch", 3)},
{Create: validCreateIdentityBody("valid-patch", 4)},
}
req := &identity.BatchPatchIdentitiesBody{Identities: validPatches}
send(t, adminTS, "PATCH", "/identities", http.StatusOK, req)
})
Expand Down
52 changes: 32 additions & 20 deletions identity/manager.go
Original file line number Diff line number Diff line change
Expand Up @@ -334,41 +334,56 @@ type FailedIdentity struct {
}

type CreateIdentitiesError struct {
Failed []*FailedIdentity
failedIdentities map[*Identity]*herodot.DefaultError
}

func (e *CreateIdentitiesError) Error() string {
return fmt.Sprintf("create identities error: %d identities failed", len(e.Failed))
e.init()
return fmt.Sprintf("create identities error: %d identities failed", len(e.failedIdentities))
}
func (e *CreateIdentitiesError) Unwrap() []error {
e.init()
var errs []error
for _, failed := range e.Failed {
errs = append(errs, failed.Error)
for _, err := range e.failedIdentities {
errs = append(errs, err)
}
return errs
}
func (e *CreateIdentitiesError) Contains(ident *Identity) bool {
for _, failed := range e.Failed {
if failed.Identity.ID == ident.ID {
return true
}

func (e *CreateIdentitiesError) AddFailedIdentity(ident *Identity, err *herodot.DefaultError) {
e.init()
e.failedIdentities[ident] = err
}
func (e *CreateIdentitiesError) Merge(other *CreateIdentitiesError) {
e.init()
for k, v := range other.failedIdentities {
e.failedIdentities[k] = v
}
return false
}
func (e *CreateIdentitiesError) Contains(ident *Identity) bool {
e.init()
_, found := e.failedIdentities[ident]
return found
}
func (e *CreateIdentitiesError) Find(ident *Identity) *FailedIdentity {
for _, failed := range e.Failed {
if failed.Identity.ID == ident.ID {
return failed
}
e.init()
if err, found := e.failedIdentities[ident]; found {
return &FailedIdentity{Identity: ident, Error: err}
}

return nil
}
func (e *CreateIdentitiesError) ErrOrNil() error {
if len(e.Failed) == 0 {
if e.failedIdentities == nil || len(e.failedIdentities) == 0 {
return nil
}
return e
}
func (e *CreateIdentitiesError) init() {
if e.failedIdentities == nil {
e.failedIdentities = map[*Identity]*herodot.DefaultError{}
}
}

func (m *Manager) CreateIdentities(ctx context.Context, identities []*Identity, opts ...ManagerOption) (err error) {
ctx, span := m.r.Tracer(ctx).Tracer().Start(ctx, "identity.Manager.CreateIdentities")
Expand All @@ -383,18 +398,15 @@ func (m *Manager) CreateIdentities(ctx context.Context, identities []*Identity,

o := newManagerOptions(opts)
if err := m.ValidateIdentity(ctx, ident, o); err != nil {
createIdentitiesError.Failed = append(createIdentitiesError.Failed, &FailedIdentity{
Identity: ident,
Error: herodot.ErrBadRequest.WithReasonf("%s", err).WithWrap(err),
})
createIdentitiesError.AddFailedIdentity(ident, herodot.ErrBadRequest.WithReasonf("%s", err).WithWrap(err))
continue
}
validIdentities = append(validIdentities, ident)
}

if err := m.r.PrivilegedIdentityPool().CreateIdentities(ctx, validIdentities...); err != nil {
if partialErr := new(CreateIdentitiesError); errors.As(err, &partialErr) {
createIdentitiesError.Failed = append(createIdentitiesError.Failed, partialErr.Failed...)
createIdentitiesError.Merge(partialErr)
} else {
return err
}
Expand Down
7 changes: 2 additions & 5 deletions persistence/sql/identity/persister_identity.go
Original file line number Diff line number Diff line change
Expand Up @@ -629,10 +629,7 @@ func (p *IdentityPersister) CreateIdentities(ctx context.Context, identities ...
failedIDs := make([]uuid.UUID, 0, len(failedIdentityIDs))
for _, ident := range identities {
if _, ok := failedIdentityIDs[ident.ID]; ok {
partialErr.Failed = append(partialErr.Failed, &identity.FailedIdentity{
Identity: ident,
Error: sqlcon.ErrUniqueViolation,
})
partialErr.AddFailedIdentity(ident, sqlcon.ErrUniqueViolation)
failedIDs = append(failedIDs, ident.ID)
}
}
Expand All @@ -643,7 +640,7 @@ func (p *IdentityPersister) CreateIdentities(ctx context.Context, identities ...
}
// Wrap the partial error with the first error that occurred, so that the caller
// can continue to handle the error either as a partial error or a full error.
return partialErr.Failed[0].Error.WithWrap(partialErr)
return partialErr
}

return nil
Expand Down

0 comments on commit a2842a0

Please sign in to comment.