Skip to content

Commit

Permalink
aws: Add S3DescribeBucket, and S3GetAccessPoint (#3023)
Browse files Browse the repository at this point in the history
  • Loading branch information
tolujimoh committed May 23, 2024
1 parent 6209809 commit ee9d920
Show file tree
Hide file tree
Showing 9 changed files with 226 additions and 26 deletions.
49 changes: 33 additions & 16 deletions backend/mock/service/awsmock/awsmock.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,29 +23,15 @@ import (

dynamodbv1 "github.com/lyft/clutch/backend/api/aws/dynamodb/v1"
ec2v1 "github.com/lyft/clutch/backend/api/aws/ec2/v1"
iamv1 "github.com/lyft/clutch/backend/api/aws/iam/v1"
kinesisv1 "github.com/lyft/clutch/backend/api/aws/kinesis/v1"
s3v1 "github.com/lyft/clutch/backend/api/aws/s3/v1"
"github.com/lyft/clutch/backend/service"
clutchawsclient "github.com/lyft/clutch/backend/service/aws"
)

type svc struct{}

func (s *svc) S3GetAccessPointPolicy(ctx context.Context, account, region, accessPointName string, accountID string) (*s3control.GetAccessPointPolicyOutput, error) {
return &s3control.GetAccessPointPolicyOutput{
Policy: aws.String("{}"),
ResultMetadata: middleware.Metadata{},
}, nil
}

func (s *svc) GetIAMRole(ctx context.Context, account, region, roleName string) (*iam.GetRoleOutput, error) {
return &iam.GetRoleOutput{
Role: &iamtypes.Role{
RoleName: aws.String(roleName),
Arn: aws.String(fmt.Sprintf("arn:aws:iam::%s:role/%s", account, roleName)),
},
}, nil
}

func (s *svc) GetDirectClient(account string, region string) (clutchawsclient.DirectClient, error) {
panic("implement me")
}
Expand Down Expand Up @@ -132,6 +118,12 @@ func (s *svc) RebootInstances(ctx context.Context, account, region string, ids [
return nil
}

func (s *svc) S3DescribeBucket(ctx context.Context, account, region, bucket string) (*s3v1.Bucket, error) {
return &s3v1.Bucket{
Name: bucket,
}, nil
}

func (s *svc) S3StreamingGet(ctx context.Context, account, region, bucket, key string) (io.ReadCloser, error) {
panic("implement me")
}
Expand All @@ -143,6 +135,13 @@ func (s *svc) S3GetBucketPolicy(ctx context.Context, account, region, bucket, ac
}, nil
}

func (s *svc) S3GetAccessPointPolicy(ctx context.Context, account, region, accessPointName string, accountID string) (*s3control.GetAccessPointPolicyOutput, error) {
return &s3control.GetAccessPointPolicyOutput{
Policy: aws.String("{}"),
ResultMetadata: middleware.Metadata{},
}, nil
}

func (s *svc) DescribeTable(ctx context.Context, account, region, tableName string) (*dynamodbv1.Table, error) {
ret := &dynamodbv1.Table{
Name: tableName,
Expand Down Expand Up @@ -324,6 +323,24 @@ func (s *svc) SimulateCustomPolicy(ctx context.Context, account, region string,
}, nil
}

func (s *svc) GetIAMRole(ctx context.Context, account, region, roleName string) (*iamv1.Role, error) {
return &iamv1.Role{
Name: roleName,
Arn: fmt.Sprintf("arn:aws:iam::%s:role/%s", account, roleName),
}, nil
}

func (s *svc) S3GetAccessPoint(ctx context.Context, account, region, accessPointName, accountId string) (*s3v1.AccessPoint, error) {
return &s3v1.AccessPoint{
Name: accessPointName,
AccessPointArn: fmt.Sprintf("arn:aws:s3:%s:%s:accesspoint/%s", region, account, accessPointName),
Bucket: "my-bucket",
Alias: "alias",
BucketAccountId: accountId,
CreationDate: timestamppb.New(time.Now()),
}, nil
}

func New() clutchawsclient.Client {
return &svc{}
}
Expand Down
8 changes: 6 additions & 2 deletions backend/service/aws/aws.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,9 @@ import (

dynamodbv1 "github.com/lyft/clutch/backend/api/aws/dynamodb/v1"
ec2v1 "github.com/lyft/clutch/backend/api/aws/ec2/v1"
iamv1 "github.com/lyft/clutch/backend/api/aws/iam/v1"
kinesisv1 "github.com/lyft/clutch/backend/api/aws/kinesis/v1"
s3v1 "github.com/lyft/clutch/backend/api/aws/s3/v1"
awsv1 "github.com/lyft/clutch/backend/api/config/service/aws/v1"
topologyv1 "github.com/lyft/clutch/backend/api/topology/v1"
"github.com/lyft/clutch/backend/service"
Expand Down Expand Up @@ -177,10 +179,12 @@ type Client interface {
DescribeKinesisStream(ctx context.Context, account, region, streamName string) (*kinesisv1.Stream, error)
UpdateKinesisShardCount(ctx context.Context, account, region, streamName string, targetShardCount int32) error

S3DescribeBucket(ctx context.Context, account, region, bucket string) (*s3v1.Bucket, error)
S3GetAccessPoint(ctx context.Context, account, region, accessPointName, accountId string) (*s3v1.AccessPoint, error)
S3GetAccessPointPolicy(ctx context.Context, account, region, accessPointName, accountId string) (*s3control.GetAccessPointPolicyOutput, error)
S3GetBucketPolicy(ctx context.Context, account, region, bucket, accountID string) (*s3.GetBucketPolicyOutput, error)
S3StreamingGet(ctx context.Context, account, region, bucket, key string) (io.ReadCloser, error)

S3GetAccessPointPolicy(ctx context.Context, account, region, accessPointName, accountID string) (*s3control.GetAccessPointPolicyOutput, error)
DescribeTable(ctx context.Context, account, region, tableName string) (*dynamodbv1.Table, error)
UpdateCapacity(ctx context.Context, account, region, tableName string, targetTableCapacity *dynamodbv1.Throughput, indexUpdates []*dynamodbv1.IndexUpdateAction, ignoreMaximums bool) (*dynamodbv1.Table, error)
BatchGetItem(ctx context.Context, account, region string, params *dynamodb.BatchGetItemInput) (*dynamodb.BatchGetItemOutput, error)
Expand All @@ -189,7 +193,7 @@ type Client interface {
GetCallerIdentity(ctx context.Context, account, region string) (*sts.GetCallerIdentityOutput, error)

SimulateCustomPolicy(ctx context.Context, account, region string, customPolicySimulatorParams *iam.SimulateCustomPolicyInput) (*iam.SimulateCustomPolicyOutput, error)
GetIAMRole(ctx context.Context, account, region, roleName string) (*iam.GetRoleOutput, error)
GetIAMRole(ctx context.Context, account, region, roleName string) (*iamv1.Role, error)

Accounts() []string
AccountsAndRegions() map[string][]string
Expand Down
19 changes: 17 additions & 2 deletions backend/service/aws/iam.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,9 @@ import (
"context"

"github.com/aws/aws-sdk-go-v2/service/iam"
"google.golang.org/protobuf/types/known/timestamppb"

iamv1 "github.com/lyft/clutch/backend/api/aws/iam/v1"
)

func (c *client) SimulateCustomPolicy(ctx context.Context, account, region string, customPolicySimulatorParams *iam.SimulateCustomPolicyInput) (*iam.SimulateCustomPolicyOutput, error) {
Expand All @@ -20,13 +23,25 @@ func (c *client) GetIAMRole(
account,
region,
roleName string,
) (*iam.GetRoleOutput, error) {
) (*iamv1.Role, error) {
cl, err := c.getAccountRegionClient(account, region)
if err != nil {
return nil, err
}

return cl.iam.GetRole(ctx, &iam.GetRoleInput{
role, err := cl.iam.GetRole(ctx, &iam.GetRoleInput{
RoleName: &roleName,
})
if err != nil {
return nil, err
}

return &iamv1.Role{
Arn: *role.Role.Arn,
Name: *role.Role.RoleName,
CreatedDate: timestamppb.New(*role.Role.CreateDate),
Id: *role.Role.RoleId,
Account: account,
Region: region,
}, nil
}
19 changes: 14 additions & 5 deletions backend/service/aws/iam_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,10 @@ import (
type mockIAM struct {
getSimulationResultsErr error
getSimulationResults *iam.SimulateCustomPolicyOutput
getGetIAMRoleResultsErr error
getGetIAMRoleResults *iam.GetRoleOutput
getIAMRoleResultsErr error
getIAMRoleResults *iam.GetRoleOutput
listIAMRolesResultsErr error
listIAMRolesResults *iam.ListRolesOutput
}

func TestIAMSimulateCustomPolicy(t *testing.T) {
Expand Down Expand Up @@ -108,8 +110,15 @@ func (m *mockIAM) SimulateCustomPolicy(ctx context.Context, params *iam.Simulate
}

func (m *mockIAM) GetRole(ctx context.Context, params *iam.GetRoleInput, optFns ...func(*iam.Options)) (*iam.GetRoleOutput, error) {
if m.getGetIAMRoleResultsErr != nil {
return nil, m.getGetIAMRoleResultsErr
if m.getIAMRoleResultsErr != nil {
return nil, m.getIAMRoleResultsErr
}
return m.getGetIAMRoleResults, nil
return m.getIAMRoleResults, nil
}

func (m *mockIAM) ListRoles(ctx context.Context, params *iam.ListRolesInput, optFns ...func(*iam.Options)) (*iam.ListRolesOutput, error) {
if m.listIAMRolesResultsErr != nil {
return nil, m.listIAMRolesResultsErr
}
return m.listIAMRolesResults, nil
}
7 changes: 6 additions & 1 deletion backend/service/aws/iface.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,15 @@ import (
)

type s3Client interface {
HeadBucket(ctx context.Context, params *s3.HeadBucketInput, optFns ...func(*s3.Options)) (*s3.HeadBucketOutput, error)
GetObject(ctx context.Context, params *s3.GetObjectInput, optFns ...func(*s3.Options)) (*s3.GetObjectOutput, error)
GetBucketPolicy(ctx context.Context, params *s3.GetBucketPolicyInput, optFns ...func(*s3.Options)) (*s3.GetBucketPolicyOutput, error)
ListBuckets(ctx context.Context, params *s3.ListBucketsInput, optFns ...func(*s3.Options)) (*s3.ListBucketsOutput, error)
}

type s3ControlClient interface {
ListAccessPoints(ctx context.Context, params *s3control.ListAccessPointsInput, optFns ...func(*s3control.Options)) (*s3control.ListAccessPointsOutput, error)
GetAccessPoint(ctx context.Context, params *s3control.GetAccessPointInput, optFns ...func(*s3control.Options)) (*s3control.GetAccessPointOutput, error)
GetAccessPointPolicy(ctx context.Context, params *s3control.GetAccessPointPolicyInput, optFns ...func(*s3control.Options)) (*s3control.GetAccessPointPolicyOutput, error)
}
type stsClient interface {
Expand All @@ -30,8 +34,9 @@ type stsClient interface {
}

type iamClient interface {
SimulateCustomPolicy(ctx context.Context, params *iam.SimulateCustomPolicyInput, optFns ...func(*iam.Options)) (*iam.SimulateCustomPolicyOutput, error)
GetRole(ctx context.Context, params *iam.GetRoleInput, optFns ...func(options *iam.Options)) (*iam.GetRoleOutput, error)
ListRoles(ctx context.Context, params *iam.ListRolesInput, optFns ...func(options *iam.Options)) (*iam.ListRolesOutput, error)
SimulateCustomPolicy(ctx context.Context, params *iam.SimulateCustomPolicyInput, optFns ...func(*iam.Options)) (*iam.SimulateCustomPolicyOutput, error)
}

type kinesisClient interface {
Expand Down
24 changes: 24 additions & 0 deletions backend/service/aws/s3.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@ import (

"github.com/aws/aws-sdk-go-v2/aws"
"github.com/aws/aws-sdk-go-v2/service/s3"

s3v1 "github.com/lyft/clutch/backend/api/aws/s3/v1"
)

func (c *client) S3GetBucketPolicy(ctx context.Context, account, region, bucket, accountID string) (*s3.GetBucketPolicyOutput, error) {
Expand Down Expand Up @@ -40,3 +42,25 @@ func (c *client) S3StreamingGet(ctx context.Context, account, region, bucket, ke

return out.Body, nil
}

func (c *client) S3DescribeBucket(ctx context.Context, account, region, bucket string) (*s3v1.Bucket, error) {
cl, err := c.getAccountRegionClient(account, region)
if err != nil {
return nil, err
}

in := &s3.HeadBucketInput{
Bucket: aws.String(bucket),
}

bucketHeaders, err := cl.s3.HeadBucket(ctx, in)
if err != nil {
return nil, err
}

return &s3v1.Bucket{
Name: bucket,
Region: *bucketHeaders.BucketRegion,
Account: account,
}, nil
}
77 changes: 77 additions & 0 deletions backend/service/aws/s3_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@ import (
"github.com/aws/aws-sdk-go-v2/service/s3"
"github.com/aws/smithy-go/middleware"
"github.com/stretchr/testify/assert"

s3v1 "github.com/lyft/clutch/backend/api/aws/s3/v1"
)

func TestS3StreamGet(t *testing.T) {
Expand Down Expand Up @@ -111,12 +113,79 @@ func TestS3GetBucketPolicyErrorHandling(t *testing.T) {
assert.Error(t, err2)
}

func TestS3DescribeBucket(t *testing.T) {
s3Client := &mockS3{
getHeadBucketOutput: &s3.HeadBucketOutput{
BucketRegion: aws.String("us-east-1"),
},
}

c := &client{
currentAccountAlias: "default",
accounts: map[string]*accountClients{
"default": {
clients: map[string]*regionalClient{
"us-east-1": {region: "us-east-1", s3: s3Client},
},
},
},
}

output, err := c.S3DescribeBucket(context.Background(), "default", "us-east-1", "clutch")
assert.NoError(t, err)
assert.Equal(t, output, &s3v1.Bucket{
Name: "clutch",
Region: "us-east-1",
Account: "default",
})
}

func TestS3DescribeBucketErrorHandling(t *testing.T) {
s3Client := &mockS3{
getHeadBucketErr: fmt.Errorf("error"),
}

c := &client{
currentAccountAlias: "default",
accounts: map[string]*accountClients{
"default": {
clients: map[string]*regionalClient{
"us-east-1": {region: "us-east-1", s3: s3Client},
},
},
},
}

output1, err1 := c.S3DescribeBucket(context.Background(), "default", "us-east-1", "clutch")
assert.Nil(t, output1)
assert.Error(t, err1)

// Test unknown region
output2, err2 := c.S3DescribeBucket(context.Background(), "default", "choice-region-1", "clutch")
assert.Nil(t, output2)
assert.Error(t, err2)
}

type mockS3 struct {
getObjectErr error
getObjectOutput *s3.GetObjectOutput

getObjectPolicyErr error
getObjectPolicyOutput *s3.GetBucketPolicyOutput

getHeadBucketErr error
getHeadBucketOutput *s3.HeadBucketOutput

listRolesErr error
listRolesOutput *s3.ListBucketsOutput
}

func (m *mockS3) HeadBucket(ctx context.Context, params *s3.HeadBucketInput, optFns ...func(*s3.Options)) (*s3.HeadBucketOutput, error) {
if m.getHeadBucketErr != nil {
return nil, m.getHeadBucketErr
}

return m.getHeadBucketOutput, nil
}

func (m *mockS3) GetObject(ctx context.Context, params *s3.GetObjectInput, optFns ...func(*s3.Options)) (*s3.GetObjectOutput, error) {
Expand All @@ -134,3 +203,11 @@ func (m *mockS3) GetBucketPolicy(ctx context.Context, params *s3.GetBucketPolicy

return m.getObjectPolicyOutput, nil
}

func (m *mockS3) ListBuckets(ctx context.Context, params *s3.ListBucketsInput, optFns ...func(*s3.Options)) (*s3.ListBucketsOutput, error) {
if m.listRolesErr != nil {
return nil, m.listRolesErr
}

return m.listRolesOutput, nil
}
31 changes: 31 additions & 0 deletions backend/service/aws/s3control.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,9 @@ import (

"github.com/aws/aws-sdk-go-v2/aws"
"github.com/aws/aws-sdk-go-v2/service/s3control"
"google.golang.org/protobuf/types/known/timestamppb"

s3v1 "github.com/lyft/clutch/backend/api/aws/s3/v1"
)

func (c *client) S3GetAccessPointPolicy(ctx context.Context, account, region, accessPointName, accountID string) (*s3control.GetAccessPointPolicyOutput, error) {
Expand All @@ -20,3 +23,31 @@ func (c *client) S3GetAccessPointPolicy(ctx context.Context, account, region, ac

return cl.s3control.GetAccessPointPolicy(ctx, in)
}

func (c *client) S3GetAccessPoint(ctx context.Context, account, region, accessPointName, accountId string) (*s3v1.AccessPoint, error) {
cl, err := c.getAccountRegionClient(account, region)
if err != nil {
return nil, err
}

in := &s3control.GetAccessPointInput{
Name: aws.String(accessPointName),
AccountId: aws.String(accountId),
}

out, err := cl.s3control.GetAccessPoint(ctx, in)
if err != nil {
return nil, err
}

return &s3v1.AccessPoint{
Name: *out.Name,
AccessPointArn: *out.AccessPointArn,
Bucket: *out.Bucket,
Alias: *out.Alias,
BucketAccountId: *out.BucketAccountId,
CreationDate: timestamppb.New(*out.CreationDate),
Account: account,
Region: region,
}, nil
}
Loading

0 comments on commit ee9d920

Please sign in to comment.