diff --git a/cmd/configure.go b/cmd/configure.go index b1b1819..32c5e87 100644 --- a/cmd/configure.go +++ b/cmd/configure.go @@ -58,7 +58,7 @@ var configureCmd = &cobra.Command{ func configureCmdImpl(cmd *cobra.Command, args []string) { if err := configure(); err != nil { - exit(err) + exit(fmt.Errorf("failed to configure cobra CLI: %w", err)) } } diff --git a/cmd/install_windows.go b/cmd/install_windows.go index e833b42..25dac29 100644 --- a/cmd/install_windows.go +++ b/cmd/install_windows.go @@ -60,9 +60,9 @@ func installCmdImpl(cmd *cobra.Command, args []string) { ) if err := configureService(); err != nil { - exit(err) + exit(fmt.Errorf("failed to configure service: %w", err)) } else if err := installService(constants.DisplayName, config, recoveryActions); err != nil { - exit(err) + exit(fmt.Errorf("failed to install service: %w", err)) } } diff --git a/cmd/list-app-owners.go b/cmd/list-app-owners.go index 59a511d..7b713d6 100644 --- a/cmd/list-app-owners.go +++ b/cmd/list-app-owners.go @@ -19,7 +19,6 @@ package cmd import ( "context" - "fmt" "os" "os/signal" "sync" @@ -48,56 +47,37 @@ func listAppOwnersCmdImpl(cmd *cobra.Command, args []string) { defer gracefulShutdown(stop) log.V(1).Info("testing connections") - if err := testConnections(); err != nil { - exit(err) - } else if azClient, err := newAzureClient(); err != nil { - exit(err) - } else { - log.Info("collecting azure app owners...") - start := time.Now() - stream := listAppOwners(ctx, azClient, listApps(ctx, azClient)) - outputStream(ctx, stream) - duration := time.Since(start) - log.Info("collection completed", "duration", duration.String()) - } + azClient := connectAndCreateClient() + log.Info("collecting azure app owners...") + start := time.Now() + stream := listAppOwners(ctx, azClient, listApps(ctx, azClient)) + outputStream(ctx, stream) + duration := time.Since(start) + log.Info("collection completed", "duration", duration.String()) } -func listAppOwners(ctx context.Context, client client.AzureClient, apps <-chan interface{}) <-chan interface{} { +func listAppOwners(ctx context.Context, client client.AzureClient, apps <-chan azureWrapper[models.App]) <-chan azureWrapper[models.AppOwners] { var ( - out = make(chan interface{}) - ids = make(chan string) - streams = pipeline.Demux(ctx.Done(), ids, 25) + out = make(chan azureWrapper[models.AppOwners]) + streams = pipeline.Demux(ctx.Done(), apps, 25) wg sync.WaitGroup ) - go func() { - defer close(ids) - - for result := range pipeline.OrDone(ctx.Done(), apps) { - if app, ok := result.(AzureWrapper).Data.(models.App); !ok { - log.Error(fmt.Errorf("failed type assertion"), "unable to continue enumerating app owners", "result", result) - return - } else { - ids <- app.Id - } - } - }() - wg.Add(len(streams)) for i := range streams { stream := streams[i] go func() { defer wg.Done() - for id := range stream { + for app := range stream { var ( data = models.AppOwners{ - AppId: id, + AppId: app.Data.AppId, } count = 0 ) - for item := range client.ListAzureADAppOwners(ctx, id, "", "", "", nil) { + for item := range client.ListAzureADAppOwners(ctx, app.Data.Id, "", "", "", nil) { if item.Error != nil { - log.Error(item.Error, "unable to continue processing owners for this app", "appId", id) + log.Error(item.Error, "unable to continue processing owners for this app", "appId", app.Data.AppId) } else { appOwner := models.AppOwner{ Owner: item.Ok, @@ -109,11 +89,11 @@ func listAppOwners(ctx context.Context, client client.AzureClient, apps <-chan i } } - out <- AzureWrapper{ - Kind: enums.KindAZAppOwner, - Data: data, - } - log.V(1).Info("finished listing app owners", "appId", id, "count", count) + out <- NewAzureWrapper( + enums.KindAZAppOwner, + data, + ) + log.V(1).Info("finished listing app owners", "appId", app.Data.AppId, "count", count) } }() } diff --git a/cmd/list-app-owners_test.go b/cmd/list-app-owners_test.go index bef3a27..4fce7ff 100644 --- a/cmd/list-app-owners_test.go +++ b/cmd/list-app-owners_test.go @@ -24,6 +24,7 @@ import ( "testing" "github.com/bloodhoundad/azurehound/client/mocks" + "github.com/bloodhoundad/azurehound/enums" "github.com/bloodhoundad/azurehound/models" "github.com/bloodhoundad/azurehound/models/azure" "github.com/golang/mock/gomock" @@ -40,7 +41,7 @@ func TestListAppOwners(t *testing.T) { mockClient := mocks.NewMockAzureClient(ctrl) - mockAppsChannel := make(chan interface{}) + mockAppsChannel := make(chan azureWrapper[models.App]) mockAppOwnerChannel := make(chan azure.AppOwnerResult) mockAppOwnerChannel2 := make(chan azure.AppOwnerResult) @@ -53,12 +54,8 @@ func TestListAppOwners(t *testing.T) { go func() { defer close(mockAppsChannel) - mockAppsChannel <- AzureWrapper{ - Data: models.App{}, - } - mockAppsChannel <- AzureWrapper{ - Data: models.App{}, - } + mockAppsChannel <- NewAzureWrapper(enums.KindAZApp, models.App{}) + mockAppsChannel <- NewAzureWrapper(enums.KindAZApp, models.App{}) }() go func() { defer close(mockAppOwnerChannel) @@ -81,21 +78,13 @@ func TestListAppOwners(t *testing.T) { if result, ok := <-channel; !ok { t.Fatalf("failed to receive from channel") - } else if wrapper, ok := result.(AzureWrapper); !ok { - t.Errorf("failed type assertion: got %T, want %T", result, AzureWrapper{}) - } else if data, ok := wrapper.Data.(models.AppOwners); !ok { - t.Errorf("failed type assertion: got %T, want %T", wrapper.Data, models.AppOwners{}) - } else if len(data.Owners) != 2 { - t.Errorf("got %v, want %v", len(data.Owners), 2) + } else if len(result.Data.Owners) != 2 { + t.Errorf("got %v, want %v", len(result.Data.Owners), 2) } if result, ok := <-channel; !ok { t.Fatalf("failed to receive from channel") - } else if wrapper, ok := result.(AzureWrapper); !ok { - t.Errorf("failed type assertion: got %T, want %T", result, AzureWrapper{}) - } else if data, ok := wrapper.Data.(models.AppOwners); !ok { - t.Errorf("failed type assertion: got %T, want %T", wrapper.Data, models.AppOwners{}) - } else if len(data.Owners) != 1 { - t.Errorf("got %v, want %v", len(data.Owners), 2) + } else if len(result.Data.Owners) != 1 { + t.Errorf("got %v, want %v", len(result.Data.Owners), 2) } } diff --git a/cmd/list-app-role-assignments.go b/cmd/list-app-role-assignments.go index 37e18dd..2a845dd 100644 --- a/cmd/list-app-role-assignments.go +++ b/cmd/list-app-role-assignments.go @@ -48,19 +48,14 @@ func listAppRoleAssignmentsCmdImpl(cmd *cobra.Command, args []string) { defer gracefulShutdown(stop) log.V(1).Info("testing connections") - if err := testConnections(); err != nil { - exit(err) - } else if azClient, err := newAzureClient(); err != nil { - exit(err) - } else { - log.Info("collecting azure active directory app role assignments...") - start := time.Now() - servicePrincipals := listServicePrincipals(ctx, azClient) - stream := listAppRoleAssignments(ctx, azClient, servicePrincipals) - outputStream(ctx, stream) - duration := time.Since(start) - log.Info("collection completed", "duration", duration.String()) - } + azClient := connectAndCreateClient() + log.Info("collecting azure active directory app role assignments...") + start := time.Now() + servicePrincipals := listServicePrincipals(ctx, azClient) + stream := listAppRoleAssignments(ctx, azClient, servicePrincipals) + outputStream(ctx, stream) + duration := time.Since(start) + log.Info("collection completed", "duration", duration.String()) } func listAppRoleAssignments(ctx context.Context, client client.AzureClient, servicePrincipals <-chan interface{}) <-chan interface{} { diff --git a/cmd/list-apps.go b/cmd/list-apps.go index a0bd005..6e392e5 100644 --- a/cmd/list-apps.go +++ b/cmd/list-apps.go @@ -45,22 +45,17 @@ func listAppsCmdImpl(cmd *cobra.Command, args []string) { defer gracefulShutdown(stop) log.V(1).Info("testing connections") - if err := testConnections(); err != nil { - exit(err) - } else if azClient, err := newAzureClient(); err != nil { - exit(err) - } else { - log.Info("collecting azure active directory applications...") - start := time.Now() - stream := listApps(ctx, azClient) - outputStream(ctx, stream) - duration := time.Since(start) - log.Info("collection completed", "duration", duration.String()) - } + azClient := connectAndCreateClient() + log.Info("collecting azure active directory applications...") + start := time.Now() + stream := listApps(ctx, azClient) + outputStream(ctx, stream) + duration := time.Since(start) + log.Info("collection completed", "duration", duration.String()) } -func listApps(ctx context.Context, client client.AzureClient) <-chan interface{} { - out := make(chan interface{}) +func listApps(ctx context.Context, client client.AzureClient) <-chan azureWrapper[models.App] { + out := make(chan azureWrapper[models.App]) go func() { defer close(out) @@ -72,14 +67,14 @@ func listApps(ctx context.Context, client client.AzureClient) <-chan interface{} } else { log.V(2).Info("found application", "app", item) count++ - out <- AzureWrapper{ - Kind: enums.KindAZApp, - Data: models.App{ + out <- NewAzureWrapper( + enums.KindAZApp, + models.App{ Application: item.Ok, TenantId: client.TenantInfo().TenantId, TenantName: client.TenantInfo().DisplayName, }, - } + ) } } log.Info("finished listing all apps", "count", count) diff --git a/cmd/list-apps_test.go b/cmd/list-apps_test.go index e734b73..19708fd 100644 --- a/cmd/list-apps_test.go +++ b/cmd/list-apps_test.go @@ -57,11 +57,7 @@ func TestListApps(t *testing.T) { }() channel := listApps(ctx, mockClient) - result := <-channel - if _, ok := result.(AzureWrapper); !ok { - t.Errorf("failed type assertion: got %T, want %T", result, AzureWrapper{}) - } - + <-channel if _, ok := <-channel; ok { t.Error("expected channel to close from an error result but it did not") } diff --git a/cmd/list-automation-account-role-assignments.go b/cmd/list-automation-account-role-assignments.go index 6cab0b9..b81dfe9 100644 --- a/cmd/list-automation-account-role-assignments.go +++ b/cmd/list-automation-account-role-assignments.go @@ -49,19 +49,14 @@ func listAutomationAccountRoleAssignmentImpl(cmd *cobra.Command, args []string) defer gracefulShutdown(stop) log.V(1).Info("testing connections") - if err := testConnections(); err != nil { - exit(err) - } else if azClient, err := newAzureClient(); err != nil { - exit(err) - } else { - log.Info("collecting azure automation account role assignments...") - start := time.Now() - subscriptions := listSubscriptions(ctx, azClient) - stream := listAutomationAccountRoleAssignments(ctx, azClient, listAutomationAccounts(ctx, azClient, subscriptions)) - outputStream(ctx, stream) - duration := time.Since(start) - log.Info("collection completed", "duration", duration.String()) - } + azClient := connectAndCreateClient() + log.Info("collecting azure automation account role assignments...") + start := time.Now() + subscriptions := listSubscriptions(ctx, azClient) + stream := listAutomationAccountRoleAssignments(ctx, azClient, listAutomationAccounts(ctx, azClient, subscriptions)) + outputStream(ctx, stream) + duration := time.Since(start) + log.Info("collection completed", "duration", duration.String()) } func listAutomationAccountRoleAssignments(ctx context.Context, client client.AzureClient, automationAccounts <-chan interface{}) <-chan interface{} { diff --git a/cmd/list-automation-accounts.go b/cmd/list-automation-accounts.go index 676db9c..be269e4 100644 --- a/cmd/list-automation-accounts.go +++ b/cmd/list-automation-accounts.go @@ -48,18 +48,13 @@ func listAutomationAccountsCmdImpl(cmd *cobra.Command, args []string) { defer gracefulShutdown(stop) log.V(1).Info("testing connections") - if err := testConnections(); err != nil { - exit(err) - } else if azClient, err := newAzureClient(); err != nil { - exit(err) - } else { - log.Info("collecting azure automation accounts...") - start := time.Now() - stream := listAutomationAccounts(ctx, azClient, listSubscriptions(ctx, azClient)) - outputStream(ctx, stream) - duration := time.Since(start) - log.Info("collection completed", "duration", duration.String()) - } + azClient := connectAndCreateClient() + log.Info("collecting azure automation accounts...") + start := time.Now() + stream := listAutomationAccounts(ctx, azClient, listSubscriptions(ctx, azClient)) + outputStream(ctx, stream) + duration := time.Since(start) + log.Info("collection completed", "duration", duration.String()) } func listAutomationAccounts(ctx context.Context, client client.AzureClient, subscriptions <-chan interface{}) <-chan interface{} { diff --git a/cmd/list-azure-ad.go b/cmd/list-azure-ad.go index 42c3f6a..2501037 100644 --- a/cmd/list-azure-ad.go +++ b/cmd/list-azure-ad.go @@ -50,25 +50,17 @@ func listAzureADCmdImpl(cmd *cobra.Command, args []string) { defer gracefulShutdown(stop) log.V(1).Info("testing connections") - if err := testConnections(); err != nil { - exit(err) - } else if azClient, err := newAzureClient(); err != nil { - exit(err) - } else { - log.Info("collecting azure ad objects...") - start := time.Now() - stream := listAllAD(ctx, azClient) - outputStream(ctx, stream) - duration := time.Since(start) - log.Info("collection completed", "duration", duration.String()) - } + azClient := connectAndCreateClient() + log.Info("collecting azure ad objects...") + start := time.Now() + stream := listAllAD(ctx, azClient) + outputStream(ctx, stream) + duration := time.Since(start) + log.Info("collection completed", "duration", duration.String()) } func listAllAD(ctx context.Context, client client.AzureClient) <-chan interface{} { var ( - apps = make(chan interface{}) - apps2 = make(chan interface{}) - devices = make(chan interface{}) devices2 = make(chan interface{}) @@ -87,8 +79,9 @@ func listAllAD(ctx context.Context, client client.AzureClient) <-chan interface{ ) // Enumerate Apps, AppOwners and AppMembers - pipeline.Tee(ctx.Done(), listApps(ctx, client), apps, apps2) - appOwners := listAppOwners(ctx, client, apps2) + appChans := pipeline.TeeFixed(ctx.Done(), listApps(ctx, client), 2) + apps := pipeline.ToAny(ctx.Done(), appChans[0]) + appOwners := pipeline.ToAny(ctx.Done(), listAppOwners(ctx, client, appChans[1])) // Enumerate Devices and DeviceOwners pipeline.Tee(ctx.Done(), listDevices(ctx, client), devices, devices2) diff --git a/cmd/list-azure-rm.go b/cmd/list-azure-rm.go index 0c1117d..33ba125 100644 --- a/cmd/list-azure-rm.go +++ b/cmd/list-azure-rm.go @@ -52,18 +52,13 @@ func listAzureRMCmdImpl(cmd *cobra.Command, args []string) { defer gracefulShutdown(stop) log.V(1).Info("testing connections") - if err := testConnections(); err != nil { - exit(err) - } else if azClient, err := newAzureClient(); err != nil { - exit(err) - } else { - log.Info("collecting azure resource management objects...") - start := time.Now() - stream := listAllRM(ctx, azClient) - outputStream(ctx, stream) - duration := time.Since(start) - log.Info("collection completed", "duration", duration.String()) - } + azClient := connectAndCreateClient() + log.Info("collecting azure resource management objects...") + start := time.Now() + stream := listAllRM(ctx, azClient) + outputStream(ctx, stream) + duration := time.Since(start) + log.Info("collection completed", "duration", duration.String()) } func listAllRM(ctx context.Context, client client.AzureClient) <-chan interface{} { diff --git a/cmd/list-device-owners.go b/cmd/list-device-owners.go index a60ccde..5edfbea 100644 --- a/cmd/list-device-owners.go +++ b/cmd/list-device-owners.go @@ -48,18 +48,13 @@ func listDeviceOwnersCmdImpl(cmd *cobra.Command, args []string) { defer gracefulShutdown(stop) log.V(1).Info("testing connections") - if err := testConnections(); err != nil { - exit(err) - } else if azClient, err := newAzureClient(); err != nil { - exit(err) - } else { - log.Info("collecting azure device owners...") - start := time.Now() - stream := listDeviceOwners(ctx, azClient, listDevices(ctx, azClient)) - outputStream(ctx, stream) - duration := time.Since(start) - log.Info("collection completed", "duration", duration.String()) - } + azClient := connectAndCreateClient() + log.Info("collecting azure device owners...") + start := time.Now() + stream := listDeviceOwners(ctx, azClient, listDevices(ctx, azClient)) + outputStream(ctx, stream) + duration := time.Since(start) + log.Info("collection completed", "duration", duration.String()) } func listDeviceOwners(ctx context.Context, client client.AzureClient, devices <-chan interface{}) <-chan interface{} { diff --git a/cmd/list-devices.go b/cmd/list-devices.go index e2585e1..3b24c64 100644 --- a/cmd/list-devices.go +++ b/cmd/list-devices.go @@ -45,18 +45,13 @@ func listDevicesCmdImpl(cmd *cobra.Command, args []string) { defer gracefulShutdown(stop) log.V(1).Info("testing connections") - if err := testConnections(); err != nil { - exit(err) - } else if azClient, err := newAzureClient(); err != nil { - exit(err) - } else { - log.Info("collecting azure active directory devices...") - start := time.Now() - stream := listDevices(ctx, azClient) - outputStream(ctx, stream) - duration := time.Since(start) - log.Info("collection completed", "duration", duration.String()) - } + azClient := connectAndCreateClient() + log.Info("collecting azure active directory devices...") + start := time.Now() + stream := listDevices(ctx, azClient) + outputStream(ctx, stream) + duration := time.Since(start) + log.Info("collection completed", "duration", duration.String()) } func listDevices(ctx context.Context, client client.AzureClient) <-chan interface{} { diff --git a/cmd/list-function-app-role-assignments.go b/cmd/list-function-app-role-assignments.go index de428dc..defd85a 100644 --- a/cmd/list-function-app-role-assignments.go +++ b/cmd/list-function-app-role-assignments.go @@ -49,19 +49,14 @@ func listFunctionAppRoleAssignmentImpl(cmd *cobra.Command, args []string) { defer gracefulShutdown(stop) log.V(1).Info("testing connections") - if err := testConnections(); err != nil { - exit(err) - } else if azClient, err := newAzureClient(); err != nil { - exit(err) - } else { - log.Info("collecting azure function app role assignments...") - start := time.Now() - subscriptions := listSubscriptions(ctx, azClient) - stream := listFunctionAppRoleAssignments(ctx, azClient, listFunctionApps(ctx, azClient, subscriptions)) - outputStream(ctx, stream) - duration := time.Since(start) - log.Info("collection completed", "duration", duration.String()) - } + azClient := connectAndCreateClient() + log.Info("collecting azure function app role assignments...") + start := time.Now() + subscriptions := listSubscriptions(ctx, azClient) + stream := listFunctionAppRoleAssignments(ctx, azClient, listFunctionApps(ctx, azClient, subscriptions)) + outputStream(ctx, stream) + duration := time.Since(start) + log.Info("collection completed", "duration", duration.String()) } func listFunctionAppRoleAssignments(ctx context.Context, client client.AzureClient, functionApps <-chan interface{}) <-chan interface{} { diff --git a/cmd/list-function-apps.go b/cmd/list-function-apps.go index dba07cf..0c273bb 100644 --- a/cmd/list-function-apps.go +++ b/cmd/list-function-apps.go @@ -48,18 +48,13 @@ func listFunctionAppsCmdImpl(cmd *cobra.Command, args []string) { defer gracefulShutdown(stop) log.V(1).Info("testing connections") - if err := testConnections(); err != nil { - exit(err) - } else if azClient, err := newAzureClient(); err != nil { - exit(err) - } else { - log.Info("collecting azure function apps...") - start := time.Now() - stream := listFunctionApps(ctx, azClient, listSubscriptions(ctx, azClient)) - outputStream(ctx, stream) - duration := time.Since(start) - log.Info("collection completed", "duration", duration.String()) - } + azClient := connectAndCreateClient() + log.Info("collecting azure function apps...") + start := time.Now() + stream := listFunctionApps(ctx, azClient, listSubscriptions(ctx, azClient)) + outputStream(ctx, stream) + duration := time.Since(start) + log.Info("collection completed", "duration", duration.String()) } func listFunctionApps(ctx context.Context, client client.AzureClient, subscriptions <-chan interface{}) <-chan interface{} { diff --git a/cmd/list-group-members.go b/cmd/list-group-members.go index 51e9478..5699707 100644 --- a/cmd/list-group-members.go +++ b/cmd/list-group-members.go @@ -48,18 +48,13 @@ func listGroupMembersCmdImpl(cmd *cobra.Command, args []string) { defer gracefulShutdown(stop) log.V(1).Info("testing connections") - if err := testConnections(); err != nil { - exit(err) - } else if azClient, err := newAzureClient(); err != nil { - exit(err) - } else { - log.Info("collecting azure group members...") - start := time.Now() - stream := listGroupMembers(ctx, azClient, listGroups(ctx, azClient)) - outputStream(ctx, stream) - duration := time.Since(start) - log.Info("collection completed", "duration", duration.String()) - } + azClient := connectAndCreateClient() + log.Info("collecting azure group members...") + start := time.Now() + stream := listGroupMembers(ctx, azClient, listGroups(ctx, azClient)) + outputStream(ctx, stream) + duration := time.Since(start) + log.Info("collection completed", "duration", duration.String()) } func listGroupMembers(ctx context.Context, client client.AzureClient, groups <-chan interface{}) <-chan interface{} { diff --git a/cmd/list-group-owners.go b/cmd/list-group-owners.go index 7cfb051..8eae384 100644 --- a/cmd/list-group-owners.go +++ b/cmd/list-group-owners.go @@ -48,18 +48,13 @@ func listGroupOwnersCmdImpl(cmd *cobra.Command, args []string) { defer gracefulShutdown(stop) log.V(1).Info("testing connections") - if err := testConnections(); err != nil { - exit(err) - } else if azClient, err := newAzureClient(); err != nil { - exit(err) - } else { - log.Info("collecting azure group owners...") - start := time.Now() - stream := listGroupOwners(ctx, azClient, listGroups(ctx, azClient)) - outputStream(ctx, stream) - duration := time.Since(start) - log.Info("collection completed", "duration", duration.String()) - } + azClient := connectAndCreateClient() + log.Info("collecting azure group owners...") + start := time.Now() + stream := listGroupOwners(ctx, azClient, listGroups(ctx, azClient)) + outputStream(ctx, stream) + duration := time.Since(start) + log.Info("collection completed", "duration", duration.String()) } func listGroupOwners(ctx context.Context, client client.AzureClient, groups <-chan interface{}) <-chan interface{} { diff --git a/cmd/list-groups.go b/cmd/list-groups.go index 72a059b..fa9eef8 100644 --- a/cmd/list-groups.go +++ b/cmd/list-groups.go @@ -45,18 +45,13 @@ func listGroupsCmdImpl(cmd *cobra.Command, args []string) { defer gracefulShutdown(stop) log.V(1).Info("testing connections") - if err := testConnections(); err != nil { - exit(err) - } else if azClient, err := newAzureClient(); err != nil { - exit(err) - } else { - log.Info("collecting azure active directory groups...") - start := time.Now() - stream := listGroups(ctx, azClient) - outputStream(ctx, stream) - duration := time.Since(start) - log.Info("collection completed", "duration", duration.String()) - } + azClient := connectAndCreateClient() + log.Info("collecting azure active directory groups...") + start := time.Now() + stream := listGroups(ctx, azClient) + outputStream(ctx, stream) + duration := time.Since(start) + log.Info("collection completed", "duration", duration.String()) } func listGroups(ctx context.Context, client client.AzureClient) <-chan interface{} { diff --git a/cmd/list-key-vault-access-policies.go b/cmd/list-key-vault-access-policies.go index aae832b..ac4db77 100644 --- a/cmd/list-key-vault-access-policies.go +++ b/cmd/list-key-vault-access-policies.go @@ -50,25 +50,20 @@ func listKeyVaultAccessPoliciesCmdImpl(cmd *cobra.Command, args []string) { defer gracefulShutdown(stop) log.V(1).Info("testing connections") - if err := testConnections(); err != nil { - exit(err) - } else if azClient, err := newAzureClient(); err != nil { - exit(err) + azClient := connectAndCreateClient() + log.Info("collecting azure key vault access policies...") + start := time.Now() + subscriptions := listSubscriptions(ctx, azClient) + if filters, ok := config.KeyVaultAccessTypes.Value().([]enums.KeyVaultAccessType); !ok { + exit(fmt.Errorf("filter failed type assertion")) } else { - log.Info("collecting azure key vault access policies...") - start := time.Now() - subscriptions := listSubscriptions(ctx, azClient) - if filters, ok := config.KeyVaultAccessTypes.Value().([]enums.KeyVaultAccessType); !ok { - exit(fmt.Errorf("filter failed type assertion")) - } else { - if len(filters) > 0 { - log.Info("applying access type filters", "filters", filters) - } - stream := listKeyVaultAccessPolicies(ctx, azClient, listKeyVaults(ctx, azClient, subscriptions), filters) - outputStream(ctx, stream) - duration := time.Since(start) - log.Info("collection completed", "duration", duration.String()) + if len(filters) > 0 { + log.Info("applying access type filters", "filters", filters) } + stream := listKeyVaultAccessPolicies(ctx, azClient, listKeyVaults(ctx, azClient, subscriptions), filters) + outputStream(ctx, stream) + duration := time.Since(start) + log.Info("collection completed", "duration", duration.String()) } } diff --git a/cmd/list-key-vault-access-policies_test.go b/cmd/list-key-vault-access-policies_test.go index a52504f..2a50626 100644 --- a/cmd/list-key-vault-access-policies_test.go +++ b/cmd/list-key-vault-access-policies_test.go @@ -50,7 +50,7 @@ func TestListKeyVaultAccessPolicies(t *testing.T) { KeyVault: azure.KeyVault{ Properties: azure.VaultProperties{ AccessPolicies: []azure.AccessPolicyEntry{ - azure.AccessPolicyEntry{ + { Permissions: azure.KeyVaultPermissions{ Certificates: []string{"Get"}, }, diff --git a/cmd/list-key-vault-contributors.go b/cmd/list-key-vault-contributors.go index 1e017e7..cbd8406 100644 --- a/cmd/list-key-vault-contributors.go +++ b/cmd/list-key-vault-contributors.go @@ -47,21 +47,16 @@ func listKeyVaultContributorsCmdImpl(cmd *cobra.Command, args []string) { defer gracefulShutdown(stop) log.V(1).Info("testing connections") - if err := testConnections(); err != nil { - exit(err) - } else if azClient, err := newAzureClient(); err != nil { - exit(err) - } else { - log.Info("collecting azure key vault contributors...") - start := time.Now() - subscriptions := listSubscriptions(ctx, azClient) - keyVaults := listKeyVaults(ctx, azClient, subscriptions) - kvRoleAssignments := listKeyVaultRoleAssignments(ctx, azClient, keyVaults) - stream := listKeyVaultContributors(ctx, kvRoleAssignments) - outputStream(ctx, stream) - duration := time.Since(start) - log.Info("collection completed", "duration", duration.String()) - } + azClient := connectAndCreateClient() + log.Info("collecting azure key vault contributors...") + start := time.Now() + subscriptions := listSubscriptions(ctx, azClient) + keyVaults := listKeyVaults(ctx, azClient, subscriptions) + kvRoleAssignments := listKeyVaultRoleAssignments(ctx, azClient, keyVaults) + stream := listKeyVaultContributors(ctx, kvRoleAssignments) + outputStream(ctx, stream) + duration := time.Since(start) + log.Info("collection completed", "duration", duration.String()) } func listKeyVaultContributors( @@ -73,8 +68,8 @@ func listKeyVaultContributors( contributors := internal.Map(filteredAssignments, func(ra models.KeyVaultRoleAssignment) models.KeyVaultContributor { return models.KeyVaultContributor{ - ra.RoleAssignment, - ra.KeyVaultId, + Contributor: ra.RoleAssignment, + KeyVaultId: ra.KeyVaultId, } }) diff --git a/cmd/list-key-vault-kvcontributors.go b/cmd/list-key-vault-kvcontributors.go index a9bd4f0..295f253 100644 --- a/cmd/list-key-vault-kvcontributors.go +++ b/cmd/list-key-vault-kvcontributors.go @@ -47,21 +47,16 @@ func listKeyVaultKVContributorsCmdImpl(cmd *cobra.Command, args []string) { defer gracefulShutdown(stop) log.V(1).Info("testing connections") - if err := testConnections(); err != nil { - exit(err) - } else if azClient, err := newAzureClient(); err != nil { - exit(err) - } else { - log.Info("collecting azure key vault kvcontributors...") - start := time.Now() - subscriptions := listSubscriptions(ctx, azClient) - keyVaults := listKeyVaults(ctx, azClient, subscriptions) - kvRoleAssignments := listKeyVaultRoleAssignments(ctx, azClient, keyVaults) - stream := listKeyVaultKVContributors(ctx, kvRoleAssignments) - outputStream(ctx, stream) - duration := time.Since(start) - log.Info("collection completed", "duration", duration.String()) - } + azClient := connectAndCreateClient() + log.Info("collecting azure key vault kvcontributors...") + start := time.Now() + subscriptions := listSubscriptions(ctx, azClient) + keyVaults := listKeyVaults(ctx, azClient, subscriptions) + kvRoleAssignments := listKeyVaultRoleAssignments(ctx, azClient, keyVaults) + stream := listKeyVaultKVContributors(ctx, kvRoleAssignments) + outputStream(ctx, stream) + duration := time.Since(start) + log.Info("collection completed", "duration", duration.String()) } func listKeyVaultKVContributors( diff --git a/cmd/list-key-vault-owners.go b/cmd/list-key-vault-owners.go index a119c98..a04bbc5 100644 --- a/cmd/list-key-vault-owners.go +++ b/cmd/list-key-vault-owners.go @@ -47,21 +47,16 @@ func listKeyVaultOwnersCmdImpl(cmd *cobra.Command, args []string) { defer gracefulShutdown(stop) log.V(1).Info("testing connections") - if err := testConnections(); err != nil { - exit(err) - } else if azClient, err := newAzureClient(); err != nil { - exit(err) - } else { - log.Info("collecting azure key vault owners...") - start := time.Now() - subscriptions := listSubscriptions(ctx, azClient) - keyVaults := listKeyVaults(ctx, azClient, subscriptions) - kvRoleAssignments := listKeyVaultRoleAssignments(ctx, azClient, keyVaults) - stream := listKeyVaultOwners(ctx, kvRoleAssignments) - outputStream(ctx, stream) - duration := time.Since(start) - log.Info("collection completed", "duration", duration.String()) - } + azClient := connectAndCreateClient() + log.Info("collecting azure key vault owners...") + start := time.Now() + subscriptions := listSubscriptions(ctx, azClient) + keyVaults := listKeyVaults(ctx, azClient, subscriptions) + kvRoleAssignments := listKeyVaultRoleAssignments(ctx, azClient, keyVaults) + stream := listKeyVaultOwners(ctx, kvRoleAssignments) + outputStream(ctx, stream) + duration := time.Since(start) + log.Info("collection completed", "duration", duration.String()) } func listKeyVaultOwners( diff --git a/cmd/list-key-vault-role-assignments.go b/cmd/list-key-vault-role-assignments.go index ef8f7da..3ff4fdc 100644 --- a/cmd/list-key-vault-role-assignments.go +++ b/cmd/list-key-vault-role-assignments.go @@ -48,19 +48,14 @@ func listKeyVaultRoleAssignmentsCmdImpl(cmd *cobra.Command, args []string) { defer gracefulShutdown(stop) log.V(1).Info("testing connections") - if err := testConnections(); err != nil { - exit(err) - } else if azClient, err := newAzureClient(); err != nil { - exit(err) - } else { - log.Info("collecting azure key vault role assignments...") - start := time.Now() - subscriptions := listSubscriptions(ctx, azClient) - stream := listKeyVaultRoleAssignments(ctx, azClient, listKeyVaults(ctx, azClient, subscriptions)) - outputStream(ctx, stream) - duration := time.Since(start) - log.Info("collection completed", "duration", duration.String()) - } + azClient := connectAndCreateClient() + log.Info("collecting azure key vault role assignments...") + start := time.Now() + subscriptions := listSubscriptions(ctx, azClient) + stream := listKeyVaultRoleAssignments(ctx, azClient, listKeyVaults(ctx, azClient, subscriptions)) + outputStream(ctx, stream) + duration := time.Since(start) + log.Info("collection completed", "duration", duration.String()) } func listKeyVaultRoleAssignments(ctx context.Context, client client.AzureClient, keyVaults <-chan interface{}) <-chan azureWrapper[models.KeyVaultRoleAssignments] { diff --git a/cmd/list-key-vault-user-access-admins.go b/cmd/list-key-vault-user-access-admins.go index 2162c87..fb29016 100644 --- a/cmd/list-key-vault-user-access-admins.go +++ b/cmd/list-key-vault-user-access-admins.go @@ -47,21 +47,16 @@ func listKeyVaultUserAccessAdminsCmdImpl(cmd *cobra.Command, args []string) { defer gracefulShutdown(stop) log.V(1).Info("testing connections") - if err := testConnections(); err != nil { - exit(err) - } else if azClient, err := newAzureClient(); err != nil { - exit(err) - } else { - log.Info("collecting azure key vault user access admins...") - start := time.Now() - subscriptions := listSubscriptions(ctx, azClient) - keyVaults := listKeyVaults(ctx, azClient, subscriptions) - kvRoleAssignments := listKeyVaultRoleAssignments(ctx, azClient, keyVaults) - stream := listKeyVaultUserAccessAdmins(ctx, kvRoleAssignments) - outputStream(ctx, stream) - duration := time.Since(start) - log.Info("collection completed", "duration", duration.String()) - } + azClient := connectAndCreateClient() + log.Info("collecting azure key vault user access admins...") + start := time.Now() + subscriptions := listSubscriptions(ctx, azClient) + keyVaults := listKeyVaults(ctx, azClient, subscriptions) + kvRoleAssignments := listKeyVaultRoleAssignments(ctx, azClient, keyVaults) + stream := listKeyVaultUserAccessAdmins(ctx, kvRoleAssignments) + outputStream(ctx, stream) + duration := time.Since(start) + log.Info("collection completed", "duration", duration.String()) } func listKeyVaultUserAccessAdmins( diff --git a/cmd/list-key-vaults.go b/cmd/list-key-vaults.go index fcee5fb..4619450 100644 --- a/cmd/list-key-vaults.go +++ b/cmd/list-key-vaults.go @@ -48,18 +48,13 @@ func listKeyVaultsCmdImpl(cmd *cobra.Command, args []string) { defer gracefulShutdown(stop) log.V(1).Info("testing connections") - if err := testConnections(); err != nil { - exit(err) - } else if azClient, err := newAzureClient(); err != nil { - exit(err) - } else { - log.Info("collecting azure key vaults...") - start := time.Now() - stream := listKeyVaults(ctx, azClient, listSubscriptions(ctx, azClient)) - outputStream(ctx, stream) - duration := time.Since(start) - log.Info("collection completed", "duration", duration.String()) - } + azClient := connectAndCreateClient() + log.Info("collecting azure key vaults...") + start := time.Now() + stream := listKeyVaults(ctx, azClient, listSubscriptions(ctx, azClient)) + outputStream(ctx, stream) + duration := time.Since(start) + log.Info("collection completed", "duration", duration.String()) } func listKeyVaults(ctx context.Context, client client.AzureClient, subscriptions <-chan interface{}) <-chan interface{} { diff --git a/cmd/list-management-group-descendants.go b/cmd/list-management-group-descendants.go index e04208e..76e2ed9 100644 --- a/cmd/list-management-group-descendants.go +++ b/cmd/list-management-group-descendants.go @@ -48,18 +48,13 @@ func listManagementGroupDescendantsCmdImpl(cmd *cobra.Command, args []string) { defer gracefulShutdown(stop) log.V(1).Info("testing connections") - if err := testConnections(); err != nil { - exit(err) - } else if azClient, err := newAzureClient(); err != nil { - exit(err) - } else { - log.Info("collecting azure management group descendants...") - start := time.Now() - stream := listManagementGroupDescendants(ctx, azClient, listManagementGroups(ctx, azClient)) - outputStream(ctx, stream) - duration := time.Since(start) - log.Info("collection completed", "duration", duration.String()) - } + azClient := connectAndCreateClient() + log.Info("collecting azure management group descendants...") + start := time.Now() + stream := listManagementGroupDescendants(ctx, azClient, listManagementGroups(ctx, azClient)) + outputStream(ctx, stream) + duration := time.Since(start) + log.Info("collection completed", "duration", duration.String()) } func listManagementGroupDescendants(ctx context.Context, client client.AzureClient, managementGroups <-chan interface{}) <-chan interface{} { diff --git a/cmd/list-management-group-owners.go b/cmd/list-management-group-owners.go index cac1c08..28cc020 100644 --- a/cmd/list-management-group-owners.go +++ b/cmd/list-management-group-owners.go @@ -47,20 +47,15 @@ func listManagementGroupOwnersCmdImpl(cmd *cobra.Command, args []string) { defer gracefulShutdown(stop) log.V(1).Info("testing connections") - if err := testConnections(); err != nil { - exit(err) - } else if azClient, err := newAzureClient(); err != nil { - exit(err) - } else { - log.Info("collecting azure management group owners...") - start := time.Now() - managementGroups := listManagementGroups(ctx, azClient) - roleAssignments := listManagementGroupRoleAssignments(ctx, azClient, managementGroups) - stream := listManagementGroupOwners(ctx, roleAssignments) - outputStream(ctx, stream) - duration := time.Since(start) - log.Info("collection completed", "duration", duration.String()) - } + azClient := connectAndCreateClient() + log.Info("collecting azure management group owners...") + start := time.Now() + managementGroups := listManagementGroups(ctx, azClient) + roleAssignments := listManagementGroupRoleAssignments(ctx, azClient, managementGroups) + stream := listManagementGroupOwners(ctx, roleAssignments) + outputStream(ctx, stream) + duration := time.Since(start) + log.Info("collection completed", "duration", duration.String()) } func listManagementGroupOwners( diff --git a/cmd/list-management-group-role-assignments.go b/cmd/list-management-group-role-assignments.go index 0e051e0..a9f631a 100644 --- a/cmd/list-management-group-role-assignments.go +++ b/cmd/list-management-group-role-assignments.go @@ -48,19 +48,14 @@ func listManagementGroupRoleAssignmentsCmdImpl(cmd *cobra.Command, args []string defer gracefulShutdown(stop) log.V(1).Info("testing connections") - if err := testConnections(); err != nil { - exit(err) - } else if azClient, err := newAzureClient(); err != nil { - exit(err) - } else { - log.Info("collecting azure management group role assignments...") - start := time.Now() - managementGroups := listManagementGroups(ctx, azClient) - stream := listManagementGroupRoleAssignments(ctx, azClient, managementGroups) - outputStream(ctx, stream) - duration := time.Since(start) - log.Info("collection completed", "duration", duration.String()) - } + azClient := connectAndCreateClient() + log.Info("collecting azure management group role assignments...") + start := time.Now() + managementGroups := listManagementGroups(ctx, azClient) + stream := listManagementGroupRoleAssignments(ctx, azClient, managementGroups) + outputStream(ctx, stream) + duration := time.Since(start) + log.Info("collection completed", "duration", duration.String()) } func listManagementGroupRoleAssignments(ctx context.Context, client client.AzureClient, managementGroups <-chan interface{}) <-chan azureWrapper[models.ManagementGroupRoleAssignments] { diff --git a/cmd/list-management-group-user-access-admins.go b/cmd/list-management-group-user-access-admins.go index 8f5a9b6..58598a6 100644 --- a/cmd/list-management-group-user-access-admins.go +++ b/cmd/list-management-group-user-access-admins.go @@ -47,20 +47,15 @@ func listManagementGroupUserAccessAdminsCmdImpl(cmd *cobra.Command, args []strin defer gracefulShutdown(stop) log.V(1).Info("testing connections") - if err := testConnections(); err != nil { - exit(err) - } else if azClient, err := newAzureClient(); err != nil { - exit(err) - } else { - log.Info("collecting azure management group user access admins...") - start := time.Now() - managementGroups := listManagementGroups(ctx, azClient) - roleAssignments := listManagementGroupRoleAssignments(ctx, azClient, managementGroups) - stream := listManagementGroupUserAccessAdmins(ctx, roleAssignments) - outputStream(ctx, stream) - duration := time.Since(start) - log.Info("collection completed", "duration", duration.String()) - } + azClient := connectAndCreateClient() + log.Info("collecting azure management group user access admins...") + start := time.Now() + managementGroups := listManagementGroups(ctx, azClient) + roleAssignments := listManagementGroupRoleAssignments(ctx, azClient, managementGroups) + stream := listManagementGroupUserAccessAdmins(ctx, roleAssignments) + outputStream(ctx, stream) + duration := time.Since(start) + log.Info("collection completed", "duration", duration.String()) } func listManagementGroupUserAccessAdmins( diff --git a/cmd/list-management-groups.go b/cmd/list-management-groups.go index 3adc09f..f887041 100644 --- a/cmd/list-management-groups.go +++ b/cmd/list-management-groups.go @@ -46,18 +46,13 @@ func listManagementGroupsCmdImpl(cmd *cobra.Command, args []string) { defer gracefulShutdown(stop) log.V(1).Info("testing connections") - if err := testConnections(); err != nil { - exit(err) - } else if azClient, err := newAzureClient(); err != nil { - exit(err) - } else { - log.Info("collecting azure active directory management groups...") - start := time.Now() - stream := listManagementGroups(ctx, azClient) - outputStream(ctx, stream) - duration := time.Since(start) - log.Info("collection completed", "duration", duration.String()) - } + azClient := connectAndCreateClient() + log.Info("collecting azure active directory management groups...") + start := time.Now() + stream := listManagementGroups(ctx, azClient) + outputStream(ctx, stream) + duration := time.Since(start) + log.Info("collection completed", "duration", duration.String()) } func listManagementGroups(ctx context.Context, client client.AzureClient) <-chan interface{} { diff --git a/cmd/list-resource-group-owners.go b/cmd/list-resource-group-owners.go index e2e189a..3f03b8e 100644 --- a/cmd/list-resource-group-owners.go +++ b/cmd/list-resource-group-owners.go @@ -47,21 +47,16 @@ func listResourceGroupOwnersCmdImpl(cmd *cobra.Command, args []string) { defer gracefulShutdown(stop) log.V(1).Info("testing connections") - if err := testConnections(); err != nil { - exit(err) - } else if azClient, err := newAzureClient(); err != nil { - exit(err) - } else { - log.Info("collecting azure resource group owners...") - start := time.Now() - subscriptions := listSubscriptions(ctx, azClient) - resourceGroups := listResourceGroups(ctx, azClient, subscriptions) - roleAssignments := listResourceGroupRoleAssignments(ctx, azClient, resourceGroups) - stream := listResourceGroupOwners(ctx, roleAssignments) - outputStream(ctx, stream) - duration := time.Since(start) - log.Info("collection completed", "duration", duration.String()) - } + azClient := connectAndCreateClient() + log.Info("collecting azure resource group owners...") + start := time.Now() + subscriptions := listSubscriptions(ctx, azClient) + resourceGroups := listResourceGroups(ctx, azClient, subscriptions) + roleAssignments := listResourceGroupRoleAssignments(ctx, azClient, resourceGroups) + stream := listResourceGroupOwners(ctx, roleAssignments) + outputStream(ctx, stream) + duration := time.Since(start) + log.Info("collection completed", "duration", duration.String()) } func listResourceGroupOwners( diff --git a/cmd/list-resource-group-role-assignments.go b/cmd/list-resource-group-role-assignments.go index f6a0596..7c35b9d 100644 --- a/cmd/list-resource-group-role-assignments.go +++ b/cmd/list-resource-group-role-assignments.go @@ -48,20 +48,15 @@ func listResourceGroupRoleAssignmentsCmdImpl(cmd *cobra.Command, args []string) defer gracefulShutdown(stop) log.V(1).Info("testing connections") - if err := testConnections(); err != nil { - exit(err) - } else if azClient, err := newAzureClient(); err != nil { - exit(err) - } else { - log.Info("collecting azure resource group role assignments...") - start := time.Now() - subscriptions := listSubscriptions(ctx, azClient) - resourceGroups := listResourceGroups(ctx, azClient, subscriptions) - stream := listResourceGroupRoleAssignments(ctx, azClient, resourceGroups) - outputStream(ctx, stream) - duration := time.Since(start) - log.Info("collection completed", "duration", duration.String()) - } + azClient := connectAndCreateClient() + log.Info("collecting azure resource group role assignments...") + start := time.Now() + subscriptions := listSubscriptions(ctx, azClient) + resourceGroups := listResourceGroups(ctx, azClient, subscriptions) + stream := listResourceGroupRoleAssignments(ctx, azClient, resourceGroups) + outputStream(ctx, stream) + duration := time.Since(start) + log.Info("collection completed", "duration", duration.String()) } func listResourceGroupRoleAssignments(ctx context.Context, client client.AzureClient, resourceGroups <-chan interface{}) <-chan azureWrapper[models.ResourceGroupRoleAssignments] { diff --git a/cmd/list-resource-group-user-access-admins.go b/cmd/list-resource-group-user-access-admins.go index 7b5b470..004e7bf 100644 --- a/cmd/list-resource-group-user-access-admins.go +++ b/cmd/list-resource-group-user-access-admins.go @@ -47,21 +47,16 @@ func listResourceGroupUserAccessAdminsCmdImpl(cmd *cobra.Command, args []string) defer gracefulShutdown(stop) log.V(1).Info("testing connections") - if err := testConnections(); err != nil { - exit(err) - } else if azClient, err := newAzureClient(); err != nil { - exit(err) - } else { - log.Info("collecting azure resource group user access admins...") - start := time.Now() - subscriptions := listSubscriptions(ctx, azClient) - resourceGroups := listResourceGroups(ctx, azClient, subscriptions) - roleAssignments := listResourceGroupRoleAssignments(ctx, azClient, resourceGroups) - stream := listResourceGroupUserAccessAdmins(ctx, roleAssignments) - outputStream(ctx, stream) - duration := time.Since(start) - log.Info("collection completed", "duration", duration.String()) - } + azClient := connectAndCreateClient() + log.Info("collecting azure resource group user access admins...") + start := time.Now() + subscriptions := listSubscriptions(ctx, azClient) + resourceGroups := listResourceGroups(ctx, azClient, subscriptions) + roleAssignments := listResourceGroupRoleAssignments(ctx, azClient, resourceGroups) + stream := listResourceGroupUserAccessAdmins(ctx, roleAssignments) + outputStream(ctx, stream) + duration := time.Since(start) + log.Info("collection completed", "duration", duration.String()) } func listResourceGroupUserAccessAdmins( diff --git a/cmd/list-resource-groups.go b/cmd/list-resource-groups.go index 9b16192..1811e06 100644 --- a/cmd/list-resource-groups.go +++ b/cmd/list-resource-groups.go @@ -48,18 +48,13 @@ func listResourceGroupsCmdImpl(cmd *cobra.Command, args []string) { defer gracefulShutdown(stop) log.V(1).Info("testing connections") - if err := testConnections(); err != nil { - exit(err) - } else if azClient, err := newAzureClient(); err != nil { - exit(err) - } else { - log.Info("collecting azure resource groups...") - start := time.Now() - stream := listResourceGroups(ctx, azClient, listSubscriptions(ctx, azClient)) - outputStream(ctx, stream) - duration := time.Since(start) - log.Info("collection completed", "duration", duration.String()) - } + azClient := connectAndCreateClient() + log.Info("collecting azure resource groups...") + start := time.Now() + stream := listResourceGroups(ctx, azClient, listSubscriptions(ctx, azClient)) + outputStream(ctx, stream) + duration := time.Since(start) + log.Info("collection completed", "duration", duration.String()) } func listResourceGroups(ctx context.Context, client client.AzureClient, subscriptions <-chan interface{}) <-chan interface{} { diff --git a/cmd/list-role-assignments.go b/cmd/list-role-assignments.go index 2a4ab98..edb6e67 100644 --- a/cmd/list-role-assignments.go +++ b/cmd/list-role-assignments.go @@ -48,19 +48,14 @@ func listRoleAssignmentsCmdImpl(cmd *cobra.Command, args []string) { defer gracefulShutdown(stop) log.V(1).Info("testing connections") - if err := testConnections(); err != nil { - exit(err) - } else if azClient, err := newAzureClient(); err != nil { - exit(err) - } else { - log.Info("collecting azure active directory role assignments...") - start := time.Now() - roles := listRoles(ctx, azClient) - stream := listRoleAssignments(ctx, azClient, roles) - outputStream(ctx, stream) - duration := time.Since(start) - log.Info("collection completed", "duration", duration.String()) - } + azClient := connectAndCreateClient() + log.Info("collecting azure active directory role assignments...") + start := time.Now() + roles := listRoles(ctx, azClient) + stream := listRoleAssignments(ctx, azClient, roles) + outputStream(ctx, stream) + duration := time.Since(start) + log.Info("collection completed", "duration", duration.String()) } func listRoleAssignments(ctx context.Context, client client.AzureClient, roles <-chan interface{}) <-chan interface{} { diff --git a/cmd/list-roles.go b/cmd/list-roles.go index 9736704..7359219 100644 --- a/cmd/list-roles.go +++ b/cmd/list-roles.go @@ -45,18 +45,13 @@ func listRolesCmdImpl(cmd *cobra.Command, args []string) { defer gracefulShutdown(stop) log.V(1).Info("testing connections") - if err := testConnections(); err != nil { - exit(err) - } else if azClient, err := newAzureClient(); err != nil { - exit(err) - } else { - log.Info("collecting azure active directory roles...") - start := time.Now() - stream := listRoles(ctx, azClient) - outputStream(ctx, stream) - duration := time.Since(start) - log.Info("collection completed", "duration", duration.String()) - } + azClient := connectAndCreateClient() + log.Info("collecting azure active directory roles...") + start := time.Now() + stream := listRoles(ctx, azClient) + outputStream(ctx, stream) + duration := time.Since(start) + log.Info("collection completed", "duration", duration.String()) } func listRoles(ctx context.Context, client client.AzureClient) <-chan interface{} { diff --git a/cmd/list-root.go b/cmd/list-root.go index 73ea4ff..c3ea462 100644 --- a/cmd/list-root.go +++ b/cmd/list-root.go @@ -52,18 +52,13 @@ func listCmdImpl(cmd *cobra.Command, args []string) { defer gracefulShutdown(stop) log.V(1).Info("testing connections") - if err := testConnections(); err != nil { - exit(err) - } else if azClient, err := newAzureClient(); err != nil { - exit(err) - } else { - log.Info("collecting azure objects...") - start := time.Now() - stream := listAll(ctx, azClient) - outputStream(ctx, stream) - duration := time.Since(start) - log.Info("collection completed", "duration", duration.String()) - } + azClient := connectAndCreateClient() + log.Info("collecting azure objects...") + start := time.Now() + stream := listAll(ctx, azClient) + outputStream(ctx, stream) + duration := time.Since(start) + log.Info("collection completed", "duration", duration.String()) } func listAll(ctx context.Context, client client.AzureClient) <-chan interface{} { diff --git a/cmd/list-service-principal-owners.go b/cmd/list-service-principal-owners.go index e856bdf..2f920d7 100644 --- a/cmd/list-service-principal-owners.go +++ b/cmd/list-service-principal-owners.go @@ -48,18 +48,13 @@ func listServicePrincipalOwnersCmdImpl(cmd *cobra.Command, args []string) { defer gracefulShutdown(stop) log.V(1).Info("testing connections") - if err := testConnections(); err != nil { - exit(err) - } else if azClient, err := newAzureClient(); err != nil { - exit(err) - } else { - log.Info("collecting azure service principal owners...") - start := time.Now() - stream := listServicePrincipalOwners(ctx, azClient, listServicePrincipals(ctx, azClient)) - outputStream(ctx, stream) - duration := time.Since(start) - log.Info("collection completed", "duration", duration.String()) - } + azClient := connectAndCreateClient() + log.Info("collecting azure service principal owners...") + start := time.Now() + stream := listServicePrincipalOwners(ctx, azClient, listServicePrincipals(ctx, azClient)) + outputStream(ctx, stream) + duration := time.Since(start) + log.Info("collection completed", "duration", duration.String()) } func listServicePrincipalOwners(ctx context.Context, client client.AzureClient, servicePrincipals <-chan interface{}) <-chan interface{} { diff --git a/cmd/list-service-principal-owners_test.go b/cmd/list-service-principal-owners_test.go index 8758532..2479f8b 100644 --- a/cmd/list-service-principal-owners_test.go +++ b/cmd/list-service-principal-owners_test.go @@ -96,6 +96,6 @@ func TestListServicePrincipalOwners(t *testing.T) { } else if data, ok := wrapper.Data.(models.ServicePrincipalOwners); !ok { t.Errorf("failed type assertion: got %T, want %T", wrapper.Data, models.ServicePrincipalOwners{}) } else if len(data.Owners) != 1 { - t.Errorf("got %v, want %v", len(data.Owners), 2) + t.Errorf("got %v, want %v", len(data.Owners), 1) } } diff --git a/cmd/list-service-principals.go b/cmd/list-service-principals.go index 06f37aa..9eecc76 100644 --- a/cmd/list-service-principals.go +++ b/cmd/list-service-principals.go @@ -45,18 +45,13 @@ func listServicePrincipalsCmdImpl(cmd *cobra.Command, args []string) { defer gracefulShutdown(stop) log.V(1).Info("testing connections") - if err := testConnections(); err != nil { - exit(err) - } else if azClient, err := newAzureClient(); err != nil { - exit(err) - } else { - log.Info("collecting azure active directory service principals...") - start := time.Now() - stream := listServicePrincipals(ctx, azClient) - outputStream(ctx, stream) - duration := time.Since(start) - log.Info("collection completed", "duration", duration.String()) - } + azClient := connectAndCreateClient() + log.Info("collecting azure active directory service principals...") + start := time.Now() + stream := listServicePrincipals(ctx, azClient) + outputStream(ctx, stream) + duration := time.Since(start) + log.Info("collection completed", "duration", duration.String()) } func listServicePrincipals(ctx context.Context, client client.AzureClient) <-chan interface{} { diff --git a/cmd/list-storage-account-role-assignments.go b/cmd/list-storage-account-role-assignments.go index 5a77994..501388f 100644 --- a/cmd/list-storage-account-role-assignments.go +++ b/cmd/list-storage-account-role-assignments.go @@ -49,19 +49,14 @@ func listStorageAccountRoleAssignmentsImpl(cmd *cobra.Command, args []string) { defer gracefulShutdown(stop) log.V(1).Info("testing connections") - if err := testConnections(); err != nil { - exit(err) - } else if azClient, err := newAzureClient(); err != nil { - exit(err) - } else { - log.Info("collecting azure storage account role assignments...") - start := time.Now() - subscriptions := listSubscriptions(ctx, azClient) - stream := listStorageAccountRoleAssignments(ctx, azClient, listStorageAccounts(ctx, azClient, subscriptions)) - outputStream(ctx, stream) - duration := time.Since(start) - log.Info("collection completed", "duration", duration.String()) - } + azClient := connectAndCreateClient() + log.Info("collecting azure storage account role assignments...") + start := time.Now() + subscriptions := listSubscriptions(ctx, azClient) + stream := listStorageAccountRoleAssignments(ctx, azClient, listStorageAccounts(ctx, azClient, subscriptions)) + outputStream(ctx, stream) + duration := time.Since(start) + log.Info("collection completed", "duration", duration.String()) } func listStorageAccountRoleAssignments(ctx context.Context, client client.AzureClient, storageAccounts <-chan interface{}) <-chan interface{} { diff --git a/cmd/list-storage-accounts.go b/cmd/list-storage-accounts.go index 8a51698..a2497b6 100644 --- a/cmd/list-storage-accounts.go +++ b/cmd/list-storage-accounts.go @@ -48,18 +48,13 @@ func listStorageAccountsCmdImpl(cmd *cobra.Command, args []string) { defer gracefulShutdown(stop) log.V(1).Info("testing connections") - if err := testConnections(); err != nil { - exit(err) - } else if azClient, err := newAzureClient(); err != nil { - exit(err) - } else { - log.Info("collecting azure storage accounts...") - start := time.Now() - stream := listStorageAccounts(ctx, azClient, listSubscriptions(ctx, azClient)) - outputStream(ctx, stream) - duration := time.Since(start) - log.Info("collection completed", "duration", duration.String()) - } + azClient := connectAndCreateClient() + log.Info("collecting azure storage accounts...") + start := time.Now() + stream := listStorageAccounts(ctx, azClient, listSubscriptions(ctx, azClient)) + outputStream(ctx, stream) + duration := time.Since(start) + log.Info("collection completed", "duration", duration.String()) } func listStorageAccounts(ctx context.Context, client client.AzureClient, subscriptions <-chan interface{}) <-chan interface{} { diff --git a/cmd/list-storage-containers.go b/cmd/list-storage-containers.go index ba25f55..759ec0b 100644 --- a/cmd/list-storage-containers.go +++ b/cmd/list-storage-containers.go @@ -48,26 +48,21 @@ func listStorageContainersCmdImpl(cmd *cobra.Command, args []string) { defer gracefulShutdown(stop) log.V(1).Info("testing connections") - if err := testConnections(); err != nil { - exit(err) - } else if azClient, err := newAzureClient(); err != nil { - exit(err) - } else { - log.Info("collecting azure storage containers...") - start := time.Now() - subscriptions := listSubscriptions(ctx, azClient) - storageAccounts := listStorageAccounts(ctx, azClient, subscriptions) - stream := listStorageContainers(ctx, azClient, storageAccounts) - outputStream(ctx, stream) - duration := time.Since(start) - log.Info("collection completed", "duration", duration.String()) - } + azClient := connectAndCreateClient() + log.Info("collecting azure storage containers...") + start := time.Now() + subscriptions := listSubscriptions(ctx, azClient) + storageAccounts := listStorageAccounts(ctx, azClient, subscriptions) + stream := listStorageContainers(ctx, azClient, storageAccounts) + outputStream(ctx, stream) + duration := time.Since(start) + log.Info("collection completed", "duration", duration.String()) } func listStorageContainers(ctx context.Context, client client.AzureClient, storageAccounts <-chan interface{}) <-chan interface{} { var ( - out = make(chan interface{}) - ids = make(chan interface{}) + out = make(chan interface{}) + ids = make(chan interface{}) // The original size of the demuxxer cascaded into error messages for a lot of collection steps. // Decreasing the demuxxer size only here is sufficient to prevent the cascade // The error message with higher values for size is diff --git a/cmd/list-subscription-owners.go b/cmd/list-subscription-owners.go index d9ecdc3..01b8834 100644 --- a/cmd/list-subscription-owners.go +++ b/cmd/list-subscription-owners.go @@ -49,20 +49,15 @@ func listSubscriptionOwnersCmdImpl(cmd *cobra.Command, args []string) { defer gracefulShutdown(stop) log.V(1).Info("testing connections") - if err := testConnections(); err != nil { - exit(err) - } else if azClient, err := newAzureClient(); err != nil { - exit(err) - } else { - log.Info("collecting azure subscription owners...") - start := time.Now() - subscriptions := listSubscriptions(ctx, azClient) - roleAssignments := listSubscriptionRoleAssignments(ctx, azClient, subscriptions) - stream := listSubscriptionOwners(ctx, azClient, roleAssignments) - outputStream(ctx, stream) - duration := time.Since(start) - log.Info("collection completed", "duration", duration.String()) - } + azClient := connectAndCreateClient() + log.Info("collecting azure subscription owners...") + start := time.Now() + subscriptions := listSubscriptions(ctx, azClient) + roleAssignments := listSubscriptionRoleAssignments(ctx, azClient, subscriptions) + stream := listSubscriptionOwners(ctx, azClient, roleAssignments) + outputStream(ctx, stream) + duration := time.Since(start) + log.Info("collection completed", "duration", duration.String()) } func listSubscriptionOwners(ctx context.Context, client client.AzureClient, roleAssignments <-chan interface{}) <-chan interface{} { diff --git a/cmd/list-subscription-role-assignments.go b/cmd/list-subscription-role-assignments.go index 6a7cef2..6b6b5f9 100644 --- a/cmd/list-subscription-role-assignments.go +++ b/cmd/list-subscription-role-assignments.go @@ -48,19 +48,14 @@ func listSubscriptionRoleAssignmentsCmdImpl(cmd *cobra.Command, args []string) { defer gracefulShutdown(stop) log.V(1).Info("testing connections") - if err := testConnections(); err != nil { - exit(err) - } else if azClient, err := newAzureClient(); err != nil { - exit(err) - } else { - log.Info("collecting azure subscription role assignments...") - start := time.Now() - subscriptions := listSubscriptions(ctx, azClient) - stream := listSubscriptionRoleAssignments(ctx, azClient, subscriptions) - outputStream(ctx, stream) - duration := time.Since(start) - log.Info("collection completed", "duration", duration.String()) - } + azClient := connectAndCreateClient() + log.Info("collecting azure subscription role assignments...") + start := time.Now() + subscriptions := listSubscriptions(ctx, azClient) + stream := listSubscriptionRoleAssignments(ctx, azClient, subscriptions) + outputStream(ctx, stream) + duration := time.Since(start) + log.Info("collection completed", "duration", duration.String()) } func listSubscriptionRoleAssignments(ctx context.Context, client client.AzureClient, subscriptions <-chan interface{}) <-chan interface{} { diff --git a/cmd/list-subscription-user-access-admins.go b/cmd/list-subscription-user-access-admins.go index ba4886a..8ff0817 100644 --- a/cmd/list-subscription-user-access-admins.go +++ b/cmd/list-subscription-user-access-admins.go @@ -49,20 +49,15 @@ func listSubscriptionUserAccessAdminsCmdImpl(cmd *cobra.Command, args []string) defer gracefulShutdown(stop) log.V(1).Info("testing connections") - if err := testConnections(); err != nil { - exit(err) - } else if azClient, err := newAzureClient(); err != nil { - exit(err) - } else { - log.Info("collecting azure subscription user access admins...") - start := time.Now() - subscriptions := listSubscriptions(ctx, azClient) - roleAssignments := listSubscriptionRoleAssignments(ctx, azClient, subscriptions) - stream := listSubscriptionUserAccessAdmins(ctx, azClient, roleAssignments) - outputStream(ctx, stream) - duration := time.Since(start) - log.Info("collection completed", "duration", duration.String()) - } + azClient := connectAndCreateClient() + log.Info("collecting azure subscription user access admins...") + start := time.Now() + subscriptions := listSubscriptions(ctx, azClient) + roleAssignments := listSubscriptionRoleAssignments(ctx, azClient, subscriptions) + stream := listSubscriptionUserAccessAdmins(ctx, azClient, roleAssignments) + outputStream(ctx, stream) + duration := time.Since(start) + log.Info("collection completed", "duration", duration.String()) } func listSubscriptionUserAccessAdmins(ctx context.Context, client client.AzureClient, vmRoleAssignments <-chan interface{}) <-chan interface{} { diff --git a/cmd/list-subscriptions.go b/cmd/list-subscriptions.go index 63f939c..e2172a2 100644 --- a/cmd/list-subscriptions.go +++ b/cmd/list-subscriptions.go @@ -49,18 +49,13 @@ func listSubscriptionsCmdImpl(cmd *cobra.Command, args []string) { defer gracefulShutdown(stop) log.V(1).Info("testing connections") - if err := testConnections(); err != nil { - exit(err) - } else if azClient, err := newAzureClient(); err != nil { - exit(err) - } else { - log.Info("collecting azure active directory subscriptions...") - start := time.Now() - stream := listSubscriptions(ctx, azClient) - outputStream(ctx, stream) - duration := time.Since(start) - log.Info("collection completed", "duration", duration.String()) - } + azClient := connectAndCreateClient() + log.Info("collecting azure active directory subscriptions...") + start := time.Now() + stream := listSubscriptions(ctx, azClient) + outputStream(ctx, stream) + duration := time.Since(start) + log.Info("collection completed", "duration", duration.String()) } func listSubscriptions(ctx context.Context, client client.AzureClient) <-chan interface{} { diff --git a/cmd/list-tenants.go b/cmd/list-tenants.go index 19d4097..751800b 100644 --- a/cmd/list-tenants.go +++ b/cmd/list-tenants.go @@ -45,18 +45,13 @@ func listTenantsCmdImpl(cmd *cobra.Command, args []string) { defer gracefulShutdown(stop) log.V(1).Info("testing connections") - if err := testConnections(); err != nil { - exit(err) - } else if azClient, err := newAzureClient(); err != nil { - exit(err) - } else { - log.Info("collecting azure active directory tenants...") - start := time.Now() - stream := listTenants(ctx, azClient) - outputStream(ctx, stream) - duration := time.Since(start) - log.Info("collection completed", "duration", duration.String()) - } + azClient := connectAndCreateClient() + log.Info("collecting azure active directory tenants...") + start := time.Now() + stream := listTenants(ctx, azClient) + outputStream(ctx, stream) + duration := time.Since(start) + log.Info("collection completed", "duration", duration.String()) } func listTenants(ctx context.Context, client client.AzureClient) <-chan interface{} { diff --git a/cmd/list-users.go b/cmd/list-users.go index aede4d6..6b224cd 100644 --- a/cmd/list-users.go +++ b/cmd/list-users.go @@ -45,18 +45,13 @@ func listUsersCmdImpl(cmd *cobra.Command, args []string) { defer gracefulShutdown(stop) log.V(1).Info("testing connections") - if err := testConnections(); err != nil { - exit(err) - } else if azClient, err := newAzureClient(); err != nil { - exit(err) - } else { - log.Info("collecting azure active directory users...") - start := time.Now() - stream := listUsers(ctx, azClient) - outputStream(ctx, stream) - duration := time.Since(start) - log.Info("collection completed", "duration", duration.String()) - } + azClient := connectAndCreateClient() + log.Info("collecting azure active directory users...") + start := time.Now() + stream := listUsers(ctx, azClient) + outputStream(ctx, stream) + duration := time.Since(start) + log.Info("collection completed", "duration", duration.String()) } func listUsers(ctx context.Context, client client.AzureClient) <-chan interface{} { diff --git a/cmd/list-virtual-machine-admin-logins.go b/cmd/list-virtual-machine-admin-logins.go index 22cfcef..7724c86 100644 --- a/cmd/list-virtual-machine-admin-logins.go +++ b/cmd/list-virtual-machine-admin-logins.go @@ -47,21 +47,16 @@ func listVirtualMachineAdminLoginsCmdImpl(cmd *cobra.Command, args []string) { defer gracefulShutdown(stop) log.V(1).Info("testing connections") - if err := testConnections(); err != nil { - exit(err) - } else if azClient, err := newAzureClient(); err != nil { - exit(err) - } else { - log.Info("collecting azure virtual machine admin logins...") - start := time.Now() - subscriptions := listSubscriptions(ctx, azClient) - vms := listVirtualMachines(ctx, azClient, subscriptions) - vmRoleAssignments := listVirtualMachineRoleAssignments(ctx, azClient, vms) - stream := listVirtualMachineAdminLogins(ctx, vmRoleAssignments) - outputStream(ctx, stream) - duration := time.Since(start) - log.Info("collection completed", "duration", duration.String()) - } + azClient := connectAndCreateClient() + log.Info("collecting azure virtual machine admin logins...") + start := time.Now() + subscriptions := listSubscriptions(ctx, azClient) + vms := listVirtualMachines(ctx, azClient, subscriptions) + vmRoleAssignments := listVirtualMachineRoleAssignments(ctx, azClient, vms) + stream := listVirtualMachineAdminLogins(ctx, vmRoleAssignments) + outputStream(ctx, stream) + duration := time.Since(start) + log.Info("collection completed", "duration", duration.String()) } func listVirtualMachineAdminLogins( diff --git a/cmd/list-virtual-machine-avere-contributors.go b/cmd/list-virtual-machine-avere-contributors.go index 4f78975..a6e98a0 100644 --- a/cmd/list-virtual-machine-avere-contributors.go +++ b/cmd/list-virtual-machine-avere-contributors.go @@ -47,21 +47,16 @@ func listVirtualMachineAvereContributorsCmdImpl(cmd *cobra.Command, args []strin defer gracefulShutdown(stop) log.V(1).Info("testing connections") - if err := testConnections(); err != nil { - exit(err) - } else if azClient, err := newAzureClient(); err != nil { - exit(err) - } else { - log.Info("collecting azure virtual machine averecontributors...") - start := time.Now() - subscriptions := listSubscriptions(ctx, azClient) - vms := listVirtualMachines(ctx, azClient, subscriptions) - vmRoleAssignments := listVirtualMachineRoleAssignments(ctx, azClient, vms) - stream := listVirtualMachineAvereContributors(ctx, vmRoleAssignments) - outputStream(ctx, stream) - duration := time.Since(start) - log.Info("collection completed", "duration", duration.String()) - } + azClient := connectAndCreateClient() + log.Info("collecting azure virtual machine averecontributors...") + start := time.Now() + subscriptions := listSubscriptions(ctx, azClient) + vms := listVirtualMachines(ctx, azClient, subscriptions) + vmRoleAssignments := listVirtualMachineRoleAssignments(ctx, azClient, vms) + stream := listVirtualMachineAvereContributors(ctx, vmRoleAssignments) + outputStream(ctx, stream) + duration := time.Since(start) + log.Info("collection completed", "duration", duration.String()) } func listVirtualMachineAvereContributors( diff --git a/cmd/list-virtual-machine-contributors.go b/cmd/list-virtual-machine-contributors.go index bd64132..e06196b 100644 --- a/cmd/list-virtual-machine-contributors.go +++ b/cmd/list-virtual-machine-contributors.go @@ -47,21 +47,16 @@ func listVirtualMachineContributorsCmdImpl(cmd *cobra.Command, args []string) { defer gracefulShutdown(stop) log.V(1).Info("testing connections") - if err := testConnections(); err != nil { - exit(err) - } else if azClient, err := newAzureClient(); err != nil { - exit(err) - } else { - log.Info("collecting azure virtual machine contributors...") - start := time.Now() - subscriptions := listSubscriptions(ctx, azClient) - vms := listVirtualMachines(ctx, azClient, subscriptions) - vmRoleAssignments := listVirtualMachineRoleAssignments(ctx, azClient, vms) - stream := listVirtualMachineContributors(ctx, vmRoleAssignments) - outputStream(ctx, stream) - duration := time.Since(start) - log.Info("collection completed", "duration", duration.String()) - } + azClient := connectAndCreateClient() + log.Info("collecting azure virtual machine contributors...") + start := time.Now() + subscriptions := listSubscriptions(ctx, azClient) + vms := listVirtualMachines(ctx, azClient, subscriptions) + vmRoleAssignments := listVirtualMachineRoleAssignments(ctx, azClient, vms) + stream := listVirtualMachineContributors(ctx, vmRoleAssignments) + outputStream(ctx, stream) + duration := time.Since(start) + log.Info("collection completed", "duration", duration.String()) } func listVirtualMachineContributors( diff --git a/cmd/list-virtual-machine-owners.go b/cmd/list-virtual-machine-owners.go index 0edd7ff..2399703 100644 --- a/cmd/list-virtual-machine-owners.go +++ b/cmd/list-virtual-machine-owners.go @@ -47,21 +47,16 @@ func listVirtualMachineOwnersCmdImpl(cmd *cobra.Command, args []string) { defer gracefulShutdown(stop) log.V(1).Info("testing connections") - if err := testConnections(); err != nil { - exit(err) - } else if azClient, err := newAzureClient(); err != nil { - exit(err) - } else { - log.Info("collecting azure virtual machine owners...") - start := time.Now() - subscriptions := listSubscriptions(ctx, azClient) - vms := listVirtualMachines(ctx, azClient, subscriptions) - vmRoleAssignments := listVirtualMachineRoleAssignments(ctx, azClient, vms) - stream := listVirtualMachineOwners(ctx, vmRoleAssignments) - outputStream(ctx, stream) - duration := time.Since(start) - log.Info("collection completed", "duration", duration.String()) - } + azClient := connectAndCreateClient() + log.Info("collecting azure virtual machine owners...") + start := time.Now() + subscriptions := listSubscriptions(ctx, azClient) + vms := listVirtualMachines(ctx, azClient, subscriptions) + vmRoleAssignments := listVirtualMachineRoleAssignments(ctx, azClient, vms) + stream := listVirtualMachineOwners(ctx, vmRoleAssignments) + outputStream(ctx, stream) + duration := time.Since(start) + log.Info("collection completed", "duration", duration.String()) } func listVirtualMachineOwners( diff --git a/cmd/list-virtual-machine-role-assignments.go b/cmd/list-virtual-machine-role-assignments.go index 9344d55..7fc996b 100644 --- a/cmd/list-virtual-machine-role-assignments.go +++ b/cmd/list-virtual-machine-role-assignments.go @@ -48,19 +48,14 @@ func listVirtualMachineRoleAssignmentsCmdImpl(cmd *cobra.Command, args []string) defer gracefulShutdown(stop) log.V(1).Info("testing connections") - if err := testConnections(); err != nil { - exit(err) - } else if azClient, err := newAzureClient(); err != nil { - exit(err) - } else { - log.Info("collecting azure virtual machine role assignments...") - start := time.Now() - subscriptions := listSubscriptions(ctx, azClient) - stream := listVirtualMachineRoleAssignments(ctx, azClient, listVirtualMachines(ctx, azClient, subscriptions)) - outputStream(ctx, stream) - duration := time.Since(start) - log.Info("collection completed", "duration", duration.String()) - } + azClient := connectAndCreateClient() + log.Info("collecting azure virtual machine role assignments...") + start := time.Now() + subscriptions := listSubscriptions(ctx, azClient) + stream := listVirtualMachineRoleAssignments(ctx, azClient, listVirtualMachines(ctx, azClient, subscriptions)) + outputStream(ctx, stream) + duration := time.Since(start) + log.Info("collection completed", "duration", duration.String()) } func listVirtualMachineRoleAssignments(ctx context.Context, client client.AzureClient, virtualMachines <-chan interface{}) <-chan azureWrapper[models.VirtualMachineRoleAssignments] { diff --git a/cmd/list-virtual-machine-user-access-admins.go b/cmd/list-virtual-machine-user-access-admins.go index 2bab9bb..ab4eac4 100644 --- a/cmd/list-virtual-machine-user-access-admins.go +++ b/cmd/list-virtual-machine-user-access-admins.go @@ -47,21 +47,16 @@ func listVirtualMachineUserAccessAdminsCmdImpl(cmd *cobra.Command, args []string defer gracefulShutdown(stop) log.V(1).Info("testing connections") - if err := testConnections(); err != nil { - exit(err) - } else if azClient, err := newAzureClient(); err != nil { - exit(err) - } else { - log.Info("collecting azure virtual machine user access admins...") - start := time.Now() - subscriptions := listSubscriptions(ctx, azClient) - vms := listVirtualMachines(ctx, azClient, subscriptions) - vmRoleAssignments := listVirtualMachineRoleAssignments(ctx, azClient, vms) - stream := listVirtualMachineUserAccessAdmins(ctx, vmRoleAssignments) - outputStream(ctx, stream) - duration := time.Since(start) - log.Info("collection completed", "duration", duration.String()) - } + azClient := connectAndCreateClient() + log.Info("collecting azure virtual machine user access admins...") + start := time.Now() + subscriptions := listSubscriptions(ctx, azClient) + vms := listVirtualMachines(ctx, azClient, subscriptions) + vmRoleAssignments := listVirtualMachineRoleAssignments(ctx, azClient, vms) + stream := listVirtualMachineUserAccessAdmins(ctx, vmRoleAssignments) + outputStream(ctx, stream) + duration := time.Since(start) + log.Info("collection completed", "duration", duration.String()) } func listVirtualMachineUserAccessAdmins( diff --git a/cmd/list-virtual-machine-vmcontributors.go b/cmd/list-virtual-machine-vmcontributors.go index b800a6a..5a88e81 100644 --- a/cmd/list-virtual-machine-vmcontributors.go +++ b/cmd/list-virtual-machine-vmcontributors.go @@ -47,21 +47,16 @@ func listVirtualMachineVMContributorsCmdImpl(cmd *cobra.Command, args []string) defer gracefulShutdown(stop) log.V(1).Info("testing connections") - if err := testConnections(); err != nil { - exit(err) - } else if azClient, err := newAzureClient(); err != nil { - exit(err) - } else { - log.Info("collecting azure virtual machine vmcontributors...") - start := time.Now() - subscriptions := listSubscriptions(ctx, azClient) - vms := listVirtualMachines(ctx, azClient, subscriptions) - vmRoleAssignments := listVirtualMachineRoleAssignments(ctx, azClient, vms) - stream := listVirtualMachineVMContributors(ctx, vmRoleAssignments) - outputStream(ctx, stream) - duration := time.Since(start) - log.Info("collection completed", "duration", duration.String()) - } + azClient := connectAndCreateClient() + log.Info("collecting azure virtual machine vmcontributors...") + start := time.Now() + subscriptions := listSubscriptions(ctx, azClient) + vms := listVirtualMachines(ctx, azClient, subscriptions) + vmRoleAssignments := listVirtualMachineRoleAssignments(ctx, azClient, vms) + stream := listVirtualMachineVMContributors(ctx, vmRoleAssignments) + outputStream(ctx, stream) + duration := time.Since(start) + log.Info("collection completed", "duration", duration.String()) } func listVirtualMachineVMContributors( diff --git a/cmd/list-virtual-machines.go b/cmd/list-virtual-machines.go index 9483cfd..ff2da30 100644 --- a/cmd/list-virtual-machines.go +++ b/cmd/list-virtual-machines.go @@ -48,18 +48,13 @@ func listVirtualMachinesCmdImpl(cmd *cobra.Command, args []string) { defer gracefulShutdown(stop) log.V(1).Info("testing connections") - if err := testConnections(); err != nil { - exit(err) - } else if azClient, err := newAzureClient(); err != nil { - exit(err) - } else { - log.Info("collecting azure virtual machines...") - start := time.Now() - stream := listVirtualMachines(ctx, azClient, listSubscriptions(ctx, azClient)) - outputStream(ctx, stream) - duration := time.Since(start) - log.Info("collection completed", "duration", duration.String()) - } + azClient := connectAndCreateClient() + log.Info("collecting azure virtual machines...") + start := time.Now() + stream := listVirtualMachines(ctx, azClient, listSubscriptions(ctx, azClient)) + outputStream(ctx, stream) + duration := time.Since(start) + log.Info("collection completed", "duration", duration.String()) } func listVirtualMachines(ctx context.Context, client client.AzureClient, subscriptions <-chan interface{}) <-chan interface{} { diff --git a/cmd/list-workflow-role-assignments.go b/cmd/list-workflow-role-assignments.go index b5b48ff..7d4709b 100644 --- a/cmd/list-workflow-role-assignments.go +++ b/cmd/list-workflow-role-assignments.go @@ -49,19 +49,14 @@ func listWorkflowRoleAssignmentImpl(cmd *cobra.Command, args []string) { defer gracefulShutdown(stop) log.V(1).Info("testing connections") - if err := testConnections(); err != nil { - exit(err) - } else if azClient, err := newAzureClient(); err != nil { - exit(err) - } else { - log.Info("collecting azure workflow role assignments...") - start := time.Now() - subscriptions := listSubscriptions(ctx, azClient) - stream := listWorkflowRoleAsignments(ctx, azClient, listWorkflows(ctx, azClient, subscriptions)) - outputStream(ctx, stream) - duration := time.Since(start) - log.Info("collection completed", "duration", duration.String()) - } + azClient := connectAndCreateClient() + log.Info("collecting azure workflow role assignments...") + start := time.Now() + subscriptions := listSubscriptions(ctx, azClient) + stream := listWorkflowRoleAsignments(ctx, azClient, listWorkflows(ctx, azClient, subscriptions)) + outputStream(ctx, stream) + duration := time.Since(start) + log.Info("collection completed", "duration", duration.String()) } func listWorkflowRoleAsignments(ctx context.Context, client client.AzureClient, workflows <-chan interface{}) <-chan interface{} { diff --git a/cmd/list-workflows.go b/cmd/list-workflows.go index b1cbd42..6207f5e 100644 --- a/cmd/list-workflows.go +++ b/cmd/list-workflows.go @@ -47,19 +47,13 @@ func listWorkflowsCmdImpl(cmd *cobra.Command, args []string) { ctx, stop := signal.NotifyContext(cmd.Context(), os.Interrupt, os.Kill) defer gracefulShutdown(stop) - log.V(1).Info("testing connections") - if err := testConnections(); err != nil { - exit(err) - } else if azClient, err := newAzureClient(); err != nil { - exit(err) - } else { - log.Info("collecting azure workflows...") - start := time.Now() - stream := listWorkflows(ctx, azClient, listSubscriptions(ctx, azClient)) - outputStream(ctx, stream) - duration := time.Since(start) - log.Info("collection completed", "duration", duration.String()) - } + azClient := connectAndCreateClient() + log.Info("collecting azure workflows...") + start := time.Now() + stream := listWorkflows(ctx, azClient, listSubscriptions(ctx, azClient)) + outputStream(ctx, stream) + duration := time.Since(start) + log.Info("collection completed", "duration", duration.String()) } func listWorkflows(ctx context.Context, client client.AzureClient, subscriptions <-chan interface{}) <-chan interface{} { diff --git a/cmd/start.go b/cmd/start.go index d8f5e6d..8a30b8b 100644 --- a/cmd/start.go +++ b/cmd/start.go @@ -73,20 +73,15 @@ func start(ctx context.Context) { defer gracefulShutdown(stop) log.V(1).Info("testing connections") - if err := testConnections(); err != nil { - exit(err) - } else if azClient, err := newAzureClient(); err != nil { - exit(err) + if azClient := connectAndCreateClient(); azClient == nil { + exit(fmt.Errorf("azClient is unexpectedly nil")) } else if bheInstance, err := url.Parse(config.BHEUrl.Value().(string)); err != nil { - exit(err) + exit(fmt.Errorf("unable to parse BHE url: %w", err)) } else if bheClient, err := newSigningHttpClient(BHEAuthSignature, config.BHETokenId.Value().(string), config.BHEToken.Value().(string), config.Proxy.Value().(string)); err != nil { - exit(err) + exit(fmt.Errorf("failed to create new signing HTTP client: %w", err)) + } else if err := updateClient(ctx, *bheInstance, bheClient); err != nil { + exit(fmt.Errorf("failed to update client: %w", err)) } else { - - if err := updateClient(ctx, *bheInstance, bheClient); err != nil { - exit(err) - } - log.Info("connected successfully! waiting for tasks...") ticker := time.NewTicker(5 * time.Second) defer ticker.Stop() @@ -97,51 +92,64 @@ func start(ctx context.Context) { select { case <-ticker.C: if currentTask != nil { - log.V(1).Info("currently performing collection; continuing...") + log.V(1).Info("collection in progress...") + if err := checkin(ctx, *bheInstance, bheClient); err != nil { + log.Error(err, "bloodhound enterprise service checkin failed") + } } else { - log.V(2).Info("checking for available collection tasks") - if availableTasks, err := getAvailableTasks(ctx, *bheInstance, bheClient); err != nil { - log.Error(err, "unable to fetch available tasks for azurehound") - } else { - - // Get only the tasks that have reached their execution time - executableTasks := []models.ClientTask{} - now := time.Now() - for _, task := range availableTasks { - if task.ExectionTime.Before(now) || task.ExectionTime.Equal(now) { - executableTasks = append(executableTasks, task) + go func() { + log.V(2).Info("checking for available collection tasks") + if availableTasks, err := getAvailableTasks(ctx, *bheInstance, bheClient); err != nil { + log.Error(err, "unable to fetch available tasks for azurehound") + } else { + + // Get only the tasks that have reached their execution time + executableTasks := []models.ClientTask{} + now := time.Now() + for _, task := range availableTasks { + if task.ExectionTime.Before(now) || task.ExectionTime.Equal(now) { + executableTasks = append(executableTasks, task) + } } - } - // Sort tasks in ascending order by execution time - sort.Slice(executableTasks, func(i, j int) bool { - return executableTasks[i].ExectionTime.Before(executableTasks[j].ExectionTime) - }) + // Sort tasks in ascending order by execution time + sort.Slice(executableTasks, func(i, j int) bool { + return executableTasks[i].ExectionTime.Before(executableTasks[j].ExectionTime) + }) - if len(executableTasks) == 0 { - log.V(2).Info("there are no tasks for azurehound to complete at this time") - } else { + if len(executableTasks) == 0 { + log.V(2).Info("there are no tasks for azurehound to complete at this time") + } else { - // Notify BHE instance of task start - currentTask = &executableTasks[0] - startTask(ctx, *bheInstance, bheClient, currentTask.Id) - start := time.Now() + // Notify BHE instance of task start + currentTask = &executableTasks[0] + if err := startTask(ctx, *bheInstance, bheClient, currentTask.Id); err != nil { + log.Error(err, "failed to start task, will retry on next heartbeat") + currentTask = nil + return + } + + start := time.Now() + + // Batch data out for ingestion + stream := listAll(ctx, azClient) + batches := pipeline.Batch(ctx.Done(), stream, 999, 10*time.Second) + if err := ingest(ctx, *bheInstance, bheClient, batches); err != nil { + log.Error(err, "ingestion failed") + } - // Batch data out for ingestion - stream := listAll(ctx, azClient) - batches := pipeline.Batch(ctx.Done(), stream, 999, 10*time.Second) - if err := ingest(ctx, *bheInstance, bheClient, batches); err != nil { - log.Error(err, "ingestion failed; collection will be re-attempted") - } else { // Notify BHE instance of task end duration := time.Since(start) - endTask(ctx, *bheInstance, bheClient) - log.Info("finished collection task", "id", currentTask.Id, "duration", duration.String()) + if err := endTask(ctx, *bheInstance, bheClient); err != nil { + log.Error(err, "failed to end task") + } else { + log.Info("finished collection task", "id", currentTask.Id, "duration", duration.String()) + } currentTask = nil } } - } + }() } case <-ctx.Done(): return @@ -193,6 +201,20 @@ func getAvailableTasks(ctx context.Context, bheUrl url.URL, bheClient *http.Clie } } +func checkin(ctx context.Context, bheUrl url.URL, bheClient *http.Client) error { + endpoint := bheUrl.ResolveReference(&url.URL{Path: "/api/v2/jobs/current"}) + + if req, err := rest.NewRequest(ctx, "GET", endpoint, nil, nil, nil); err != nil { + return err + } else if res, err := bheClient.Do(req); err != nil { + return err + } else if !contains([]int{http.StatusOK, http.StatusNotFound}, res.StatusCode) { + return fmt.Errorf("unexpected response code %s", res.Status) + } else { + return nil + } +} + func startTask(ctx context.Context, bheUrl url.URL, bheClient *http.Client, taskId int) error { log.Info("beginning collection task", "id", taskId) var ( diff --git a/cmd/uninstall_windows.go b/cmd/uninstall_windows.go index 4fedcc9..8d162be 100644 --- a/cmd/uninstall_windows.go +++ b/cmd/uninstall_windows.go @@ -18,6 +18,8 @@ package cmd import ( + "fmt" + "github.com/bloodhoundad/azurehound/constants" "github.com/spf13/cobra" "golang.org/x/sys/windows/svc/eventlog" @@ -38,7 +40,7 @@ var uninstallCmd = &cobra.Command{ func uninstallCmdImpl(cmd *cobra.Command, args []string) { if err := uninstallService(constants.Name); err != nil { - exit(err) + exit(fmt.Errorf("failed to uninstall service: %w", err)) } } diff --git a/cmd/utils.go b/cmd/utils.go index c1bcd92..446aea3 100644 --- a/cmd/utils.go +++ b/cmd/utils.go @@ -321,7 +321,7 @@ func (s signingTransport) RoundTrip(req *http.Request) (*http.Response, error) { return s.base.RoundTrip(clone) } -func contains(collection []string, value string) bool { +func contains[T comparable](collection []T, value T) bool { for _, item := range collection { if item == value { return true @@ -395,7 +395,7 @@ func outputStream[T any](ctx context.Context, stream <-chan T) { formatted := pipeline.FormatJson(ctx.Done(), stream) if path := config.OutputFile.Value().(string); path != "" { if err := sinks.WriteToFile(ctx, path, formatted); err != nil { - exit(err) + exit(fmt.Errorf("failed to write stream to file: %w", err)) } } else { sinks.WriteToConsole(ctx, formatted) @@ -425,3 +425,16 @@ func mgmtGroupRoleAssignmentFilter(roleId string) func(models.ManagementGroupRol return path.Base(ra.RoleAssignment.Properties.RoleDefinitionId) == roleId } } + +func connectAndCreateClient() client.AzureClient { + log.V(1).Info("testing connections") + if err := testConnections(); err != nil { + exit(fmt.Errorf("failed to test connections: %w", err)) + } else if azClient, err := newAzureClient(); err != nil { + exit(fmt.Errorf("failed to create new Azure client: %w", err)) + } else { + return azClient + } + + panic("unexpectedly failed to create azClient without error") +} diff --git a/models/azure/immutability_policy.go b/models/azure/immutability_policy.go index 1a4c6a4..e24c2f9 100644 --- a/models/azure/immutability_policy.go +++ b/models/azure/immutability_policy.go @@ -19,6 +19,6 @@ package azure type ImmutabilityPolicy struct { Etag string `json:"etag,omitempty"` - Properties ImmutabilityPolicyProperties `json:properties` + Properties ImmutabilityPolicyProperties `json:"properties"` UpdateHistory ImmutablePolicyUpdateHistory `json:"updateHistory,omitempty"` } diff --git a/models/azure/immutable_storage_with_versioning.go b/models/azure/immutable_storage_with_versioning.go index e04d8b0..6921c4c 100644 --- a/models/azure/immutable_storage_with_versioning.go +++ b/models/azure/immutable_storage_with_versioning.go @@ -22,5 +22,5 @@ import "github.com/bloodhoundad/azurehound/enums" type ImmutableStorageWithVersioning struct { Enabled bool `json:"enabled,omitempty"` MigrationState enums.MigrationState `json:"migrationState,omitempty"` - timeStamp string `json:"timeStamp,omitempty"` + TimeStamp string `json:"timeStamp,omitempty"` } diff --git a/models/azure/logic_app_definition.go b/models/azure/logic_app_definition.go index c9fd11e..2cff0bb 100644 --- a/models/azure/logic_app_definition.go +++ b/models/azure/logic_app_definition.go @@ -30,7 +30,7 @@ type Definition struct { } type Action struct { - Type string `json:type` + Type string `json:"type"` // Kind is missing in the MSDN, but returned and present in examples and during testing Kind string `json:"kind,omitempty"` Inputs map[string]interface{} `json:"inputs,omitempty"` diff --git a/models/azure/site_config.go b/models/azure/site_config.go index 7c7617f..658c0af 100644 --- a/models/azure/site_config.go +++ b/models/azure/site_config.go @@ -20,18 +20,18 @@ package azure import "github.com/bloodhoundad/azurehound/enums" type SiteConfig struct { - AcrUseManagedIdentityCreds bool `json:"acrUseManagedIdentityCreds,omitemtpy"` - AcrUserManagedIdentityID string `json:"acrUserManagedIdentityID,omitemtpy"` - AlwaysOn bool `json:"alwaysOn,omitemtpy"` - ApiDefinition ApiDefinitionInfo `json:"apiDefinition,omitemtpy"` - ApiManagementConfig ApiManagementConfig `json:"apiManagementConfig,omitemtpy"` - AppCommandLine string `json:"appCommandLine,omitemtpy"` - AppSettings []NameValuePair `json:"appSettings,omitemtpy"` - AutoHealEnabled bool `json:"autoHealEnabled,omitemtpy"` - AutoHealRules string `json:"autoHealRules,omitemtpy"` - AutoSwapSlotName string `json:"autoSwapSlotName,omitemtpy"` - AzureStorageAccounts map[string]AzureStorageInfoValue `json:"azureStorageAccounts,omitemtpy"` - ConnectionStrings []ConnStringInfo `json:"connectionStrings,omitemtpy"` + AcrUseManagedIdentityCreds bool `json:"acrUseManagedIdentityCreds,omitempty"` + AcrUserManagedIdentityID string `json:"acrUserManagedIdentityID,omitempty"` + AlwaysOn bool `json:"alwaysOn,omitempty"` + ApiDefinition ApiDefinitionInfo `json:"apiDefinition,omitempty"` + ApiManagementConfig ApiManagementConfig `json:"apiManagementConfig,omitempty"` + AppCommandLine string `json:"appCommandLine,omitempty"` + AppSettings []NameValuePair `json:"appSettings,omitempty"` + AutoHealEnabled bool `json:"autoHealEnabled,omitempty"` + AutoHealRules string `json:"autoHealRules,omitempty"` + AutoSwapSlotName string `json:"autoSwapSlotName,omitempty"` + AzureStorageAccounts map[string]AzureStorageInfoValue `json:"azureStorageAccounts,omitempty"` + ConnectionStrings []ConnStringInfo `json:"connectionStrings,omitempty"` Cors CorsSettings `json:"cors,omitempty"` DefaultDocuments []string `json:"defaultDocuments,omitempty"` DetailedErrorLoggingEnabled bool `json:"detailedErrorLoggingEnabled,omitempty"` @@ -89,26 +89,26 @@ type SiteConfig struct { XManagedServiceIdentityId int `json:"xManagedServiceIdentityId,omitempty"` //Following ones have been found in testing, but not present in the documentation - AntivirusScanEnabled bool `json:"antivirusScanEnabled,omitemtpy"` - AzureMonitorLogCategories interface{} `json:"azureMonitorLogCategories,omitemtpy"` - CustomAppPoolIdentityAdminState interface{} `json:"customAppPoolIdentityAdminState,omitemtpy"` - CustomAppPoolIdentityTenantState interface{} `json:"customAppPoolIdentityTenantState,omitemtpy"` - ElasticWebAppScaleLimit interface{} `json:"elasticWebAppScaleLimit,omitemtpy"` - FileChangeAuditEnabled bool `json:"fileChangeAuditEnabled,omitemtpy"` - Http20ProxyFlag interface{} `json:"http20ProxyFlag,omitemtpy"` - IpSecurityRestrictionsDefaultAction interface{} `json:"ipSecurityRestrictionsDefaultAction,omitemtpy"` - Metadata interface{} `json:"metadata,omitemtpy"` - MinTlsCipherSuite interface{} `json:"minTlsCipherSuite,omitemtpy"` - PublishingPassword interface{} `json:"publishingPassword,omitemtpy"` - RoutingRules interface{} `json:"routingRules,omitemtpy"` - RuntimeADUser interface{} `json:"runtimeADUser,omitemtpy"` - RuntimeADUserPassword interface{} `json:"runtimeADUserPassword,omitemtpy"` - ScmIpSecurityRestrictionsDefaultAction interface{} `json:"scmIpSecurityRestrictionsDefaultAction,omitemtpy"` - SitePort interface{} `json:"sitePort,omitemtpy"` - StorageType interface{} `json:"storageType,omitemtpy"` - SupportedTlsCipherSuites interface{} `json:"supportedTlsCipherSuites,omitemtpy"` - WinAuthAdminState interface{} `json:"winAuthAdminState,omitemtpy"` - WinAuthTenantState interface{} `json:"winAuthTenantState,omitemtpy"` + AntivirusScanEnabled bool `json:"antivirusScanEnabled,omitempty"` + AzureMonitorLogCategories interface{} `json:"azureMonitorLogCategories,omitempty"` + CustomAppPoolIdentityAdminState interface{} `json:"customAppPoolIdentityAdminState,omitempty"` + CustomAppPoolIdentityTenantState interface{} `json:"customAppPoolIdentityTenantState,omitempty"` + ElasticWebAppScaleLimit interface{} `json:"elasticWebAppScaleLimit,omitempty"` + FileChangeAuditEnabled bool `json:"fileChangeAuditEnabled,omitempty"` + Http20ProxyFlag interface{} `json:"http20ProxyFlag,omitempty"` + IpSecurityRestrictionsDefaultAction interface{} `json:"ipSecurityRestrictionsDefaultAction,omitempty"` + Metadata interface{} `json:"metadata,omitempty"` + MinTlsCipherSuite interface{} `json:"minTlsCipherSuite,omitempty"` + PublishingPassword interface{} `json:"publishingPassword,omitempty"` + RoutingRules interface{} `json:"routingRules,omitempty"` + RuntimeADUser interface{} `json:"runtimeADUser,omitempty"` + RuntimeADUserPassword interface{} `json:"runtimeADUserPassword,omitempty"` + ScmIpSecurityRestrictionsDefaultAction interface{} `json:"scmIpSecurityRestrictionsDefaultAction,omitempty"` + SitePort interface{} `json:"sitePort,omitempty"` + StorageType interface{} `json:"storageType,omitempty"` + SupportedTlsCipherSuites interface{} `json:"supportedTlsCipherSuites,omitempty"` + WinAuthAdminState interface{} `json:"winAuthAdminState,omitempty"` + WinAuthTenantState interface{} `json:"winAuthTenantState,omitempty"` } type ApiDefinitionInfo struct { diff --git a/pipeline/pipeline.go b/pipeline/pipeline.go index f3110b0..93b1f30 100644 --- a/pipeline/pipeline.go +++ b/pipeline/pipeline.go @@ -118,6 +118,12 @@ func Demux[D, T any](done <-chan D, in <-chan T, size int) []<-chan T { return internal.Map(outputs, func(out chan T) <-chan T { return out }) } +func ToAny[D, T any](done <-chan D, in <-chan T) <-chan any { + return Map(done, in, func(t T) any { + return any(t) + }) +} + func Map[D, T, U any](done <-chan D, in <-chan T, fn func(T) U) <-chan U { out := make(chan U) go func() { @@ -143,7 +149,7 @@ func Filter[D, T any](done <-chan D, in <-chan T, fn func(T) bool) <-chan T { } // Tee copies the stream of data from a single channel to zero or more channels -func Tee[D, T any](done <-chan D, in <-chan T, outputs ...chan<- T) { +func Tee[D, T any](done <-chan D, in <-chan T, outputs ...chan T) { go func() { // Need to close outputs when goroutine exits to ensure we avoid deadlock defer func() { @@ -163,6 +169,16 @@ func Tee[D, T any](done <-chan D, in <-chan T, outputs ...chan<- T) { }() } +func TeeFixed[D, T any](done <-chan D, in <-chan T, size int) []<-chan T { + out := internal.Map(make([]any, size), func(_ any) chan T { + return make(chan T) + }) + Tee(done, in, out...) + return internal.Map(out, func(c chan T) <-chan T { + return c + }) +} + func Batch[D, T any](done <-chan D, in <-chan T, maxItems int, maxTimeout time.Duration) <-chan []T { out := make(chan []T)