From 20d5cca42b26024124f90fdb134d3e501e51a053 Mon Sep 17 00:00:00 2001 From: Edoardo Spadolini Date: Sun, 2 Jun 2024 11:51:01 +0200 Subject: [PATCH] Allow athenaevents to bypass SNS --- examples/dynamoathenamigration/migration.go | 5 +- lib/events/athena/athena.go | 13 +- lib/events/athena/athena_test.go | 6 +- lib/events/athena/fakequeue_test.go | 56 +++--- lib/events/athena/publisher.go | 198 ++++++++++++++------ lib/events/athena/publisher_test.go | 16 +- lib/events/athena/test.go | 13 +- 7 files changed, 201 insertions(+), 106 deletions(-) diff --git a/examples/dynamoathenamigration/migration.go b/examples/dynamoathenamigration/migration.go index 28ec1af01bfea..5c3dff3943988 100644 --- a/examples/dynamoathenamigration/migration.go +++ b/examples/dynamoathenamigration/migration.go @@ -165,15 +165,14 @@ func newMigrateTask(ctx context.Context, cfg Config, awsCfg aws.Config) (*task, dynamoClient: dynamodb.NewFromConfig(awsCfg), s3Downloader: manager.NewDownloader(s3Client), eventsEmitter: athena.NewPublisher(athena.PublisherConfig{ - TopicARN: cfg.TopicARN, - SNSPublisher: sns.NewFromConfig(awsCfg, func(o *sns.Options) { + MessagePublisher: athena.SNSPublisherFunc(cfg.TopicARN, sns.NewFromConfig(awsCfg, func(o *sns.Options) { o.Retryer = retry.NewStandard(func(so *retry.StandardOptions) { so.MaxAttempts = 30 so.MaxBackoff = 1 * time.Minute // Use bigger rate limit to handle default sdk throttling: https://github.com/aws/aws-sdk-go-v2/issues/1665 so.RateLimiter = ratelimit.NewTokenRateLimit(1000000) }) - }), + })), Uploader: manager.NewUploader(s3Client), PayloadBucket: cfg.LargePayloadBucket, PayloadPrefix: cfg.LargePayloadPrefix, diff --git a/lib/events/athena/athena.go b/lib/events/athena/athena.go index 3d4da4c0ce105..1c5bbff7cc922 100644 --- a/lib/events/athena/athena.go +++ b/lib/events/athena/athena.go @@ -52,6 +52,11 @@ const ( defaultBatchItems = 20000 // defaultBatchInterval defines default batch interval. defaultBatchInterval = 1 * time.Minute + + // topicARNBypass is a magic value for TopicARN that signifies that the + // Athena audit log should send messages directly to SQS instead of going + // through a SNS topic. + topicARNBypass = "bypass" ) // Config structure represents Athena configuration. @@ -62,7 +67,9 @@ type Config struct { // Publisher settings. - // TopicARN where to emit events in SNS (required). + // TopicARN where to emit events in SNS (required). If TopicARN is "bypass" + // (i.e. [topicArnBypass]) then the events should be emitted directly to the + // SQS queue reachable at QueryURL. TopicARN string // LargeEventsS3 is location on S3 where temporary large events (>256KB) // are stored before converting it to Parquet and moving to long term @@ -106,7 +113,9 @@ type Config struct { // Batcher settings. - // QueueURL is URL of SQS, which is set as subscriber to SNS topic (required). + // QueueURL is URL of SQS, which is set as subscriber to SNS topic if we're + // emitting to SNS, or used directly to send messages if we're bypassing SNS + // (required). QueueURL string // BatchMaxItems defines how many items can be stored in single Parquet // batch (optional). diff --git a/lib/events/athena/athena_test.go b/lib/events/athena/athena_test.go index b3871d0d1b769..e0e52d8903b73 100644 --- a/lib/events/athena/athena_test.go +++ b/lib/events/athena/athena_test.go @@ -335,7 +335,7 @@ func TestPublisherConsumer(t *testing.T) { ID: uuid.NewString(), Time: time.Now().UTC(), Type: events.AppCreateEvent, - Code: strings.Repeat("d", 2*maxSNSMessageSize), + Code: strings.Repeat("d", 2*maxDirectMessageSize), }, AppMetadata: apievents.AppMetadata{ AppName: "app-large", @@ -418,8 +418,8 @@ func TestPublisherConsumer(t *testing.T) { fq := newFakeQueue() p := &publisher{ PublisherConfig: PublisherConfig{ - SNSPublisher: fq, - Uploader: fS3, + MessagePublisher: fq, + Uploader: fS3, }, } cfg := validCollectCfgForTests(t) diff --git a/lib/events/athena/fakequeue_test.go b/lib/events/athena/fakequeue_test.go index fbd3376ce91de..f3304fd280281 100644 --- a/lib/events/athena/fakequeue_test.go +++ b/lib/events/athena/fakequeue_test.go @@ -23,10 +23,8 @@ import ( "sync" "github.com/aws/aws-sdk-go-v2/aws" - "github.com/aws/aws-sdk-go-v2/service/sns" - snsTypes "github.com/aws/aws-sdk-go-v2/service/sns/types" "github.com/aws/aws-sdk-go-v2/service/sqs" - sqsTypes "github.com/aws/aws-sdk-go-v2/service/sqs/types" + sqstypes "github.com/aws/aws-sdk-go-v2/service/sqs/types" "github.com/google/uuid" ) @@ -42,27 +40,27 @@ type fakeQueue struct { } type fakeQueueMessage struct { - payload string - attributes map[string]snsTypes.MessageAttributeValue + payload string + s3Based bool } func newFakeQueue() *fakeQueue { return &fakeQueue{} } -func (f *fakeQueue) Publish(ctx context.Context, params *sns.PublishInput, optFns ...func(*sns.Options)) (*sns.PublishOutput, error) { +func (f *fakeQueue) Publish(ctx context.Context, base64Body string, s3Based bool) error { f.mu.Lock() defer f.mu.Unlock() if len(f.publishErrors) > 0 { err := f.publishErrors[0] f.publishErrors = f.publishErrors[1:] - return nil, err + return err } f.msgs = append(f.msgs, fakeQueueMessage{ - payload: *params.Message, - attributes: params.MessageAttributes, + payload: base64Body, + s3Based: s3Based, }) - return nil, nil + return nil } func (f *fakeQueue) ReceiveMessage(ctx context.Context, params *sqs.ReceiveMessageInput, optFns ...func(*sqs.Options)) (*sqs.ReceiveMessageOutput, error) { @@ -70,11 +68,27 @@ func (f *fakeQueue) ReceiveMessage(ctx context.Context, params *sqs.ReceiveMessa if len(msgs) == 0 { return &sqs.ReceiveMessageOutput{}, nil } - out := make([]sqsTypes.Message, 0, 10) + out := make([]sqstypes.Message, 0, len(msgs)) for _, msg := range msgs { - out = append(out, sqsTypes.Message{ - Body: aws.String(msg.payload), - MessageAttributes: snsToSqsAttributes(msg.attributes), + var messageAttributes map[string]sqstypes.MessageAttributeValue + if msg.s3Based { + messageAttributes = map[string]sqstypes.MessageAttributeValue{ + payloadTypeAttr: { + DataType: aws.String("String"), + StringValue: aws.String(payloadTypeS3Based), + }, + } + } else { + messageAttributes = map[string]sqstypes.MessageAttributeValue{ + payloadTypeAttr: { + DataType: aws.String("String"), + StringValue: aws.String(payloadTypeRawProtoEvent), + }, + } + } + out = append(out, sqstypes.Message{ + Body: &msg.payload, + MessageAttributes: messageAttributes, ReceiptHandle: aws.String(uuid.NewString()), }) } @@ -83,20 +97,6 @@ func (f *fakeQueue) ReceiveMessage(ctx context.Context, params *sqs.ReceiveMessa }, nil } -func snsToSqsAttributes(in map[string]snsTypes.MessageAttributeValue) map[string]sqsTypes.MessageAttributeValue { - if in == nil { - return nil - } - out := map[string]sqsTypes.MessageAttributeValue{} - for k, v := range in { - out[k] = sqsTypes.MessageAttributeValue{ - DataType: v.DataType, - StringValue: v.StringValue, - } - } - return out -} - func (f *fakeQueue) dequeue() []fakeQueueMessage { f.mu.Lock() defer f.mu.Unlock() diff --git a/lib/events/athena/publisher.go b/lib/events/athena/publisher.go index e53ab72b23120..54d75009003a8 100644 --- a/lib/events/athena/publisher.go +++ b/lib/events/athena/publisher.go @@ -22,19 +22,25 @@ import ( "bytes" "context" "encoding/base64" - "path/filepath" + "net/http" + "path" "time" "github.com/aws/aws-sdk-go-v2/aws" - "github.com/aws/aws-sdk-go-v2/aws/retry" - "github.com/aws/aws-sdk-go-v2/feature/s3/manager" + awsratelimit "github.com/aws/aws-sdk-go-v2/aws/ratelimit" + awsretry "github.com/aws/aws-sdk-go-v2/aws/retry" + awshttp "github.com/aws/aws-sdk-go-v2/aws/transport/http" + s3manager "github.com/aws/aws-sdk-go-v2/feature/s3/manager" "github.com/aws/aws-sdk-go-v2/service/s3" "github.com/aws/aws-sdk-go-v2/service/sns" - snsTypes "github.com/aws/aws-sdk-go-v2/service/sns/types" + snstypes "github.com/aws/aws-sdk-go-v2/service/sns/types" + "github.com/aws/aws-sdk-go-v2/service/sqs" + sqstypes "github.com/aws/aws-sdk-go-v2/service/sqs/types" "github.com/google/uuid" "github.com/gravitational/trace" apievents "github.com/gravitational/teleport/api/types/events" + "github.com/gravitational/teleport/lib/defaults" "github.com/gravitational/teleport/lib/events" ) @@ -43,19 +49,18 @@ const ( payloadTypeRawProtoEvent = "raw_proto_event" payloadTypeS3Based = "s3_event" - // maxSNSMessageSize defines maximum size of SNS message. AWS allows 256KB - // however it counts also headers. We round it to 250KB, just to be sure. - maxSNSMessageSize = 250 * 1024 + // maxDirectMessageSize defines maximum size of SNS/SQS message. AWS allows + // 256KB however it counts also headers. We round it to 250KB, just to be + // sure. + maxDirectMessageSize = 250 * 1024 ) -var ( - // maxS3BasedSize defines some resonable threshold for S3 based messages - // (almost 2GiB but fits in an int). - // - // It's a var instead of const so tests can override it instead of casually - // allocating 2GiB. - maxS3BasedSize = 2*1024*1024*1024 - 1 -) +// maxS3BasedSize defines some resonable threshold for S3 based messages +// (almost 2GiB but fits in an int). +// +// It's a var instead of const so tests can override it instead of casually +// allocating 2GiB. +var maxS3BasedSize = 2*1024*1024*1024 - 1 // publisher is a SNS based events publisher. // It publishes proto events directly to SNS topic, or use S3 bucket @@ -64,20 +69,89 @@ type publisher struct { PublisherConfig } -type snsPublisher interface { - Publish(ctx context.Context, params *sns.PublishInput, optFns ...func(*sns.Options)) (*sns.PublishOutput, error) +type messagePublisher interface { + // Publish sends a message with a given body to a notification topic or a + // queue (or something similar), with added metadata to signify whether or + // not the message is only a reference to a S3 object or a full message. + Publish(ctx context.Context, base64Body string, s3Based bool) error +} + +type messagePublisherFunc func(ctx context.Context, base64Body string, s3Based bool) error + +// Publish implements [messagePublisher]. +func (f messagePublisherFunc) Publish(ctx context.Context, base64Body string, s3Based bool) error { + return f(ctx, base64Body, s3Based) +} + +// SNSPublisherFunc returns a message publisher that sends messages to a SNS +// topic through the given SNS client. +func SNSPublisherFunc(topicARN string, snsClient *sns.Client) messagePublisherFunc { + return func(ctx context.Context, base64Body string, s3Based bool) error { + var messageAttributes map[string]snstypes.MessageAttributeValue + if s3Based { + messageAttributes = map[string]snstypes.MessageAttributeValue{ + payloadTypeAttr: { + DataType: aws.String("String"), + StringValue: aws.String(payloadTypeS3Based), + }, + } + } else { + messageAttributes = map[string]snstypes.MessageAttributeValue{ + payloadTypeAttr: { + DataType: aws.String("String"), + StringValue: aws.String(payloadTypeRawProtoEvent), + }, + } + } + + _, err := snsClient.Publish(ctx, &sns.PublishInput{ + TopicArn: &topicARN, + Message: &base64Body, + MessageAttributes: messageAttributes, + }) + return trace.Wrap(err) + } +} + +// SQSPublisherFunc returns a message publisher that sends messages to a SQS +// queue through the given SQS client. +func SQSPublisherFunc(queueURL string, sqsClient *sqs.Client) messagePublisherFunc { + return func(ctx context.Context, base64Body string, s3Based bool) error { + var messageAttributes map[string]sqstypes.MessageAttributeValue + if s3Based { + messageAttributes = map[string]sqstypes.MessageAttributeValue{ + payloadTypeAttr: { + DataType: aws.String("String"), + StringValue: aws.String(payloadTypeS3Based), + }, + } + } else { + messageAttributes = map[string]sqstypes.MessageAttributeValue{ + payloadTypeAttr: { + DataType: aws.String("String"), + StringValue: aws.String(payloadTypeRawProtoEvent), + }, + } + } + + _, err := sqsClient.SendMessage(ctx, &sqs.SendMessageInput{ + QueueUrl: &queueURL, + MessageBody: &base64Body, + MessageAttributes: messageAttributes, + }) + return trace.Wrap(err) + } } type s3uploader interface { - Upload(ctx context.Context, input *s3.PutObjectInput, opts ...func(*manager.Uploader)) (*manager.UploadOutput, error) + Upload(ctx context.Context, input *s3.PutObjectInput, opts ...func(*s3manager.Uploader)) (*s3manager.UploadOutput, error) } type PublisherConfig struct { - TopicARN string - SNSPublisher snsPublisher - Uploader s3uploader - PayloadBucket string - PayloadPrefix string + MessagePublisher messagePublisher + Uploader s3uploader + PayloadBucket string + PayloadPrefix string } // NewPublisher returns new instance of publisher. @@ -90,17 +164,37 @@ func NewPublisher(cfg PublisherConfig) *publisher { // newPublisherFromAthenaConfig returns new instance of publisher from athena // config. func newPublisherFromAthenaConfig(cfg Config) *publisher { - r := retry.NewStandard(func(so *retry.StandardOptions) { + r := awsretry.NewStandard(func(so *awsretry.StandardOptions) { so.MaxAttempts = 20 so.MaxBackoff = 1 * time.Minute + // failure to do an API call likely means that we've just lost data, so + // let's just have the server bounce us back repeatedly rather than give + // up in the client + so.RateLimiter = awsratelimit.None }) - return NewPublisher(PublisherConfig{ - TopicARN: cfg.TopicARN, - SNSPublisher: sns.NewFromConfig(*cfg.PublisherConsumerAWSConfig, func(o *sns.Options) { + hc := awshttp.NewBuildableClient().WithTransportOptions(func(t *http.Transport) { + // aggressively reuse connections for the sake of avoiding TLS + // handshakes (the default MaxIdleConnsPerHost is a pitiful 2) + t.MaxIdleConns = defaults.HTTPMaxIdleConns + t.MaxIdleConnsPerHost = defaults.HTTPMaxIdleConnsPerHost + }) + var messagePublisher messagePublisherFunc + if cfg.TopicARN == topicARNBypass { + messagePublisher = SQSPublisherFunc(cfg.QueueURL, sqs.NewFromConfig(*cfg.PublisherConsumerAWSConfig, func(o *sqs.Options) { + o.Retryer = r + o.HTTPClient = hc + })) + } else { + messagePublisher = SNSPublisherFunc(cfg.TopicARN, sns.NewFromConfig(*cfg.PublisherConsumerAWSConfig, func(o *sns.Options) { o.Retryer = r - }), + o.HTTPClient = hc + })) + } + + return NewPublisher(PublisherConfig{ + MessagePublisher: messagePublisher, // TODO(tobiaszheller): consider reworking lib/observability to work also on s3 sdk-v2. - Uploader: manager.NewUploader(s3.NewFromConfig(*cfg.PublisherConsumerAWSConfig)), + Uploader: s3manager.NewUploader(s3.NewFromConfig(*cfg.PublisherConsumerAWSConfig)), PayloadBucket: cfg.largeEventsBucket, PayloadPrefix: cfg.largeEventsPrefix, }) @@ -150,57 +244,39 @@ func (p *publisher) EmitAuditEvent(ctx context.Context, in apievents.AuditEvent) return trace.Wrap(err) } - b64Encoded := base64.StdEncoding.EncodeToString(marshaledProto) - if len(b64Encoded) > maxSNSMessageSize { - if len(b64Encoded) > maxS3BasedSize { - return trace.BadParameter("message too large to publish, size %d", len(b64Encoded)) + base64Len := base64.StdEncoding.EncodedLen(len(marshaledProto)) + if base64Len > maxDirectMessageSize { + if base64Len > maxS3BasedSize { + return trace.BadParameter("message too large to publish, size %d", base64Len) } return trace.Wrap(p.emitViaS3(ctx, in.GetID(), marshaledProto)) } - return trace.Wrap(p.emitViaSNS(ctx, in.GetID(), b64Encoded)) + base64Body := base64.StdEncoding.EncodeToString(marshaledProto) + const s3BasedFalse = false + return trace.Wrap(p.MessagePublisher.Publish(ctx, base64Body, s3BasedFalse)) } func (p *publisher) emitViaS3(ctx context.Context, uid string, marshaledEvent []byte) error { - path := filepath.Join(p.PayloadPrefix, uid) + path := path.Join(p.PayloadPrefix, uid) out, err := p.Uploader.Upload(ctx, &s3.PutObjectInput{ - Bucket: aws.String(p.PayloadBucket), - Key: aws.String(path), + Bucket: &p.PayloadBucket, + Key: &path, Body: bytes.NewBuffer(marshaledEvent), }) if err != nil { return trace.Wrap(err) } - var versionID string - if out.VersionID != nil { - versionID = *out.VersionID - } msg := &apievents.AthenaS3EventPayload{ Path: path, - VersionId: versionID, + VersionId: aws.ToString(out.VersionID), } buf, err := msg.Marshal() if err != nil { return trace.Wrap(err) } - _, err = p.SNSPublisher.Publish(ctx, &sns.PublishInput{ - TopicArn: aws.String(p.TopicARN), - Message: aws.String(base64.StdEncoding.EncodeToString(buf)), - MessageAttributes: map[string]snsTypes.MessageAttributeValue{ - payloadTypeAttr: {DataType: aws.String("String"), StringValue: aws.String(payloadTypeS3Based)}, - }, - }) - return trace.Wrap(err) -} - -func (p *publisher) emitViaSNS(ctx context.Context, uid string, b64Encoded string) error { - _, err := p.SNSPublisher.Publish(ctx, &sns.PublishInput{ - TopicArn: aws.String(p.TopicARN), - Message: aws.String(b64Encoded), - MessageAttributes: map[string]snsTypes.MessageAttributeValue{ - payloadTypeAttr: {DataType: aws.String("String"), StringValue: aws.String(payloadTypeRawProtoEvent)}, - }, - }) - return trace.Wrap(err) + base64Body := base64.StdEncoding.EncodeToString(buf) + const s3BasedTrue = true + return trace.Wrap(p.MessagePublisher.Publish(ctx, base64Body, s3BasedTrue)) } diff --git a/lib/events/athena/publisher_test.go b/lib/events/athena/publisher_test.go index 83f569249721a..5348c38515b05 100644 --- a/lib/events/athena/publisher_test.go +++ b/lib/events/athena/publisher_test.go @@ -36,7 +36,7 @@ import ( func init() { // Override maxS3BasedSize so we don't have to allocate 2GiB to test it. // Do this in init to avoid any race. - maxS3BasedSize = maxSNSMessageSize * 4 + maxS3BasedSize = maxDirectMessageSize * 4 } // TODO(tobiaszheller): Those UT just cover basic stuff. When we will have consumer @@ -61,7 +61,7 @@ func Test_EmitAuditEvent(t *testing.T) { }, wantCheck: func(t *testing.T, out []fakeQueueMessage) { require.Len(t, out, 1) - require.Contains(t, *out[0].attributes[payloadTypeAttr].StringValue, payloadTypeRawProtoEvent) + require.False(t, out[0].s3Based) }, }, { @@ -77,7 +77,7 @@ func Test_EmitAuditEvent(t *testing.T) { }, wantCheck: func(t *testing.T, out []fakeQueueMessage) { require.Len(t, out, 1) - require.Contains(t, *out[0].attributes[payloadTypeAttr].StringValue, payloadTypeRawProtoEvent) + require.False(t, out[0].s3Based) }, }, { @@ -86,13 +86,13 @@ func Test_EmitAuditEvent(t *testing.T) { Metadata: apievents.Metadata{ ID: uuid.NewString(), Time: time.Now().UTC(), - Code: strings.Repeat("d", 2*maxSNSMessageSize), + Code: strings.Repeat("d", 2*maxDirectMessageSize), }, }, uploader: mockUploader{}, wantCheck: func(t *testing.T, out []fakeQueueMessage) { require.Len(t, out, 1) - require.Contains(t, *out[0].attributes[payloadTypeAttr].StringValue, payloadTypeS3Based) + require.True(t, out[0].s3Based) }, }, { @@ -119,7 +119,7 @@ func Test_EmitAuditEvent(t *testing.T) { uploader: mockUploader{}, wantCheck: func(t *testing.T, out []fakeQueueMessage) { require.Len(t, out, 1) - require.Contains(t, *out[0].attributes[payloadTypeAttr].StringValue, payloadTypeS3Based) + require.True(t, out[0].s3Based) }, }, } @@ -128,8 +128,8 @@ func Test_EmitAuditEvent(t *testing.T) { fq := newFakeQueue() p := &publisher{ PublisherConfig: PublisherConfig{ - SNSPublisher: fq, - Uploader: tt.uploader, + MessagePublisher: fq, + Uploader: tt.uploader, }, } err := p.EmitAuditEvent(context.Background(), tt.in) diff --git a/lib/events/athena/test.go b/lib/events/athena/test.go index 5580afccd89f0..a93b9ff86f6ca 100644 --- a/lib/events/athena/test.go +++ b/lib/events/athena/test.go @@ -169,6 +169,13 @@ func SetupAthenaContext(t *testing.T, ctx context.Context, cfg AthenaContextConf if ok, _ := strconv.ParseBool(testEnabled); !ok { t.Skip("Skipping AWS-dependent test suite.") } + var bypassSNS bool + if s := os.Getenv(teleport.AWSRunTests + "_ATHENA_BYPASS_SNS"); s != "" { + if ok, _ := strconv.ParseBool(s); ok { + t.Log("bypassing SNS for Athena audit log events") + bypassSNS = true + } + } testID := fmt.Sprintf("auditlogs-integrationtests-%v", uuid.New().String()) @@ -200,12 +207,16 @@ func SetupAthenaContext(t *testing.T, ctx context.Context, cfg AthenaContextConf region = "eu-central-1" } + topicARN := infraOut.TopicARN + if bypassSNS { + topicARN = topicARNBypass + } log, err := New(ctx, Config{ Region: region, Clock: clock, Database: ac.Database, TableName: ac.TableName, - TopicARN: infraOut.TopicARN, + TopicARN: topicARN, QueueURL: infraOut.QueueURL, LocationS3: ac.s3eventsLocation, QueryResultsS3: ac.S3ResultsLocation,