diff --git a/contrib/aws/internal/eventbridge/eventbridge.go b/contrib/aws/internal/eventbridge/eventbridge.go index e82f4ecbb1..4ef556d04b 100644 --- a/contrib/aws/internal/eventbridge/eventbridge.go +++ b/contrib/aws/internal/eventbridge/eventbridge.go @@ -7,13 +7,13 @@ package eventbridge import ( "encoding/json" + "fmt" "github.com/aws/aws-sdk-go-v2/aws" "github.com/aws/aws-sdk-go-v2/service/eventbridge" "github.com/aws/aws-sdk-go-v2/service/eventbridge/types" "github.com/aws/smithy-go/middleware" "gopkg.in/DataDog/dd-trace-go.v1/ddtrace/tracer" "gopkg.in/DataDog/dd-trace-go.v1/internal/log" - "strconv" "time" ) @@ -38,16 +38,7 @@ func handlePutEvents(span tracer.Span, in middleware.InitializeInput) { return } - for i := range params.Entries { - injectTraceContext(span, ¶ms.Entries[i]) - } -} - -func injectTraceContext(span tracer.Span, entryPtr *types.PutEventsRequestEntry) { - if entryPtr == nil { - return - } - + // Create trace context carrier := tracer.TextMapCarrier{} err := tracer.Inject(span.Context(), carrier) if err != nil { @@ -55,46 +46,63 @@ func injectTraceContext(span tracer.Span, entryPtr *types.PutEventsRequestEntry) return } - // Add start time and resource name + carrierJSON, err := json.Marshal(carrier) + if err != nil { + log.Debug("Unable to marshal trace context: %s", err) + return + } + + // Prepare the reused trace context string startTimeMillis := time.Now().UnixMilli() - carrier[startTimeKey] = strconv.FormatInt(startTimeMillis, 10) + reusedTraceContext := fmt.Sprintf(`%s,"%s":"%d"`, carrierJSON[:len(carrierJSON)-1], startTimeKey, startTimeMillis) + + for i := range params.Entries { + injectTraceContext(reusedTraceContext, ¶ms.Entries[i]) + } +} + +func injectTraceContext(baseTraceContext string, entryPtr *types.PutEventsRequestEntry) { + if entryPtr == nil { + return + } + + // Build the complete trace context + var traceContext string if entryPtr.EventBusName != nil { - carrier[resourceNameKey] = *entryPtr.EventBusName + traceContext = fmt.Sprintf(`%s,"%s":"%s"}`, baseTraceContext, resourceNameKey, *entryPtr.EventBusName) + } else { + traceContext = baseTraceContext + "}" } - var detail map[string]interface{} - if entryPtr.Detail != nil { - err = json.Unmarshal([]byte(*entryPtr.Detail), &detail) - if err != nil { - log.Debug("Unable to unmarshal event detail: %s", err) - return - } + // Get current detail string + var detail string + if entryPtr.Detail == nil || *entryPtr.Detail == "" { + detail = "{}" } else { - detail = make(map[string]interface{}) + detail = *entryPtr.Detail } - jsonBytes, err := json.Marshal(carrier) - if err != nil { - log.Debug("Unable to marshal trace context: %s", err) + // Basic JSON structure validation + if len(detail) < 2 || detail[len(detail)-1] != '}' { + log.Debug("Unable to parse detail JSON. Not injecting trace context into EventBridge payload.") return } - // Check sizes - detailSize := 0 - if entryPtr.Detail != nil { - detailSize = len(*entryPtr.Detail) - } - traceSize := len(jsonBytes) - if detailSize+traceSize > maxSizeBytes { - log.Info("Payload size too large to pass context") - return + // Create new detail string + var newDetail string + if len(detail) > 2 { + // Case where detail is not empty + newDetail = fmt.Sprintf(`%s,"%s":%s}`, detail[:len(detail)-1], datadogKey, traceContext) + } else { + // Cae where detail is empty + newDetail = fmt.Sprintf(`{"%s":%s}`, datadogKey, traceContext) } - detail[datadogKey] = json.RawMessage(jsonBytes) - updatedDetail, err := json.Marshal(detail) - if err != nil { - log.Debug("Unable to marshal modified event detail: %s", err) + // Check sizes + if len(newDetail) > maxSizeBytes { + log.Debug("Payload size too large to pass context") return } - entryPtr.Detail = aws.String(string(updatedDetail)) + + entryPtr.Detail = aws.String(newDetail) } diff --git a/contrib/aws/internal/eventbridge/eventbridge_test.go b/contrib/aws/internal/eventbridge/eventbridge_test.go index fda81714b3..77c9ab1e72 100644 --- a/contrib/aws/internal/eventbridge/eventbridge_test.go +++ b/contrib/aws/internal/eventbridge/eventbridge_test.go @@ -8,10 +8,7 @@ package eventbridge import ( "context" "encoding/json" - "strconv" - "strings" - "testing" - + "fmt" "github.com/aws/aws-sdk-go-v2/aws" "github.com/aws/aws-sdk-go-v2/service/eventbridge" "github.com/aws/aws-sdk-go-v2/service/eventbridge/types" @@ -20,6 +17,8 @@ import ( "github.com/stretchr/testify/require" "gopkg.in/DataDog/dd-trace-go.v1/ddtrace/mocktracer" "gopkg.in/DataDog/dd-trace-go.v1/ddtrace/tracer" + "strings" + "testing" ) func TestEnrichOperation(t *testing.T) { @@ -71,7 +70,8 @@ func TestInjectTraceContext(t *testing.T) { defer mt.Stop() ctx := context.Background() - span, ctx := tracer.StartSpanFromContext(ctx, "test-span") + span, _ := tracer.StartSpanFromContext(ctx, "test-span") + baseTraceContext := fmt.Sprintf(`{"x-datadog-trace-id":"%d","x-datadog-parent-id":"%d","x-datadog-start-time":"123456789"`, span.Context().TraceID(), span.Context().SpanID()) tests := []struct { name string @@ -110,7 +110,7 @@ func TestInjectTraceContext(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - injectTraceContext(span, &tt.entry) + injectTraceContext(baseTraceContext, &tt.entry) tt.expected(t, &tt.entry) var detail map[string]interface{} @@ -123,11 +123,9 @@ func TestInjectTraceContext(t *testing.T) { assert.Equal(t, *tt.entry.EventBusName, ddData[resourceNameKey]) // Check that start time exists and is not empty - startTimeStr, ok := ddData[startTimeKey].(string) + startTime, ok := ddData[startTimeKey] assert.True(t, ok) - startTime, err := strconv.ParseInt(startTimeStr, 10, 64) - assert.NoError(t, err) - assert.Greater(t, startTime, int64(0)) + assert.Equal(t, startTime, "123456789") carrier := tracer.TextMapCarrier{} for k, v := range ddData { @@ -148,7 +146,7 @@ func TestInjectTraceContextSizeLimit(t *testing.T) { mt := mocktracer.Start() defer mt.Stop() - span := tracer.StartSpan("test-span") + baseTraceContext := `{"x-datadog-trace-id":"12345","x-datadog-parent-id":"67890","x-datadog-start-time":"123456789"` tests := []struct { name string @@ -187,7 +185,7 @@ func TestInjectTraceContextSizeLimit(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - injectTraceContext(span, &tt.entry) + injectTraceContext(baseTraceContext, &tt.entry) tt.expected(t, &tt.entry) }) }