Skip to content

Commit

Permalink
Add command to revert TTP detonation (closes #12)
Browse files Browse the repository at this point in the history
  • Loading branch information
christophetd committed Jan 19, 2022
1 parent 1bed795 commit 7e84bff
Show file tree
Hide file tree
Showing 11 changed files with 234 additions and 19 deletions.
2 changes: 2 additions & 0 deletions cmd/stratus/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ func init() {
showCmd := buildShowCmd()
warmupCmd := buildWarmupCmd()
detonateCmd := buildDetonateCmd()
revertCmd := buildRevertCmd()
statusCmd := buildStatusCmd()
cleanupCmd := buildCleanupCmd()
versionCmd := buildVersionCmd()
Expand All @@ -22,6 +23,7 @@ func init() {
rootCmd.AddCommand(showCmd)
rootCmd.AddCommand(warmupCmd)
rootCmd.AddCommand(detonateCmd)
rootCmd.AddCommand(revertCmd)
rootCmd.AddCommand(statusCmd)
rootCmd.AddCommand(cleanupCmd)
rootCmd.AddCommand(versionCmd)
Expand Down
43 changes: 43 additions & 0 deletions cmd/stratus/revert_cmd.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
package main

import (
"errors"
"github.com/datadog/stratus-red-team/pkg/stratus"
"github.com/datadog/stratus-red-team/pkg/stratus/runner"
"github.com/spf13/cobra"
"log"
)

var revertForce bool

func buildRevertCmd() *cobra.Command {
detonateCmd := &cobra.Command{
Use: "revert",
Short: "Revert the detonation of an attack technique",
Args: func(cmd *cobra.Command, args []string) error {
if len(args) == 0 {
return errors.New("you must specify at least one attack technique")
}
_, err := resolveTechniques(args)
return err
},
Run: func(cmd *cobra.Command, args []string) {
techniques, _ := resolveTechniques(args)
doRevertCmd(techniques, revertForce)
},
}
detonateCmd.Flags().BoolVarP(&revertForce, "force", "f", false, "Force attempt to reverting even if the technique is not in the DETONATED state")
return detonateCmd
}

func doRevertCmd(techniques []*stratus.AttackTechnique, force bool) {
for i := range techniques {
stratusRunner := runner.NewRunner(techniques[i], force)
err := stratusRunner.Revert()
if err != nil {
log.Fatal(err)
}
}

doStatusCmd(techniques)
}
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ Detonation: Calls cloudtrail:StopLogging
`,
PrerequisitesTerraformCode: tf,
Detonate: detonate,
Revert: revert,
})
}

Expand All @@ -49,3 +50,15 @@ func detonate(params map[string]string) error {

return nil
}

func revert(params map[string]string) error {
cloudtrailClient := cloudtrail.NewFromConfig(providers.AWS().GetConnection())
trailName := params["cloudtrail_trail_name"]

log.Println("Restarting CloudTrail trail " + trailName)
_, err := cloudtrailClient.StartLogging(context.Background(), &cloudtrail.StartLoggingInput{
Name: aws.String(trailName),
})

return err
}
Original file line number Diff line number Diff line change
Expand Up @@ -49,3 +49,6 @@ func detonate(params map[string]string) error {

return nil
}

// The technique is non-revertible once it has been detonated, otherwise it would require re-creating the VPC
// flow log programmatically, which we don't want as it's implemented in the Terraform for the warm-up phase
Original file line number Diff line number Diff line change
Expand Up @@ -30,9 +30,12 @@ Detonation: Calls ModifySnapshotAttribute to share the snapshot.
`,
PrerequisitesTerraformCode: tf,
Detonate: detonate,
Revert: revert,
})
}

const ShareWithAccountId = "012345678912"

func detonate(params map[string]string) error {
ec2Client := ec2.NewFromConfig(providers.AWS().GetConnection())

Expand All @@ -46,7 +49,22 @@ func detonate(params map[string]string) error {
SnapshotId: aws.String(ourSnapshotId),
Attribute: types.SnapshotAttributeNameCreateVolumePermission,
CreateVolumePermission: &types.CreateVolumePermissionModifications{
Add: []types.CreateVolumePermission{{UserId: aws.String("012345678912")}},
Add: []types.CreateVolumePermission{{UserId: aws.String(ShareWithAccountId)}},
},
})
return err
}

func revert(params map[string]string) error {
ec2Client := ec2.NewFromConfig(providers.AWS().GetConnection())
ourSnapshotId := params["snapshot_id"]

log.Println("Unsharing the volume snapshot " + ourSnapshotId)
_, err := ec2Client.ModifySnapshotAttribute(context.Background(), &ec2.ModifySnapshotAttributeInput{
SnapshotId: aws.String(ourSnapshotId),
Attribute: types.SnapshotAttributeNameCreateVolumePermission,
CreateVolumePermission: &types.CreateVolumePermissionModifications{
Remove: []types.CreateVolumePermission{{UserId: aws.String(ShareWithAccountId)}},
},
})
return err
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ Detonation: Backdoors the S3 bucket policy.
`,
PrerequisitesTerraformCode: tf,
Detonate: detonate,
Revert: revert,
})
}

Expand All @@ -49,3 +50,15 @@ func detonate(params map[string]string) error {

return err
}

func revert(params map[string]string) error {
s3Client := s3.NewFromConfig(providers.AWS().GetConnection())
bucketName := params["bucket_name"]

log.Println("Removing malicious bucket policy on " + bucketName)
_, err := s3Client.DeleteBucketPolicy(context.Background(), &s3.DeleteBucketPolicyInput{
Bucket: aws.String(bucketName),
})

return err
}
Original file line number Diff line number Diff line change
Expand Up @@ -41,17 +41,27 @@ Detonation: Create the access key.
log.Println("Successfully created access key " + *result.AccessKey.AccessKeyId)
return nil
},
Cleanup: func() error {
// TODO: https://github.com/DataDog/stratus-red-team/issues/12
/*iamClient := iam.NewFromConfig(providers.AWS().GetConnection())
log.Println("Removing access key from IAM user")
result, err := iamClient.ListAccessKeys(context.Background(), &iam.ListAccessKeysInput{UserName: aws.String("sample-legit-user")})
Revert: func(params map[string]string) error {
iamClient := iam.NewFromConfig(providers.AWS().GetConnection())
userName := params["user_name"]
log.Println("Removing access key from IAM user " + userName)
result, err := iamClient.ListAccessKeys(context.Background(), &iam.ListAccessKeysInput{
UserName: aws.String(userName),
})
if err != nil {
return err
}
for i := range result.AccessKeyMetadata {
iamClient.DeleteAccessKey(context.Background(), &iam.DeleteAccessKeyInput{AccessKeyId: result.AccessKeyMetadata[i].AccessKeyId})
}*/
accessKeyId := result.AccessKeyMetadata[i].AccessKeyId
log.Println("Removing access key " + *accessKeyId)
_, err := iamClient.DeleteAccessKey(context.Background(), &iam.DeleteAccessKeyInput{
AccessKeyId: accessKeyId,
UserName: aws.String(userName),
})
if err != nil {
log.Println("failed: " + err.Error())
}
}

return nil
},
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -62,15 +62,20 @@ Detonation: Creates the IAM user and attached 'AdministratorAccess' to it.

return nil
},
Cleanup: func() error {
Revert: func(params map[string]string) error {
iamClient := iam.NewFromConfig(providers.AWS().GetConnection())
result, err := iamClient.ListAccessKeys(context.Background(), &iam.ListAccessKeysInput{UserName: userName})
result, err := iamClient.ListAccessKeys(context.Background(), &iam.ListAccessKeysInput{
UserName: userName,
})
if err != nil {
return errors.New("unable to clean up IAM user access keys: " + err.Error())
}
for i := range result.AccessKeyMetadata {
accessKeyId := result.AccessKeyMetadata[i].AccessKeyId
_, err := iamClient.DeleteAccessKey(context.Background(), &iam.DeleteAccessKeyInput{UserName: userName, AccessKeyId: accessKeyId})
_, err := iamClient.DeleteAccessKey(context.Background(), &iam.DeleteAccessKeyInput{
UserName: userName,
AccessKeyId: accessKeyId,
})
if err != nil {
return errors.New("unable to remove IAM user access key " + *accessKeyId + ": " + err.Error())
}
Expand Down
4 changes: 2 additions & 2 deletions pkg/stratus/attack_technique.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,8 @@ type AttackTechnique struct {
Description string
MitreAttackTactics []mitreattack.Tactic
Platform Platform
Detonate func(terraformOutputs map[string]string) error
Cleanup func() error
Detonate func(params map[string]string) error
Revert func(params map[string]string) error
PrerequisitesTerraformCode []byte
}

Expand Down
36 changes: 30 additions & 6 deletions pkg/stratus/runner/runner.go
Original file line number Diff line number Diff line change
Expand Up @@ -104,8 +104,32 @@ func (m *Runner) Detonate() error {
return nil
}

func (m *Runner) Revert() error {
if m.GetState() != stratus.AttackTechniqueStatusDetonated && !m.ShouldForce {
return errors.New(m.Technique.ID + " is not in DETONATED state and should not need to be reverted, use --force to force")
}

outputs, err := m.StateManager.GetTerraformOutputs()
if err != nil {
return errors.New("unable to retrieve outputs of " + m.Technique.ID + ": " + err.Error())
}

log.Println("Reverting detonation of technique " + m.Technique.ID)

if m.Technique.Revert != nil {
err = m.Technique.Revert(outputs)
if err != nil {
return errors.New("unable to revert detonation of " + m.Technique.ID + ": " + err.Error())
}
}

m.setState(stratus.AttackTechniqueStatusWarm)

return nil
}

func (m *Runner) CleanUp() error {
var techniqueCleanupErr error
var techniqueRevertErr error
var prerequisitesCleanupErr error

// Has the technique already been cleaned up?
Expand All @@ -116,10 +140,10 @@ func (m *Runner) CleanUp() error {
log.Println("Cleaning up " + m.Technique.ID)

// Revert detonation
if m.Technique.Cleanup != nil {
techniqueCleanupErr = m.Technique.Cleanup()
if techniqueCleanupErr != nil {
log.Println("Warning: unable to clean up TTP: " + techniqueCleanupErr.Error())
if m.Technique.Revert != nil && m.GetState() == stratus.AttackTechniqueStatusDetonated {
techniqueRevertErr = m.Revert()
if techniqueRevertErr != nil {
log.Println("Warning: unable to revert detonation of " + m.Technique.ID + ": " + techniqueRevertErr.Error())
}
}

Expand All @@ -140,7 +164,7 @@ func (m *Runner) CleanUp() error {
log.Println("Warning: unable to remove technique directory " + m.TerraformDir + ": " + err.Error())
}

return utils.CoalesceErr(techniqueCleanupErr, prerequisitesCleanupErr, err)
return utils.CoalesceErr(techniqueRevertErr, prerequisitesCleanupErr, err)
}

func (m *Runner) ValidatePlatformRequirements() {
Expand Down
84 changes: 84 additions & 0 deletions pkg/stratus/runner/runner_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,90 @@ func TestRunnerDetonate(t *testing.T) {
state.AssertCalled(t, "SetTechniqueState", stratus.AttackTechniqueState(stratus.AttackTechniqueStatusDetonated))
}

func TestRunnerRevert(t *testing.T) {
type TestRevertScenario struct {
Name string
TechniqueState stratus.AttackTechniqueState
Force bool
ExpectDidCallRevertFunction bool
ExpectDidChangeStateToWarm bool
ExpectError bool
}
scenario := []TestRevertScenario{
{
Name: "DetonatedTechniqueIsReverted",
TechniqueState: stratus.AttackTechniqueStatusDetonated,
Force: false,
ExpectDidCallRevertFunction: true,
ExpectDidChangeStateToWarm: true,
ExpectError: false,
},
{
Name: "WarmTechniqueIsNotReverted",
TechniqueState: stratus.AttackTechniqueStatusWarm,
Force: false,
ExpectDidCallRevertFunction: false,
ExpectDidChangeStateToWarm: false,
ExpectError: true,
},
{
Name: "WarmTechniqueIsRevertedWithForce",
TechniqueState: stratus.AttackTechniqueStatusWarm,
Force: true,
ExpectDidCallRevertFunction: true,
ExpectDidChangeStateToWarm: true,
ExpectError: false,
},
}

for i := range scenario {
t.Run(scenario[i].Name, func(t *testing.T) {
state := new(statemocks.StateManager)
state.On("GetRootDirectory").Return("/root")
state.On("ExtractTechnique").Return(nil)
state.On("GetTerraformOutputs").Return(map[string]string{"foo": "bar"}, nil)
state.On("GetTechniqueState", mock.Anything).Return(scenario[i].TechniqueState)
state.On("SetTechniqueState", mock.Anything).Return(nil)

var wasReverted = false
runner := Runner{
Technique: &stratus.AttackTechnique{
ID: "foo",
Detonate: func(map[string]string) error { return nil },
Revert: func(params map[string]string) error {
wasReverted = true
return nil
},
},
ShouldForce: scenario[i].Force,
StateManager: state,
}
runner.initialize()

err := runner.Revert()

if scenario[i].ExpectError {
assert.NotNil(t, err)
} else {
assert.Nil(t, err)
}

if scenario[i].ExpectDidCallRevertFunction {
assert.True(t, wasReverted)
} else {
assert.False(t, wasReverted)
}

if scenario[i].ExpectDidChangeStateToWarm {
state.AssertCalled(t, "SetTechniqueState", stratus.AttackTechniqueState(stratus.AttackTechniqueStatusWarm))
} else {
state.AssertNotCalled(t, "SetTechniqueState", stratus.AttackTechniqueState(stratus.AttackTechniqueStatusWarm))
}
})
}

}

func TestRunnerCleanup(t *testing.T) {
type RunnerCleanupTestScenario struct {
Name string
Expand Down

0 comments on commit 7e84bff

Please sign in to comment.