diff --git a/lib/services/local/access_list.go b/lib/services/local/access_list.go index d7d748546a1a3..611d9de23f094 100644 --- a/lib/services/local/access_list.go +++ b/lib/services/local/access_list.go @@ -33,6 +33,7 @@ import ( accesslistv1 "github.com/gravitational/teleport/api/gen/proto/go/teleport/accesslist/v1" "github.com/gravitational/teleport/api/types" "github.com/gravitational/teleport/api/types/accesslist" + "github.com/gravitational/teleport/api/types/common" "github.com/gravitational/teleport/api/types/header" "github.com/gravitational/teleport/entitlements" "github.com/gravitational/teleport/lib/accesslists" @@ -544,6 +545,11 @@ func (a *AccessListService) UpsertAccessListMember(ctx context.Context, member * if err != nil { return trace.Wrap(err) } + existingMember, err := a.memberService.GetResource(ctx, member.GetName()) + if err != nil && !trace.IsNotFound(err) { + return trace.Wrap(err) + } + keepAWSIdentityCenterLabels(existingMember, member) if err := accesslists.ValidateAccessListMember(ctx, memberList, member, &accessListAndMembersGetter{a.service, a.memberService}); err != nil { return trace.Wrap(err) @@ -575,6 +581,11 @@ func (a *AccessListService) UpdateAccessListMember(ctx context.Context, member * if err != nil { return trace.Wrap(err) } + existingMember, err := a.memberService.GetResource(ctx, member.GetName()) + if err != nil && !trace.IsNotFound(err) { + return trace.Wrap(err) + } + keepAWSIdentityCenterLabels(existingMember, member) if err := accesslists.ValidateAccessListMember(ctx, memberList, member, &accessListAndMembersGetter{a.service, a.memberService}); err != nil { return trace.Wrap(err) @@ -732,6 +743,7 @@ func (a *AccessListService) UpsertAccessListWithMembers(ctx context.Context, acc if existingMember.Spec.Reason != "" { newMember.Spec.Reason = existingMember.Spec.Reason } + keepAWSIdentityCenterLabels(existingMember, newMember) newMember.Spec.AddedBy = existingMember.Spec.AddedBy // Compare members and update if necessary. @@ -1029,3 +1041,17 @@ func (a *AccessListService) VerifyAccessListCreateLimit(ctx context.Context, tar const limitReachedMessage = "cluster has reached its limit for creating access lists, please contact the cluster administrator" return trace.AccessDenied(limitReachedMessage) } + +// keepAWSIdentityCenterLabels preserves member labels if +// it originated from AWS Identity Center plugin. +// The Web UI does not currently preserve metadata labels so this function should be called +// in every update/upsert member calls. +// Remove this function once https://github.com/gravitational/teleport.e/issues/5415 is addressed. +func keepAWSIdentityCenterLabels(old, new *accesslist.AccessListMember) { + if old == nil || new == nil { + return + } + if old.Origin() == common.OriginAWSIdentityCenter { + new.Metadata.Labels = old.GetAllLabels() + } +} diff --git a/lib/services/local/access_list_test.go b/lib/services/local/access_list_test.go index 97bcdd5eb44ae..f1bac7e27535d 100644 --- a/lib/services/local/access_list_test.go +++ b/lib/services/local/access_list_test.go @@ -32,6 +32,7 @@ import ( "github.com/stretchr/testify/require" "github.com/gravitational/teleport/api/types/accesslist" + "github.com/gravitational/teleport/api/types/common" "github.com/gravitational/teleport/api/types/header" "github.com/gravitational/teleport/api/types/trait" "github.com/gravitational/teleport/entitlements" @@ -662,6 +663,51 @@ func TestAccessListMembersCRUD(t *testing.T) { require.ErrorIs(t, err, trace.NotFound("access_list %q doesn't exist", accessList2.GetName())) } +func TestUpsertAndUpdateAccessListWithMembers_PreservesIdentityCenterLablesForExistingMembers(t *testing.T) { + ctx := context.Background() + clock := clockwork.NewFakeClock() + mem, err := memory.New(memory.Config{ + Context: ctx, + Clock: clock, + }) + require.NoError(t, err) + service := newAccessListService(t, mem, clock, true /* igsEnabled */) + + accessList1 := newAccessList(t, "accessList1", clock) + _, err = service.UpsertAccessList(ctx, accessList1) + require.NoError(t, err) + accessList1Member1 := newAccessListMember(t, accessList1.GetName(), "aws-ic-user") + accessList1Member1.SetOrigin(common.OriginAWSIdentityCenter) + accessList1Member1.Metadata.Labels["foo"] = "bar" + + _, err = service.UpsertAccessListMember(ctx, accessList1Member1) + require.NoError(t, err) + + member, err := service.GetAccessListMember(ctx, accessList1.GetName(), accessList1Member1.GetName()) + require.NoError(t, err) + require.Empty( + t, + cmp.Diff( + accessList1Member1, + member, + cmpopts.IgnoreFields(header.Metadata{}, "Revision"), + cmpopts.IgnoreFields(accesslist.AccessListMemberSpec{}, "Joined"), + )) + + dupeMemberButWithoutOriginLabel := newAccessListMember(t, accessList1.GetName(), "aws-ic-user") + _, updatedMembers, err := service.UpsertAccessListWithMembers(ctx, accessList1, []*accesslist.AccessListMember{dupeMemberButWithoutOriginLabel}) + require.NoError(t, err) + require.Equal(t, "bar", updatedMembers[0].GetMetadata().Labels["foo"]) + + updatedMember, err := service.UpdateAccessListMember(ctx, dupeMemberButWithoutOriginLabel) + require.NoError(t, err) + require.Equal(t, "bar", updatedMember.GetMetadata().Labels["foo"]) + + upsertedMember, err := service.UpdateAccessListMember(ctx, dupeMemberButWithoutOriginLabel) + require.NoError(t, err) + require.Equal(t, "bar", upsertedMember.GetMetadata().Labels["foo"]) +} + func TestAccessListReviewCRUD(t *testing.T) { ctx := context.Background() clock := clockwork.NewFakeClock()