From b88a8d1f793e192d3dbd9d803f65ce09c8599a5f Mon Sep 17 00:00:00 2001 From: Edoardo Spadolini Date: Sun, 2 Jun 2024 11:51:01 +0200 Subject: [PATCH 1/2] Allow athenaevents to bypass SNS --- examples/dynamoathenamigration/migration.go | 5 +- lib/events/athena/athena.go | 13 +- lib/events/athena/athena_test.go | 6 +- lib/events/athena/fakequeue_test.go | 56 +++--- lib/events/athena/publisher.go | 198 ++++++++++++++------ lib/events/athena/publisher_test.go | 16 +- lib/events/athena/test.go | 13 +- 7 files changed, 201 insertions(+), 106 deletions(-) diff --git a/examples/dynamoathenamigration/migration.go b/examples/dynamoathenamigration/migration.go index 8450798bd82af..af9d39ee5585a 100644 --- a/examples/dynamoathenamigration/migration.go +++ b/examples/dynamoathenamigration/migration.go @@ -165,15 +165,14 @@ func newMigrateTask(ctx context.Context, cfg Config, awsCfg aws.Config) (*task, dynamoClient: dynamodb.NewFromConfig(awsCfg), s3Downloader: manager.NewDownloader(s3Client), eventsEmitter: athena.NewPublisher(athena.PublisherConfig{ - TopicARN: cfg.TopicARN, - SNSPublisher: sns.NewFromConfig(awsCfg, func(o *sns.Options) { + MessagePublisher: athena.SNSPublisherFunc(cfg.TopicARN, sns.NewFromConfig(awsCfg, func(o *sns.Options) { o.Retryer = retry.NewStandard(func(so *retry.StandardOptions) { so.MaxAttempts = 30 so.MaxBackoff = 1 * time.Minute // Use bigger rate limit to handle default sdk throttling: https://github.com/aws/aws-sdk-go-v2/issues/1665 so.RateLimiter = ratelimit.NewTokenRateLimit(1000000) }) - }), + })), Uploader: manager.NewUploader(s3Client), PayloadBucket: cfg.LargePayloadBucket, PayloadPrefix: cfg.LargePayloadPrefix, diff --git a/lib/events/athena/athena.go b/lib/events/athena/athena.go index 7c1231d69622e..984f6d5d67b97 100644 --- a/lib/events/athena/athena.go +++ b/lib/events/athena/athena.go @@ -49,6 +49,11 @@ const ( defaultBatchItems = 20000 // defaultBatchInterval defines default batch interval. defaultBatchInterval = 1 * time.Minute + + // topicARNBypass is a magic value for TopicARN that signifies that the + // Athena audit log should send messages directly to SQS instead of going + // through a SNS topic. + topicARNBypass = "bypass" ) // Config structure represents Athena configuration. @@ -59,7 +64,9 @@ type Config struct { // Publisher settings. - // TopicARN where to emit events in SNS (required). + // TopicARN where to emit events in SNS (required). If TopicARN is "bypass" + // (i.e. [topicArnBypass]) then the events should be emitted directly to the + // SQS queue reachable at QueryURL. TopicARN string // LargeEventsS3 is location on S3 where temporary large events (>256KB) // are stored before converting it to Parquet and moving to long term @@ -103,7 +110,9 @@ type Config struct { // Batcher settings. - // QueueURL is URL of SQS, which is set as subscriber to SNS topic (required). + // QueueURL is URL of SQS, which is set as subscriber to SNS topic if we're + // emitting to SNS, or used directly to send messages if we're bypassing SNS + // (required). QueueURL string // BatchMaxItems defines how many items can be stored in single Parquet // batch (optional). diff --git a/lib/events/athena/athena_test.go b/lib/events/athena/athena_test.go index 533065724b507..fb908c5aeb3de 100644 --- a/lib/events/athena/athena_test.go +++ b/lib/events/athena/athena_test.go @@ -331,7 +331,7 @@ func TestPublisherConsumer(t *testing.T) { ID: uuid.NewString(), Time: time.Now().UTC(), Type: events.AppCreateEvent, - Code: strings.Repeat("d", 2*maxSNSMessageSize), + Code: strings.Repeat("d", 2*maxDirectMessageSize), }, AppMetadata: apievents.AppMetadata{ AppName: "app-large", @@ -414,8 +414,8 @@ func TestPublisherConsumer(t *testing.T) { fq := newFakeQueue() p := &publisher{ PublisherConfig: PublisherConfig{ - SNSPublisher: fq, - Uploader: fS3, + MessagePublisher: fq, + Uploader: fS3, }, } cfg := validCollectCfgForTests(t) diff --git a/lib/events/athena/fakequeue_test.go b/lib/events/athena/fakequeue_test.go index f594162071015..eb982adba3d5c 100644 --- a/lib/events/athena/fakequeue_test.go +++ b/lib/events/athena/fakequeue_test.go @@ -19,10 +19,8 @@ import ( "sync" "github.com/aws/aws-sdk-go-v2/aws" - "github.com/aws/aws-sdk-go-v2/service/sns" - snsTypes "github.com/aws/aws-sdk-go-v2/service/sns/types" "github.com/aws/aws-sdk-go-v2/service/sqs" - sqsTypes "github.com/aws/aws-sdk-go-v2/service/sqs/types" + sqstypes "github.com/aws/aws-sdk-go-v2/service/sqs/types" "github.com/google/uuid" ) @@ -38,27 +36,27 @@ type fakeQueue struct { } type fakeQueueMessage struct { - payload string - attributes map[string]snsTypes.MessageAttributeValue + payload string + s3Based bool } func newFakeQueue() *fakeQueue { return &fakeQueue{} } -func (f *fakeQueue) Publish(ctx context.Context, params *sns.PublishInput, optFns ...func(*sns.Options)) (*sns.PublishOutput, error) { +func (f *fakeQueue) Publish(ctx context.Context, base64Body string, s3Based bool) error { f.mu.Lock() defer f.mu.Unlock() if len(f.publishErrors) > 0 { err := f.publishErrors[0] f.publishErrors = f.publishErrors[1:] - return nil, err + return err } f.msgs = append(f.msgs, fakeQueueMessage{ - payload: *params.Message, - attributes: params.MessageAttributes, + payload: base64Body, + s3Based: s3Based, }) - return nil, nil + return nil } func (f *fakeQueue) ReceiveMessage(ctx context.Context, params *sqs.ReceiveMessageInput, optFns ...func(*sqs.Options)) (*sqs.ReceiveMessageOutput, error) { @@ -66,11 +64,27 @@ func (f *fakeQueue) ReceiveMessage(ctx context.Context, params *sqs.ReceiveMessa if len(msgs) == 0 { return &sqs.ReceiveMessageOutput{}, nil } - out := make([]sqsTypes.Message, 0, 10) + out := make([]sqstypes.Message, 0, len(msgs)) for _, msg := range msgs { - out = append(out, sqsTypes.Message{ - Body: aws.String(msg.payload), - MessageAttributes: snsToSqsAttributes(msg.attributes), + var messageAttributes map[string]sqstypes.MessageAttributeValue + if msg.s3Based { + messageAttributes = map[string]sqstypes.MessageAttributeValue{ + payloadTypeAttr: { + DataType: aws.String("String"), + StringValue: aws.String(payloadTypeS3Based), + }, + } + } else { + messageAttributes = map[string]sqstypes.MessageAttributeValue{ + payloadTypeAttr: { + DataType: aws.String("String"), + StringValue: aws.String(payloadTypeRawProtoEvent), + }, + } + } + out = append(out, sqstypes.Message{ + Body: &msg.payload, + MessageAttributes: messageAttributes, ReceiptHandle: aws.String(uuid.NewString()), }) } @@ -79,20 +93,6 @@ func (f *fakeQueue) ReceiveMessage(ctx context.Context, params *sqs.ReceiveMessa }, nil } -func snsToSqsAttributes(in map[string]snsTypes.MessageAttributeValue) map[string]sqsTypes.MessageAttributeValue { - if in == nil { - return nil - } - out := map[string]sqsTypes.MessageAttributeValue{} - for k, v := range in { - out[k] = sqsTypes.MessageAttributeValue{ - DataType: v.DataType, - StringValue: v.StringValue, - } - } - return out -} - func (f *fakeQueue) dequeue() []fakeQueueMessage { f.mu.Lock() defer f.mu.Unlock() diff --git a/lib/events/athena/publisher.go b/lib/events/athena/publisher.go index 333bd5f0595c6..ef53e464476bd 100644 --- a/lib/events/athena/publisher.go +++ b/lib/events/athena/publisher.go @@ -18,19 +18,25 @@ import ( "bytes" "context" "encoding/base64" - "path/filepath" + "net/http" + "path" "time" "github.com/aws/aws-sdk-go-v2/aws" - "github.com/aws/aws-sdk-go-v2/aws/retry" - "github.com/aws/aws-sdk-go-v2/feature/s3/manager" + awsratelimit "github.com/aws/aws-sdk-go-v2/aws/ratelimit" + awsretry "github.com/aws/aws-sdk-go-v2/aws/retry" + awshttp "github.com/aws/aws-sdk-go-v2/aws/transport/http" + s3manager "github.com/aws/aws-sdk-go-v2/feature/s3/manager" "github.com/aws/aws-sdk-go-v2/service/s3" "github.com/aws/aws-sdk-go-v2/service/sns" - snsTypes "github.com/aws/aws-sdk-go-v2/service/sns/types" + snstypes "github.com/aws/aws-sdk-go-v2/service/sns/types" + "github.com/aws/aws-sdk-go-v2/service/sqs" + sqstypes "github.com/aws/aws-sdk-go-v2/service/sqs/types" "github.com/google/uuid" "github.com/gravitational/trace" apievents "github.com/gravitational/teleport/api/types/events" + "github.com/gravitational/teleport/lib/defaults" "github.com/gravitational/teleport/lib/events" "github.com/gravitational/teleport/lib/internal/context121" ) @@ -40,19 +46,18 @@ const ( payloadTypeRawProtoEvent = "raw_proto_event" payloadTypeS3Based = "s3_event" - // maxSNSMessageSize defines maximum size of SNS message. AWS allows 256KB - // however it counts also headers. We round it to 250KB, just to be sure. - maxSNSMessageSize = 250 * 1024 + // maxDirectMessageSize defines maximum size of SNS/SQS message. AWS allows + // 256KB however it counts also headers. We round it to 250KB, just to be + // sure. + maxDirectMessageSize = 250 * 1024 ) -var ( - // maxS3BasedSize defines some resonable threshold for S3 based messages - // (almost 2GiB but fits in an int). - // - // It's a var instead of const so tests can override it instead of casually - // allocating 2GiB. - maxS3BasedSize = 2*1024*1024*1024 - 1 -) +// maxS3BasedSize defines some resonable threshold for S3 based messages +// (almost 2GiB but fits in an int). +// +// It's a var instead of const so tests can override it instead of casually +// allocating 2GiB. +var maxS3BasedSize = 2*1024*1024*1024 - 1 // publisher is a SNS based events publisher. // It publishes proto events directly to SNS topic, or use S3 bucket @@ -61,20 +66,89 @@ type publisher struct { PublisherConfig } -type snsPublisher interface { - Publish(ctx context.Context, params *sns.PublishInput, optFns ...func(*sns.Options)) (*sns.PublishOutput, error) +type messagePublisher interface { + // Publish sends a message with a given body to a notification topic or a + // queue (or something similar), with added metadata to signify whether or + // not the message is only a reference to a S3 object or a full message. + Publish(ctx context.Context, base64Body string, s3Based bool) error +} + +type messagePublisherFunc func(ctx context.Context, base64Body string, s3Based bool) error + +// Publish implements [messagePublisher]. +func (f messagePublisherFunc) Publish(ctx context.Context, base64Body string, s3Based bool) error { + return f(ctx, base64Body, s3Based) +} + +// SNSPublisherFunc returns a message publisher that sends messages to a SNS +// topic through the given SNS client. +func SNSPublisherFunc(topicARN string, snsClient *sns.Client) messagePublisherFunc { + return func(ctx context.Context, base64Body string, s3Based bool) error { + var messageAttributes map[string]snstypes.MessageAttributeValue + if s3Based { + messageAttributes = map[string]snstypes.MessageAttributeValue{ + payloadTypeAttr: { + DataType: aws.String("String"), + StringValue: aws.String(payloadTypeS3Based), + }, + } + } else { + messageAttributes = map[string]snstypes.MessageAttributeValue{ + payloadTypeAttr: { + DataType: aws.String("String"), + StringValue: aws.String(payloadTypeRawProtoEvent), + }, + } + } + + _, err := snsClient.Publish(ctx, &sns.PublishInput{ + TopicArn: &topicARN, + Message: &base64Body, + MessageAttributes: messageAttributes, + }) + return trace.Wrap(err) + } +} + +// SQSPublisherFunc returns a message publisher that sends messages to a SQS +// queue through the given SQS client. +func SQSPublisherFunc(queueURL string, sqsClient *sqs.Client) messagePublisherFunc { + return func(ctx context.Context, base64Body string, s3Based bool) error { + var messageAttributes map[string]sqstypes.MessageAttributeValue + if s3Based { + messageAttributes = map[string]sqstypes.MessageAttributeValue{ + payloadTypeAttr: { + DataType: aws.String("String"), + StringValue: aws.String(payloadTypeS3Based), + }, + } + } else { + messageAttributes = map[string]sqstypes.MessageAttributeValue{ + payloadTypeAttr: { + DataType: aws.String("String"), + StringValue: aws.String(payloadTypeRawProtoEvent), + }, + } + } + + _, err := sqsClient.SendMessage(ctx, &sqs.SendMessageInput{ + QueueUrl: &queueURL, + MessageBody: &base64Body, + MessageAttributes: messageAttributes, + }) + return trace.Wrap(err) + } } type s3uploader interface { - Upload(ctx context.Context, input *s3.PutObjectInput, opts ...func(*manager.Uploader)) (*manager.UploadOutput, error) + Upload(ctx context.Context, input *s3.PutObjectInput, opts ...func(*s3manager.Uploader)) (*s3manager.UploadOutput, error) } type PublisherConfig struct { - TopicARN string - SNSPublisher snsPublisher - Uploader s3uploader - PayloadBucket string - PayloadPrefix string + MessagePublisher messagePublisher + Uploader s3uploader + PayloadBucket string + PayloadPrefix string } // NewPublisher returns new instance of publisher. @@ -87,17 +161,37 @@ func NewPublisher(cfg PublisherConfig) *publisher { // newPublisherFromAthenaConfig returns new instance of publisher from athena // config. func newPublisherFromAthenaConfig(cfg Config) *publisher { - r := retry.NewStandard(func(so *retry.StandardOptions) { + r := awsretry.NewStandard(func(so *awsretry.StandardOptions) { so.MaxAttempts = 20 so.MaxBackoff = 1 * time.Minute + // failure to do an API call likely means that we've just lost data, so + // let's just have the server bounce us back repeatedly rather than give + // up in the client + so.RateLimiter = awsratelimit.None }) - return NewPublisher(PublisherConfig{ - TopicARN: cfg.TopicARN, - SNSPublisher: sns.NewFromConfig(*cfg.PublisherConsumerAWSConfig, func(o *sns.Options) { + hc := awshttp.NewBuildableClient().WithTransportOptions(func(t *http.Transport) { + // aggressively reuse connections for the sake of avoiding TLS + // handshakes (the default MaxIdleConnsPerHost is a pitiful 2) + t.MaxIdleConns = defaults.HTTPMaxIdleConns + t.MaxIdleConnsPerHost = defaults.HTTPMaxIdleConnsPerHost + }) + var messagePublisher messagePublisherFunc + if cfg.TopicARN == topicARNBypass { + messagePublisher = SQSPublisherFunc(cfg.QueueURL, sqs.NewFromConfig(*cfg.PublisherConsumerAWSConfig, func(o *sqs.Options) { + o.Retryer = r + o.HTTPClient = hc + })) + } else { + messagePublisher = SNSPublisherFunc(cfg.TopicARN, sns.NewFromConfig(*cfg.PublisherConsumerAWSConfig, func(o *sns.Options) { o.Retryer = r - }), + o.HTTPClient = hc + })) + } + + return NewPublisher(PublisherConfig{ + MessagePublisher: messagePublisher, // TODO(tobiaszheller): consider reworking lib/observability to work also on s3 sdk-v2. - Uploader: manager.NewUploader(s3.NewFromConfig(*cfg.PublisherConsumerAWSConfig)), + Uploader: s3manager.NewUploader(s3.NewFromConfig(*cfg.PublisherConsumerAWSConfig)), PayloadBucket: cfg.largeEventsBucket, PayloadPrefix: cfg.largeEventsPrefix, }) @@ -147,57 +241,39 @@ func (p *publisher) EmitAuditEvent(ctx context.Context, in apievents.AuditEvent) return trace.Wrap(err) } - b64Encoded := base64.StdEncoding.EncodeToString(marshaledProto) - if len(b64Encoded) > maxSNSMessageSize { - if len(b64Encoded) > maxS3BasedSize { - return trace.BadParameter("message too large to publish, size %d", len(b64Encoded)) + base64Len := base64.StdEncoding.EncodedLen(len(marshaledProto)) + if base64Len > maxDirectMessageSize { + if base64Len > maxS3BasedSize { + return trace.BadParameter("message too large to publish, size %d", base64Len) } return trace.Wrap(p.emitViaS3(ctx, in.GetID(), marshaledProto)) } - return trace.Wrap(p.emitViaSNS(ctx, in.GetID(), b64Encoded)) + base64Body := base64.StdEncoding.EncodeToString(marshaledProto) + const s3BasedFalse = false + return trace.Wrap(p.MessagePublisher.Publish(ctx, base64Body, s3BasedFalse)) } func (p *publisher) emitViaS3(ctx context.Context, uid string, marshaledEvent []byte) error { - path := filepath.Join(p.PayloadPrefix, uid) + path := path.Join(p.PayloadPrefix, uid) out, err := p.Uploader.Upload(ctx, &s3.PutObjectInput{ - Bucket: aws.String(p.PayloadBucket), - Key: aws.String(path), + Bucket: &p.PayloadBucket, + Key: &path, Body: bytes.NewBuffer(marshaledEvent), }) if err != nil { return trace.Wrap(err) } - var versionID string - if out.VersionID != nil { - versionID = *out.VersionID - } msg := &apievents.AthenaS3EventPayload{ Path: path, - VersionId: versionID, + VersionId: aws.ToString(out.VersionID), } buf, err := msg.Marshal() if err != nil { return trace.Wrap(err) } - _, err = p.SNSPublisher.Publish(ctx, &sns.PublishInput{ - TopicArn: aws.String(p.TopicARN), - Message: aws.String(base64.StdEncoding.EncodeToString(buf)), - MessageAttributes: map[string]snsTypes.MessageAttributeValue{ - payloadTypeAttr: {DataType: aws.String("String"), StringValue: aws.String(payloadTypeS3Based)}, - }, - }) - return trace.Wrap(err) -} - -func (p *publisher) emitViaSNS(ctx context.Context, uid string, b64Encoded string) error { - _, err := p.SNSPublisher.Publish(ctx, &sns.PublishInput{ - TopicArn: aws.String(p.TopicARN), - Message: aws.String(b64Encoded), - MessageAttributes: map[string]snsTypes.MessageAttributeValue{ - payloadTypeAttr: {DataType: aws.String("String"), StringValue: aws.String(payloadTypeRawProtoEvent)}, - }, - }) - return trace.Wrap(err) + base64Body := base64.StdEncoding.EncodeToString(buf) + const s3BasedTrue = true + return trace.Wrap(p.MessagePublisher.Publish(ctx, base64Body, s3BasedTrue)) } diff --git a/lib/events/athena/publisher_test.go b/lib/events/athena/publisher_test.go index 062f1f7415502..6c37ba8649589 100644 --- a/lib/events/athena/publisher_test.go +++ b/lib/events/athena/publisher_test.go @@ -32,7 +32,7 @@ import ( func init() { // Override maxS3BasedSize so we don't have to allocate 2GiB to test it. // Do this in init to avoid any race. - maxS3BasedSize = maxSNSMessageSize * 4 + maxS3BasedSize = maxDirectMessageSize * 4 } // TODO(tobiaszheller): Those UT just cover basic stuff. When we will have consumer @@ -57,7 +57,7 @@ func Test_EmitAuditEvent(t *testing.T) { }, wantCheck: func(t *testing.T, out []fakeQueueMessage) { require.Len(t, out, 1) - require.Contains(t, *out[0].attributes[payloadTypeAttr].StringValue, payloadTypeRawProtoEvent) + require.False(t, out[0].s3Based) }, }, { @@ -73,7 +73,7 @@ func Test_EmitAuditEvent(t *testing.T) { }, wantCheck: func(t *testing.T, out []fakeQueueMessage) { require.Len(t, out, 1) - require.Contains(t, *out[0].attributes[payloadTypeAttr].StringValue, payloadTypeRawProtoEvent) + require.False(t, out[0].s3Based) }, }, { @@ -82,13 +82,13 @@ func Test_EmitAuditEvent(t *testing.T) { Metadata: apievents.Metadata{ ID: uuid.NewString(), Time: time.Now().UTC(), - Code: strings.Repeat("d", 2*maxSNSMessageSize), + Code: strings.Repeat("d", 2*maxDirectMessageSize), }, }, uploader: mockUploader{}, wantCheck: func(t *testing.T, out []fakeQueueMessage) { require.Len(t, out, 1) - require.Contains(t, *out[0].attributes[payloadTypeAttr].StringValue, payloadTypeS3Based) + require.True(t, out[0].s3Based) }, }, { @@ -115,7 +115,7 @@ func Test_EmitAuditEvent(t *testing.T) { uploader: mockUploader{}, wantCheck: func(t *testing.T, out []fakeQueueMessage) { require.Len(t, out, 1) - require.Contains(t, *out[0].attributes[payloadTypeAttr].StringValue, payloadTypeS3Based) + require.True(t, out[0].s3Based) }, }, } @@ -124,8 +124,8 @@ func Test_EmitAuditEvent(t *testing.T) { fq := newFakeQueue() p := &publisher{ PublisherConfig: PublisherConfig{ - SNSPublisher: fq, - Uploader: tt.uploader, + MessagePublisher: fq, + Uploader: tt.uploader, }, } err := p.EmitAuditEvent(context.Background(), tt.in) diff --git a/lib/events/athena/test.go b/lib/events/athena/test.go index b18d5171b53fd..b0041f081d934 100644 --- a/lib/events/athena/test.go +++ b/lib/events/athena/test.go @@ -165,6 +165,13 @@ func SetupAthenaContext(t *testing.T, ctx context.Context, cfg AthenaContextConf if ok, _ := strconv.ParseBool(testEnabled); !ok { t.Skip("Skipping AWS-dependent test suite.") } + var bypassSNS bool + if s := os.Getenv(teleport.AWSRunTests + "_ATHENA_BYPASS_SNS"); s != "" { + if ok, _ := strconv.ParseBool(s); ok { + t.Log("bypassing SNS for Athena audit log events") + bypassSNS = true + } + } testID := fmt.Sprintf("auditlogs-integrationtests-%v", uuid.New().String()) @@ -196,12 +203,16 @@ func SetupAthenaContext(t *testing.T, ctx context.Context, cfg AthenaContextConf region = "eu-central-1" } + topicARN := infraOut.TopicARN + if bypassSNS { + topicARN = topicARNBypass + } log, err := New(ctx, Config{ Region: region, Clock: clock, Database: ac.Database, TableName: ac.TableName, - TopicARN: infraOut.TopicARN, + TopicARN: topicARN, QueueURL: infraOut.QueueURL, LocationS3: ac.s3eventsLocation, QueryResultsS3: ac.S3ResultsLocation, From 6d55206ababc373f9e07ba4d9175e4618c192fba Mon Sep 17 00:00:00 2001 From: Edoardo Spadolini Date: Tue, 4 Jun 2024 21:09:07 +0200 Subject: [PATCH 2/2] Double up the athena tests that hit AWS --- lib/events/athena/integration_test.go | 55 +++++++++++++++++++++++++-- lib/events/athena/test.go | 10 +---- 2 files changed, 53 insertions(+), 12 deletions(-) diff --git a/lib/events/athena/integration_test.go b/lib/events/athena/integration_test.go index cd1aed257ffca..9b6c6fa46d4e2 100644 --- a/lib/events/athena/integration_test.go +++ b/lib/events/athena/integration_test.go @@ -33,9 +33,20 @@ import ( ) func TestIntegrationAthenaSearchSessionEventsBySessionID(t *testing.T) { + t.Run("sns", func(t *testing.T) { + const bypassSNSFalse = false + testIntegrationAthenaSearchSessionEventsBySessionID(t, bypassSNSFalse) + }) + t.Run("sqs", func(t *testing.T) { + const bypassSNSTrue = true + testIntegrationAthenaSearchSessionEventsBySessionID(t, bypassSNSTrue) + }) +} + +func testIntegrationAthenaSearchSessionEventsBySessionID(t *testing.T, bypassSNS bool) { ctx, cancel := context.WithTimeout(context.Background(), 2*time.Minute) defer cancel() - ac := SetupAthenaContext(t, ctx, AthenaContextConfig{}) + ac := SetupAthenaContext(t, ctx, AthenaContextConfig{BypassSNS: bypassSNS}) auditLogger := &EventuallyConsistentAuditLogger{ Inner: ac.log, // Additional 5s is used to compensate for uploading parquet on s3. @@ -51,9 +62,20 @@ func TestIntegrationAthenaSearchSessionEventsBySessionID(t *testing.T) { } func TestIntegrationAthenaSessionEventsCRUD(t *testing.T) { + t.Run("sns", func(t *testing.T) { + const bypassSNSFalse = false + testIntegrationAthenaSessionEventsCRUD(t, bypassSNSFalse) + }) + t.Run("sqs", func(t *testing.T) { + const bypassSNSTrue = true + testIntegrationAthenaSessionEventsCRUD(t, bypassSNSTrue) + }) +} + +func testIntegrationAthenaSessionEventsCRUD(t *testing.T, bypassSNS bool) { ctx, cancel := context.WithTimeout(context.Background(), 2*time.Minute) defer cancel() - ac := SetupAthenaContext(t, ctx, AthenaContextConfig{}) + ac := SetupAthenaContext(t, ctx, AthenaContextConfig{BypassSNS: bypassSNS}) auditLogger := &EventuallyConsistentAuditLogger{ Inner: ac.log, // Additional 5s is used to compensate for uploading parquet on s3. @@ -68,9 +90,20 @@ func TestIntegrationAthenaSessionEventsCRUD(t *testing.T) { } func TestIntegrationAthenaEventPagination(t *testing.T) { + t.Run("sns", func(t *testing.T) { + const bypassSNSFalse = false + testIntegrationAthenaEventPagination(t, bypassSNSFalse) + }) + t.Run("sqs", func(t *testing.T) { + const bypassSNSTrue = true + testIntegrationAthenaEventPagination(t, bypassSNSTrue) + }) +} + +func testIntegrationAthenaEventPagination(t *testing.T, bypassSNS bool) { ctx, cancel := context.WithTimeout(context.Background(), 2*time.Minute) defer cancel() - ac := SetupAthenaContext(t, ctx, AthenaContextConfig{}) + ac := SetupAthenaContext(t, ctx, AthenaContextConfig{BypassSNS: bypassSNS}) auditLogger := &EventuallyConsistentAuditLogger{ Inner: ac.log, // Additional 5s is used to compensate for uploading parquet on s3. @@ -85,10 +118,24 @@ func TestIntegrationAthenaEventPagination(t *testing.T) { } func TestIntegrationAthenaLargeEvents(t *testing.T) { + t.Run("sns", func(t *testing.T) { + const bypassSNSFalse = false + testIntegrationAthenaLargeEvents(t, bypassSNSFalse) + }) + t.Run("sqs", func(t *testing.T) { + const bypassSNSTrue = true + testIntegrationAthenaLargeEvents(t, bypassSNSTrue) + }) +} + +func testIntegrationAthenaLargeEvents(t *testing.T, bypassSNS bool) { ctx, cancel := context.WithTimeout(context.Background(), 2*time.Minute) defer cancel() - ac := SetupAthenaContext(t, ctx, AthenaContextConfig{MaxBatchSize: 1}) + ac := SetupAthenaContext(t, ctx, AthenaContextConfig{ + MaxBatchSize: 1, + BypassSNS: bypassSNS, + }) in := &apievents.SessionStart{ Metadata: apievents.Metadata{ Index: 2, diff --git a/lib/events/athena/test.go b/lib/events/athena/test.go index b0041f081d934..ea2eec9e455fa 100644 --- a/lib/events/athena/test.go +++ b/lib/events/athena/test.go @@ -102,6 +102,7 @@ func (a *AthenaContext) GetLog() *Log { // AthenaContextConfig is optional config to override defaults in athena context. type AthenaContextConfig struct { MaxBatchSize int + BypassSNS bool } type InfraOutputs struct { @@ -165,13 +166,6 @@ func SetupAthenaContext(t *testing.T, ctx context.Context, cfg AthenaContextConf if ok, _ := strconv.ParseBool(testEnabled); !ok { t.Skip("Skipping AWS-dependent test suite.") } - var bypassSNS bool - if s := os.Getenv(teleport.AWSRunTests + "_ATHENA_BYPASS_SNS"); s != "" { - if ok, _ := strconv.ParseBool(s); ok { - t.Log("bypassing SNS for Athena audit log events") - bypassSNS = true - } - } testID := fmt.Sprintf("auditlogs-integrationtests-%v", uuid.New().String()) @@ -204,7 +198,7 @@ func SetupAthenaContext(t *testing.T, ctx context.Context, cfg AthenaContextConf } topicARN := infraOut.TopicARN - if bypassSNS { + if cfg.BypassSNS { topicARN = topicARNBypass } log, err := New(ctx, Config{