Skip to content

Commit

Permalink
Updating opts to not leak URL parameters
Browse files Browse the repository at this point in the history
  • Loading branch information
mvbrock committed Dec 30, 2024
1 parent 326795b commit 0dc6820
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 23 deletions.
37 changes: 21 additions & 16 deletions lib/msgraph/paginated.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,10 +29,17 @@ import (
"github.com/gravitational/trace"
)

const expandParameter = "$expand"
const expandMemberOf = "memberOf"

type IterateOptions struct {
ExpandMembers bool
}

// iterateSimple implements pagination for "simple" object lists, where additional logic isn't needed
func iterateSimple[T any](c *Client, ctx context.Context, endpoint string, params *url.Values, f func(*T) bool) error {
func iterateSimple[T any](c *Client, ctx context.Context, endpoint string, opts *IterateOptions, f func(*T) bool) error {
var err error
itErr := c.iterate(ctx, endpoint, params, func(msg json.RawMessage) bool {
itErr := c.iterate(ctx, endpoint, opts, func(msg json.RawMessage) bool {
var page []T
if err = json.Unmarshal(msg, &page); err != nil {
return false
Expand All @@ -51,19 +58,17 @@ func iterateSimple[T any](c *Client, ctx context.Context, endpoint string, param
}

// iterate implements pagination for "list" endpoints.
func (c *Client) iterate(ctx context.Context, endpoint string, params *url.Values, f func(json.RawMessage) bool) error {
func (c *Client) iterate(ctx context.Context, endpoint string, opts *IterateOptions, f func(json.RawMessage) bool) error {
uri := *c.baseURL
uri.Path = path.Join(uri.Path, endpoint)
rawQuery := url.Values{
"$top": {
strconv.Itoa(c.pageSize),
},
}
if params != nil {
for key, values := range *params {
for _, value := range values {
rawQuery.Add(key, value)
}
if opts != nil {
if opts.ExpandMembers {
rawQuery.Set(expandParameter, expandMemberOf)
}
}
uri.RawQuery = rawQuery.Encode()
Expand Down Expand Up @@ -93,32 +98,32 @@ func (c *Client) iterate(ctx context.Context, endpoint string, params *url.Value
// `f` will be called for each object in the result set.
// if `f` returns `false`, the iteration is stopped (equivalent to `break` in a normal loop).
// Ref: [https://learn.microsoft.com/en-us/graph/api/application-list].
func (c *Client) IterateApplications(ctx context.Context, params *url.Values, f func(*Application) bool) error {
return iterateSimple(c, ctx, "applications", params, f)
func (c *Client) IterateApplications(ctx context.Context, opts *IterateOptions, f func(*Application) bool) error {
return iterateSimple(c, ctx, "applications", opts, f)
}

// IterateGroups lists all groups in the Entra ID directory using pagination.
// `f` will be called for each object in the result set.
// if `f` returns `false`, the iteration is stopped (equivalent to `break` in a normal loop).
// Ref: [https://learn.microsoft.com/en-us/graph/api/group-list].
func (c *Client) IterateGroups(ctx context.Context, params *url.Values, f func(*Group) bool) error {
return iterateSimple(c, ctx, "groups", params, f)
func (c *Client) IterateGroups(ctx context.Context, opts *IterateOptions, f func(*Group) bool) error {
return iterateSimple(c, ctx, "groups", opts, f)
}

// IterateUsers lists all users in the Entra ID directory using pagination.
// `f` will be called for each object in the result set.
// if `f` returns `false`, the iteration is stopped (equivalent to `break` in a normal loop).
// Ref: [https://learn.microsoft.com/en-us/graph/api/user-list].
func (c *Client) IterateUsers(ctx context.Context, params *url.Values, f func(*User) bool) error {
return iterateSimple(c, ctx, "users", params, f)
func (c *Client) IterateUsers(ctx context.Context, opts *IterateOptions, f func(*User) bool) error {
return iterateSimple(c, ctx, "users", opts, f)
}

// IterateServicePrincipals lists all service principals in the Entra ID directory using pagination.
// `f` will be called for each object in the result set.
// if `f` returns `false`, the iteration is stopped (equivalent to `break` in a normal loop).
// Ref: [https://learn.microsoft.com/en-us/graph/api/user-list].
func (c *Client) IterateServicePrincipals(ctx context.Context, params *url.Values, f func(principal *ServicePrincipal) bool) error {
return iterateSimple(c, ctx, "servicePrincipals", params, f)
func (c *Client) IterateServicePrincipals(ctx context.Context, opts *IterateOptions, f func(principal *ServicePrincipal) bool) error {
return iterateSimple(c, ctx, "servicePrincipals", opts, f)
}

// IterateGroupMembers lists all members for the given Entra ID group using pagination.
Expand Down
12 changes: 5 additions & 7 deletions lib/srv/discovery/fetchers/azure-sync/principals.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,6 @@ package azure_sync

import (
"context"
"net/url"

"github.com/gravitational/trace"
"google.golang.org/protobuf/types/known/timestamppb"

Expand All @@ -40,29 +38,29 @@ type queryResult struct {

// fetchPrincipals fetches the Azure principals (users, groups, and service principals) using the Graph API
func fetchPrincipals(ctx context.Context, subscriptionID string, cli *msgraph.Client) ([]*accessgraphv1alpha.AzurePrincipal, error) { //nolint: unused // invoked in a dependent PR
var params = &url.Values{
"$expand": []string{"memberOf"},
var opts = &msgraph.IterateOptions{
ExpandMembers: true,
}

// Fetch the users, groups, and service principals as directory objects
var queryResults []queryResult
err := cli.IterateUsers(ctx, params, func(user *msgraph.User) bool {
err := cli.IterateUsers(ctx, opts, func(user *msgraph.User) bool {
res := queryResult{metadata: dirObjMetadata{objectType: "user"}, dirObj: user.DirectoryObject}
queryResults = append(queryResults, res)
return true
})
if err != nil {
return nil, trace.Wrap(err)
}
err = cli.IterateGroups(ctx, params, func(group *msgraph.Group) bool {
err = cli.IterateGroups(ctx, opts, func(group *msgraph.Group) bool {
res := queryResult{metadata: dirObjMetadata{objectType: "group"}, dirObj: group.DirectoryObject}
queryResults = append(queryResults, res)
return true
})
if err != nil {
return nil, trace.Wrap(err)
}
err = cli.IterateServicePrincipals(ctx, params, func(servicePrincipal *msgraph.ServicePrincipal) bool {
err = cli.IterateServicePrincipals(ctx, opts, func(servicePrincipal *msgraph.ServicePrincipal) bool {
res := queryResult{metadata: dirObjMetadata{objectType: "servicePrincipal"}, dirObj: servicePrincipal.DirectoryObject}
queryResults = append(queryResults, res)
return true
Expand Down

0 comments on commit 0dc6820

Please sign in to comment.