diff --git a/cmd/stratus/main.go b/cmd/stratus/main.go index e8cfc3d3..f22bc637 100644 --- a/cmd/stratus/main.go +++ b/cmd/stratus/main.go @@ -14,6 +14,7 @@ func init() { showCmd := buildShowCmd() warmupCmd := buildWarmupCmd() detonateCmd := buildDetonateCmd() + revertCmd := buildRevertCmd() statusCmd := buildStatusCmd() cleanupCmd := buildCleanupCmd() versionCmd := buildVersionCmd() @@ -22,6 +23,7 @@ func init() { rootCmd.AddCommand(showCmd) rootCmd.AddCommand(warmupCmd) rootCmd.AddCommand(detonateCmd) + rootCmd.AddCommand(revertCmd) rootCmd.AddCommand(statusCmd) rootCmd.AddCommand(cleanupCmd) rootCmd.AddCommand(versionCmd) diff --git a/cmd/stratus/revert_cmd.go b/cmd/stratus/revert_cmd.go new file mode 100644 index 00000000..58bb6ff6 --- /dev/null +++ b/cmd/stratus/revert_cmd.go @@ -0,0 +1,43 @@ +package main + +import ( + "errors" + "github.com/datadog/stratus-red-team/pkg/stratus" + "github.com/datadog/stratus-red-team/pkg/stratus/runner" + "github.com/spf13/cobra" + "log" +) + +var revertForce bool + +func buildRevertCmd() *cobra.Command { + detonateCmd := &cobra.Command{ + Use: "revert", + Short: "Revert the detonation of an attack technique", + Args: func(cmd *cobra.Command, args []string) error { + if len(args) == 0 { + return errors.New("you must specify at least one attack technique") + } + _, err := resolveTechniques(args) + return err + }, + Run: func(cmd *cobra.Command, args []string) { + techniques, _ := resolveTechniques(args) + doRevertCmd(techniques, revertForce) + }, + } + detonateCmd.Flags().BoolVarP(&revertForce, "force", "f", false, "Force attempt to reverting even if the technique is not in the DETONATED state") + return detonateCmd +} + +func doRevertCmd(techniques []*stratus.AttackTechnique, force bool) { + for i := range techniques { + stratusRunner := runner.NewRunner(techniques[i], force) + err := stratusRunner.Revert() + if err != nil { + log.Fatal(err) + } + } + + doStatusCmd(techniques) +} diff --git a/internal/attacktechniques/aws/defense-evasion/disable-cloudtrail/main.go b/internal/attacktechniques/aws/defense-evasion/disable-cloudtrail/main.go index 18507229..739904fe 100644 --- a/internal/attacktechniques/aws/defense-evasion/disable-cloudtrail/main.go +++ b/internal/attacktechniques/aws/defense-evasion/disable-cloudtrail/main.go @@ -30,6 +30,7 @@ Detonation: Calls cloudtrail:StopLogging `, PrerequisitesTerraformCode: tf, Detonate: detonate, + Revert: revert, }) } @@ -49,3 +50,15 @@ func detonate(params map[string]string) error { return nil } + +func revert(params map[string]string) error { + cloudtrailClient := cloudtrail.NewFromConfig(providers.AWS().GetConnection()) + trailName := params["cloudtrail_trail_name"] + + log.Println("Restarting CloudTrail trail " + trailName) + _, err := cloudtrailClient.StartLogging(context.Background(), &cloudtrail.StartLoggingInput{ + Name: aws.String(trailName), + }) + + return err +} diff --git a/internal/attacktechniques/aws/defense-evasion/remove-vpc-flow-logs/main.go b/internal/attacktechniques/aws/defense-evasion/remove-vpc-flow-logs/main.go index 65b4e4d7..c83f4277 100644 --- a/internal/attacktechniques/aws/defense-evasion/remove-vpc-flow-logs/main.go +++ b/internal/attacktechniques/aws/defense-evasion/remove-vpc-flow-logs/main.go @@ -49,3 +49,6 @@ func detonate(params map[string]string) error { return nil } + +// The technique is non-revertible once it has been detonated, otherwise it would require re-creating the VPC +// flow log programmatically, which we don't want as it's implemented in the Terraform for the warm-up phase diff --git a/internal/attacktechniques/aws/exfiltration/ebs-snapshot-share/main.go b/internal/attacktechniques/aws/exfiltration/ebs-snapshot-share/main.go index 34c6237d..f60928d9 100644 --- a/internal/attacktechniques/aws/exfiltration/ebs-snapshot-share/main.go +++ b/internal/attacktechniques/aws/exfiltration/ebs-snapshot-share/main.go @@ -30,9 +30,12 @@ Detonation: Calls ModifySnapshotAttribute to share the snapshot. `, PrerequisitesTerraformCode: tf, Detonate: detonate, + Revert: revert, }) } +const ShareWithAccountId = "012345678912" + func detonate(params map[string]string) error { ec2Client := ec2.NewFromConfig(providers.AWS().GetConnection()) @@ -46,7 +49,22 @@ func detonate(params map[string]string) error { SnapshotId: aws.String(ourSnapshotId), Attribute: types.SnapshotAttributeNameCreateVolumePermission, CreateVolumePermission: &types.CreateVolumePermissionModifications{ - Add: []types.CreateVolumePermission{{UserId: aws.String("012345678912")}}, + Add: []types.CreateVolumePermission{{UserId: aws.String(ShareWithAccountId)}}, + }, + }) + return err +} + +func revert(params map[string]string) error { + ec2Client := ec2.NewFromConfig(providers.AWS().GetConnection()) + ourSnapshotId := params["snapshot_id"] + + log.Println("Unsharing the volume snapshot " + ourSnapshotId) + _, err := ec2Client.ModifySnapshotAttribute(context.Background(), &ec2.ModifySnapshotAttributeInput{ + SnapshotId: aws.String(ourSnapshotId), + Attribute: types.SnapshotAttributeNameCreateVolumePermission, + CreateVolumePermission: &types.CreateVolumePermissionModifications{ + Remove: []types.CreateVolumePermission{{UserId: aws.String(ShareWithAccountId)}}, }, }) return err diff --git a/internal/attacktechniques/aws/exfiltration/s3-bucket-backdoor-bucket-policy/main.go b/internal/attacktechniques/aws/exfiltration/s3-bucket-backdoor-bucket-policy/main.go index 506e6961..64f94e21 100644 --- a/internal/attacktechniques/aws/exfiltration/s3-bucket-backdoor-bucket-policy/main.go +++ b/internal/attacktechniques/aws/exfiltration/s3-bucket-backdoor-bucket-policy/main.go @@ -33,6 +33,7 @@ Detonation: Backdoors the S3 bucket policy. `, PrerequisitesTerraformCode: tf, Detonate: detonate, + Revert: revert, }) } @@ -49,3 +50,15 @@ func detonate(params map[string]string) error { return err } + +func revert(params map[string]string) error { + s3Client := s3.NewFromConfig(providers.AWS().GetConnection()) + bucketName := params["bucket_name"] + + log.Println("Removing malicious bucket policy on " + bucketName) + _, err := s3Client.DeleteBucketPolicy(context.Background(), &s3.DeleteBucketPolicyInput{ + Bucket: aws.String(bucketName), + }) + + return err +} diff --git a/internal/attacktechniques/aws/persistence/iam-user-backdoor-existing/main.go b/internal/attacktechniques/aws/persistence/iam-user-backdoor-existing/main.go index 5dfc99fe..ceb9ee48 100644 --- a/internal/attacktechniques/aws/persistence/iam-user-backdoor-existing/main.go +++ b/internal/attacktechniques/aws/persistence/iam-user-backdoor-existing/main.go @@ -41,17 +41,27 @@ Detonation: Create the access key. log.Println("Successfully created access key " + *result.AccessKey.AccessKeyId) return nil }, - Cleanup: func() error { - // TODO: https://github.com/DataDog/stratus-red-team/issues/12 - /*iamClient := iam.NewFromConfig(providers.AWS().GetConnection()) - log.Println("Removing access key from IAM user") - result, err := iamClient.ListAccessKeys(context.Background(), &iam.ListAccessKeysInput{UserName: aws.String("sample-legit-user")}) + Revert: func(params map[string]string) error { + iamClient := iam.NewFromConfig(providers.AWS().GetConnection()) + userName := params["user_name"] + log.Println("Removing access key from IAM user " + userName) + result, err := iamClient.ListAccessKeys(context.Background(), &iam.ListAccessKeysInput{ + UserName: aws.String(userName), + }) if err != nil { return err } for i := range result.AccessKeyMetadata { - iamClient.DeleteAccessKey(context.Background(), &iam.DeleteAccessKeyInput{AccessKeyId: result.AccessKeyMetadata[i].AccessKeyId}) - }*/ + accessKeyId := result.AccessKeyMetadata[i].AccessKeyId + log.Println("Removing access key " + *accessKeyId) + _, err := iamClient.DeleteAccessKey(context.Background(), &iam.DeleteAccessKeyInput{ + AccessKeyId: accessKeyId, + UserName: aws.String(userName), + }) + if err != nil { + log.Println("failed: " + err.Error()) + } + } return nil }, diff --git a/internal/attacktechniques/aws/persistence/iam-user-create-new/main.go b/internal/attacktechniques/aws/persistence/iam-user-create-new/main.go index 2e92bd98..236bc0cd 100644 --- a/internal/attacktechniques/aws/persistence/iam-user-create-new/main.go +++ b/internal/attacktechniques/aws/persistence/iam-user-create-new/main.go @@ -62,15 +62,20 @@ Detonation: Creates the IAM user and attached 'AdministratorAccess' to it. return nil }, - Cleanup: func() error { + Revert: func(params map[string]string) error { iamClient := iam.NewFromConfig(providers.AWS().GetConnection()) - result, err := iamClient.ListAccessKeys(context.Background(), &iam.ListAccessKeysInput{UserName: userName}) + result, err := iamClient.ListAccessKeys(context.Background(), &iam.ListAccessKeysInput{ + UserName: userName, + }) if err != nil { return errors.New("unable to clean up IAM user access keys: " + err.Error()) } for i := range result.AccessKeyMetadata { accessKeyId := result.AccessKeyMetadata[i].AccessKeyId - _, err := iamClient.DeleteAccessKey(context.Background(), &iam.DeleteAccessKeyInput{UserName: userName, AccessKeyId: accessKeyId}) + _, err := iamClient.DeleteAccessKey(context.Background(), &iam.DeleteAccessKeyInput{ + UserName: userName, + AccessKeyId: accessKeyId, + }) if err != nil { return errors.New("unable to remove IAM user access key " + *accessKeyId + ": " + err.Error()) } diff --git a/pkg/stratus/attack_technique.go b/pkg/stratus/attack_technique.go index 72510e5b..7f495f39 100644 --- a/pkg/stratus/attack_technique.go +++ b/pkg/stratus/attack_technique.go @@ -10,8 +10,8 @@ type AttackTechnique struct { Description string MitreAttackTactics []mitreattack.Tactic Platform Platform - Detonate func(terraformOutputs map[string]string) error - Cleanup func() error + Detonate func(params map[string]string) error + Revert func(params map[string]string) error PrerequisitesTerraformCode []byte } diff --git a/pkg/stratus/runner/runner.go b/pkg/stratus/runner/runner.go index f8b246cb..a5d3af97 100644 --- a/pkg/stratus/runner/runner.go +++ b/pkg/stratus/runner/runner.go @@ -104,8 +104,32 @@ func (m *Runner) Detonate() error { return nil } +func (m *Runner) Revert() error { + if m.GetState() != stratus.AttackTechniqueStatusDetonated && !m.ShouldForce { + return errors.New(m.Technique.ID + " is not in DETONATED state and should not need to be reverted, use --force to force") + } + + outputs, err := m.StateManager.GetTerraformOutputs() + if err != nil { + return errors.New("unable to retrieve outputs of " + m.Technique.ID + ": " + err.Error()) + } + + log.Println("Reverting detonation of technique " + m.Technique.ID) + + if m.Technique.Revert != nil { + err = m.Technique.Revert(outputs) + if err != nil { + return errors.New("unable to revert detonation of " + m.Technique.ID + ": " + err.Error()) + } + } + + m.setState(stratus.AttackTechniqueStatusWarm) + + return nil +} + func (m *Runner) CleanUp() error { - var techniqueCleanupErr error + var techniqueRevertErr error var prerequisitesCleanupErr error // Has the technique already been cleaned up? @@ -116,10 +140,10 @@ func (m *Runner) CleanUp() error { log.Println("Cleaning up " + m.Technique.ID) // Revert detonation - if m.Technique.Cleanup != nil { - techniqueCleanupErr = m.Technique.Cleanup() - if techniqueCleanupErr != nil { - log.Println("Warning: unable to clean up TTP: " + techniqueCleanupErr.Error()) + if m.Technique.Revert != nil && m.GetState() == stratus.AttackTechniqueStatusDetonated { + techniqueRevertErr = m.Revert() + if techniqueRevertErr != nil { + log.Println("Warning: unable to revert detonation of " + m.Technique.ID + ": " + techniqueRevertErr.Error()) } } @@ -140,7 +164,7 @@ func (m *Runner) CleanUp() error { log.Println("Warning: unable to remove technique directory " + m.TerraformDir + ": " + err.Error()) } - return utils.CoalesceErr(techniqueCleanupErr, prerequisitesCleanupErr, err) + return utils.CoalesceErr(techniqueRevertErr, prerequisitesCleanupErr, err) } func (m *Runner) ValidatePlatformRequirements() { diff --git a/pkg/stratus/runner/runner_test.go b/pkg/stratus/runner/runner_test.go index f29149f0..053e8761 100644 --- a/pkg/stratus/runner/runner_test.go +++ b/pkg/stratus/runner/runner_test.go @@ -147,6 +147,90 @@ func TestRunnerDetonate(t *testing.T) { state.AssertCalled(t, "SetTechniqueState", stratus.AttackTechniqueState(stratus.AttackTechniqueStatusDetonated)) } +func TestRunnerRevert(t *testing.T) { + type TestRevertScenario struct { + Name string + TechniqueState stratus.AttackTechniqueState + Force bool + ExpectDidCallRevertFunction bool + ExpectDidChangeStateToWarm bool + ExpectError bool + } + scenario := []TestRevertScenario{ + { + Name: "DetonatedTechniqueIsReverted", + TechniqueState: stratus.AttackTechniqueStatusDetonated, + Force: false, + ExpectDidCallRevertFunction: true, + ExpectDidChangeStateToWarm: true, + ExpectError: false, + }, + { + Name: "WarmTechniqueIsNotReverted", + TechniqueState: stratus.AttackTechniqueStatusWarm, + Force: false, + ExpectDidCallRevertFunction: false, + ExpectDidChangeStateToWarm: false, + ExpectError: true, + }, + { + Name: "WarmTechniqueIsRevertedWithForce", + TechniqueState: stratus.AttackTechniqueStatusWarm, + Force: true, + ExpectDidCallRevertFunction: true, + ExpectDidChangeStateToWarm: true, + ExpectError: false, + }, + } + + for i := range scenario { + t.Run(scenario[i].Name, func(t *testing.T) { + state := new(statemocks.StateManager) + state.On("GetRootDirectory").Return("/root") + state.On("ExtractTechnique").Return(nil) + state.On("GetTerraformOutputs").Return(map[string]string{"foo": "bar"}, nil) + state.On("GetTechniqueState", mock.Anything).Return(scenario[i].TechniqueState) + state.On("SetTechniqueState", mock.Anything).Return(nil) + + var wasReverted = false + runner := Runner{ + Technique: &stratus.AttackTechnique{ + ID: "foo", + Detonate: func(map[string]string) error { return nil }, + Revert: func(params map[string]string) error { + wasReverted = true + return nil + }, + }, + ShouldForce: scenario[i].Force, + StateManager: state, + } + runner.initialize() + + err := runner.Revert() + + if scenario[i].ExpectError { + assert.NotNil(t, err) + } else { + assert.Nil(t, err) + } + + if scenario[i].ExpectDidCallRevertFunction { + assert.True(t, wasReverted) + } else { + assert.False(t, wasReverted) + } + + if scenario[i].ExpectDidChangeStateToWarm { + state.AssertCalled(t, "SetTechniqueState", stratus.AttackTechniqueState(stratus.AttackTechniqueStatusWarm)) + } else { + state.AssertNotCalled(t, "SetTechniqueState", stratus.AttackTechniqueState(stratus.AttackTechniqueStatusWarm)) + } + }) + } + +} + func TestRunnerCleanup(t *testing.T) { type RunnerCleanupTestScenario struct { Name string