Skip to content

Commit

Permalink
check access list owners of nested lists
Browse files Browse the repository at this point in the history
  • Loading branch information
Alex McGrath committed Jun 4, 2024
1 parent 9375cc6 commit e62c03c
Show file tree
Hide file tree
Showing 4 changed files with 82 additions and 57 deletions.
4 changes: 2 additions & 2 deletions lib/auth/userloginstate/generator.go
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,7 @@ func NewGenerator(config GeneratorConfig) (*Generator, error) {
accessLists: config.AccessLists,
access: config.Access,
usageEvents: config.UsageEvents,
memberChecker: services.NewAccessListMembershipChecker(config.Clock, config.AccessLists, config.AccessLists, config.Access),
memberChecker: services.NewAccessListMembershipChecker(config.Clock, config.AccessLists, config.AccessLists),
clock: config.Clock,
}, nil
}
Expand Down Expand Up @@ -192,7 +192,7 @@ func (g *Generator) addAccessListsToState(ctx context.Context, user types.User,
}

for _, accessList := range accessLists {
if err := services.IsAccessListOwner(identity, accessList); err == nil {
if err := services.IsAccessListOwner(ctx, g.accessLists, identity, accessList); err == nil {
g.grantRolesAndTraits(identity, accessList.Spec.OwnerGrants, state)
}

Expand Down
4 changes: 4 additions & 0 deletions lib/modules/modules.go
Original file line number Diff line number Diff line change
Expand Up @@ -256,6 +256,7 @@ type AccessResourcesGetter interface {
ListResources(ctx context.Context, req proto.ListResourcesRequest) (*types.ListResourcesResponse, error)

GetAccessList(context.Context, string) (*accesslist.AccessList, error)
GetAccessLists(ctx context.Context) ([]*accesslist.AccessList, error)

ListAccessListMembers(ctx context.Context, accessList string, pageSize int, pageToken string) (members []*accesslist.AccessListMember, nextToken string, err error)
GetAccessListMember(ctx context.Context, accessList string, memberName string) (*accesslist.AccessListMember, error)
Expand All @@ -278,8 +279,11 @@ type AccessListSuggestionClient interface {
type RoleGetter interface {
GetRole(ctx context.Context, name string) (types.Role, error)
}

type AccessListGetter interface {
GetAccessList(ctx context.Context, name string) (*accesslist.AccessList, error)
GetAccessLists(ctx context.Context) ([]*accesslist.AccessList, error)
GetAccessListMember(ctx context.Context, accessList string, memberName string) (*accesslist.AccessListMember, error)
}

// Modules defines interface that external libraries can implement customizing
Expand Down
121 changes: 68 additions & 53 deletions lib/services/access_list.go
Original file line number Diff line number Diff line change
Expand Up @@ -133,21 +133,17 @@ func (ImplicitAccessListError) Error() string {
return "requested AccessList does not have explicit member list"
}

// AccessListGetter defines an interface that can retrieve an access list.
type AccessListGetter interface {
// GetAccessList returns the specified access list resource.
GetAccessList(context.Context, string) (*accesslist.AccessList, error)
// GetAccessList returns the specified access list resource.
GetAccessLists(context.Context) ([]*accesslist.AccessList, error)
}

// AccessListMemberGetter defines an interface that can retrieve access list members.
type AccessListMemberGetter interface {
// GetAccessListMember returns the specified access list member resource.
// May return a DynamicAccessListError if the requested access list has an
// implicit member list and the underlying implementation does not have
// enough information to compute the dynamic member record.
GetAccessListMember(ctx context.Context, accessList string, memberName string) (*accesslist.AccessListMember, error)
// GetAccessList returns the specified access list resource.
GetAccessList(context.Context, string) (*accesslist.AccessList, error)
// GetAccessList returns the specified access list resource.
GetAccessLists(context.Context) ([]*accesslist.AccessList, error)
}

// AccessListMembersGetter defines an interface for reading access list members.
Expand Down Expand Up @@ -227,7 +223,7 @@ func UnmarshalAccessListMember(data []byte, opts ...MarshalOption) (*accesslist.
}

// IsAccessListOwner will return true if the user is an owner for the current list.
func IsAccessListOwner(identity tlsca.Identity, accessList *accesslist.AccessList) error {
func IsAccessListOwner(ctx context.Context, members AccessListMemberGetter, identity tlsca.Identity, accessList *accesslist.AccessList) error {
// An opaque access denied error.
accessDenied := trace.AccessDenied("access denied")

Expand All @@ -236,6 +232,10 @@ func IsAccessListOwner(identity tlsca.Identity, accessList *accesslist.AccessLis
return owner.Name == identity.Username
})
if ownerIdx == -1 {
err := recursiveIsAccessListOwnerCheck(ctx, members, identity, accessList)
if err == nil {
return nil
}
return accessDenied
}

Expand All @@ -251,19 +251,17 @@ func IsAccessListOwner(identity tlsca.Identity, accessList *accesslist.AccessLis
// AccessListMembershipChecker will check if users are members of an access list and
// makes sure the user is not locked and meets membership requirements.
type AccessListMembershipChecker struct {
members AccessListMemberGetter
accessList AccessListGetter
locks LockGetter
clock clockwork.Clock
members AccessListMemberGetter
locks LockGetter
clock clockwork.Clock
}

// NewAccessListMembershipChecker will create a new access list membership checker.
func NewAccessListMembershipChecker(clock clockwork.Clock, members AccessListMemberGetter, accessLists AccessListGetter, locks LockGetter) *AccessListMembershipChecker {
func NewAccessListMembershipChecker(clock clockwork.Clock, members AccessListMemberGetter, locks LockGetter) *AccessListMembershipChecker {
return &AccessListMembershipChecker{
accessList: accessLists,
members: members,
locks: locks,
clock: clock,
members: members,
locks: locks,
clock: clock,
}
}

Expand Down Expand Up @@ -303,11 +301,11 @@ func recurseAccessLists(username string, initialList string, lists map[string][]
}
}
}
return trace.NotFound("user %s is not a member of the access list or its parents")
return trace.NotFound("user %s is not a member of the access list or its parents", username)
}

func (a AccessListMembershipChecker) recursiveIsAccessListMemberCheck(ctx context.Context, identity tlsca.Identity, accessList *accesslist.AccessList) error {
acls, err := a.accessList.GetAccessLists(ctx)
acls, err := a.members.GetAccessLists(ctx)
if err != nil {
return trace.Wrap(err)
}
Expand All @@ -333,7 +331,7 @@ func (a AccessListMembershipChecker) recursiveIsAccessListMemberCheck(ctx contex
return trace.AccessDenied("user %s's membership has expired in the access list", identity.Username)
}

subAccessList, err := a.accessList.GetAccessList(ctx, current)
subAccessList, err := a.members.GetAccessList(ctx, current)
if err != nil {
return trace.Wrap(err)
}
Expand All @@ -345,6 +343,45 @@ func (a AccessListMembershipChecker) recursiveIsAccessListMemberCheck(ctx contex
return trace.Wrap(err)
}

func recursiveIsAccessListOwnerCheck(ctx context.Context, members AccessListMemberGetter, identity tlsca.Identity, accessList *accesslist.AccessList) error {
acls, err := members.GetAccessLists(ctx)
if err != nil {
return trace.Wrap(err)
}
lists := make(map[string][]string)
for _, acl := range acls {
lists[acl.GetName()] = acl.Spec.DynamicOwners.AccessLists
}
err = recurseAccessLists(identity.Username, accessList.GetName(), lists,
func(username, list string) error {
_, err := members.GetAccessListMember(ctx, list, username)
if err != nil {
return trace.Wrap(err)
}
return nil
},
func(lists map[string][]string, current string) error {
member, err := members.GetAccessListMember(ctx, current, identity.Username)
if err != nil {
return trace.Wrap(err)
}
expires := member.Spec.Expires
if !expires.IsZero() && !time.Now().Before(expires) {
return trace.AccessDenied("user %s's membership has expired in the access list", identity.Username)
}

subAccessList, err := members.GetAccessList(ctx, current)
if err != nil {
return trace.Wrap(err)
}
if !UserMeetsRequirements(identity, subAccessList.Spec.OwnershipRequires) {
return trace.AccessDenied("user %s is a member, but does not have the roles or traits required to be a member of this list", identity.Username)
}
return nil
})
return trace.Wrap(err)
}

// IsAccessListMember will return true if the user is a member for the current list.
func (a AccessListMembershipChecker) IsAccessListMember(ctx context.Context, identity tlsca.Identity, accessList *accesslist.AccessList) error {
username := identity.Username
Expand All @@ -362,49 +399,27 @@ func (a AccessListMembershipChecker) IsAccessListMember(ctx context.Context, ide
return trace.AccessDenied("user %s is currently locked", username)
}
}

member, err := a.members.GetAccessListMember(ctx, accessList.GetName(), username)
if err != nil && !trace.IsNotFound(err) {
return trace.Wrap(err)
}
// try find if the user could be a member of any lists by recursing
err := a.recursiveIsAccessListMemberCheck(ctx, identity, accessList)
if trace.IsNotFound(err) {
// try find if the user could be a member of any lists by recursing
err := a.recursiveIsAccessListMemberCheck(ctx, identity, accessList)
if trace.IsNotFound(err) {
// The member has not been found, so we know they're not a member of this list.
return trace.NotFound("user %s is not a member of the access list", username)
}
if err != nil {
// Some other error has occurred
return trace.Wrap(err)
}
return nil
} else if err != nil {
// The member has not been found, so we know they're not a member of this list.
return trace.NotFound("user %s is not a member of the access list, or its nested lists", username)
}
if err != nil {
// Some other error has occurred
return trace.Wrap(err)
}

expires := member.Spec.Expires
if !expires.IsZero() && !a.clock.Now().Before(expires) {
return trace.AccessDenied("user %s's membership has expired in the access list", username)
}

if !UserMeetsRequirements(identity, accessList.Spec.MembershipRequires) {
return trace.AccessDenied("user %s is a member, but does not have the roles or traits required to be a member of this list", username)
}

return nil
}

// TODO(mdwn): Remove this in favor of using the access list membership checker.
func IsAccessListMember(ctx context.Context, identity tlsca.Identity, clock clockwork.Clock, accessList *accesslist.AccessList, accessListGetter AccessListGetter, members AccessListMemberGetter) error {
func IsAccessListMember(ctx context.Context, identity tlsca.Identity, clock clockwork.Clock, accessList *accesslist.AccessList, members AccessListMemberGetter) error {
// See if the member getter also implements lock getter. If so, use it. Otherwise, nil is fine.
lockGetter, _ := members.(LockGetter)
return AccessListMembershipChecker{
accessList: accessListGetter,
members: members,
locks: lockGetter,
clock: clock,
members: members,
locks: lockGetter,
clock: clock,
}.IsAccessListMember(ctx, identity, accessList)
}

Expand Down
10 changes: 8 additions & 2 deletions lib/services/access_list_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -257,10 +257,11 @@ func TestIsAccessListOwner(t *testing.T) {
test := test
t.Run(test.name, func(t *testing.T) {
t.Parallel()
ctx := context.Background()

accessList := newAccessList(t)

test.errAssertionFunc(t, IsAccessListOwner(test.identity, accessList))
test.errAssertionFunc(t, IsAccessListOwner(ctx, &testMembersAndLockGetter{}, test.identity, accessList))
})
}
}
Expand All @@ -276,6 +277,11 @@ func (t *testMembersAndLockGetter) GetAccessList(context.Context, string) (*acce
return nil, trace.NotImplemented("not implemented")
}

// GetAccessList implements AccessListGetter.
func (t *testMembersAndLockGetter) GetAccessLists(context.Context) ([]*accesslist.AccessList, error) {
return nil, trace.NotImplemented("not implemented")
}

// ListAccessListMembers returns a paginated list of all access list members.
func (t *testMembersAndLockGetter) ListAccessListMembers(ctx context.Context, accessList string, _ int, _ string) (members []*accesslist.AccessListMember, nextToken string, err error) {
for _, member := range t.members[accessList] {
Expand Down Expand Up @@ -494,7 +500,7 @@ func TestIsAccessListMemberChecker(t *testing.T) {
}
getter := &testMembersAndLockGetter{members: memberMap, locks: test.locks}

checker := NewAccessListMembershipChecker(clockwork.NewFakeClockAt(test.currentTime), getter, getter, getter)
checker := NewAccessListMembershipChecker(clockwork.NewFakeClockAt(test.currentTime), getter, getter)
test.errAssertionFunc(t, checker.IsAccessListMember(ctx, test.identity, accessList))
})
}
Expand Down

0 comments on commit e62c03c

Please sign in to comment.