Skip to content

Commit

Permalink
feat: Add function to set correlation ID
Browse files Browse the repository at this point in the history
Signed-off-by: Bob Stasyszyn <[email protected]>
  • Loading branch information
bstasyszyn committed Nov 1, 2024
1 parent 4f33656 commit 59a4ff9
Show file tree
Hide file tree
Showing 3 changed files with 112 additions and 101 deletions.
86 changes: 86 additions & 0 deletions pkg/otel/correlationid/correlationid.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
/*
Copyright Gen Digital Inc. All Rights Reserved.
SPDX-License-Identifier: Apache-2.0
*/

package correlationid

import (
"context"
"crypto/rand"
"crypto/sha256"
"encoding/hex"
"fmt"
"net/http"
"strings"

"go.opentelemetry.io/otel/trace"

"github.com/trustbloc/logutil-go/pkg/otel/api"
)

const (
nilTraceID = "00000000000000000000000000000000"
correlationIDLength = 8
)

type contextKey struct{}

// Set derives the correlation ID from the OpenTelemetry trace ID and sets it on the returned context.
// If no trace ID is available, a random correlation ID is generated.
func Set(ctx context.Context) (context.Context, string, error) {
var correlationID string

traceID := trace.SpanFromContext(ctx).SpanContext().TraceID().String()
if traceID != "" && traceID != nilTraceID {
correlationID = deriveID(traceID)
} else {
var err error
correlationID, err = generateID()
if err != nil {
return nil, "", fmt.Errorf("generate correlation ID: %w", err)
}
}

return context.WithValue(ctx, contextKey{}, correlationID), correlationID, nil
}

// Transport is an HTTP RoundTripper that adds a correlation ID to the request header.
type Transport struct {
defaultTransport http.RoundTripper
}

// NewHTTPTransport creates a new HTTP Transport.
func NewHTTPTransport(defaultTransport http.RoundTripper) *Transport {
return &Transport{
defaultTransport: defaultTransport,
}
}

// RoundTrip executes a single HTTP transaction.
func (t *Transport) RoundTrip(req *http.Request) (*http.Response, error) {
correlationID, ok := req.Context().Value(contextKey{}).(string)
if ok {
req = req.Clone(req.Context())
req.Header.Add(api.CorrelationIDHeader, correlationID)
}

return t.defaultTransport.RoundTrip(req)
}

func generateID() (string, error) {
bytes := make([]byte, correlationIDLength/2) //nolint:gomnd

if _, err := rand.Read(bytes); err != nil {
return "", err
}

return strings.ToUpper(hex.EncodeToString(bytes)), nil
}

func deriveID(id string) string {
hash := sha256.Sum256([]byte(id))

return strings.ToUpper(hex.EncodeToString(hash[:correlationIDLength/2])) //nolint:gomnd
}
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ Copyright Gen Digital Inc. All Rights Reserved.
SPDX-License-Identifier: Apache-2.0
*/

package correlationidtransport
package correlationid

import (
"context"
Expand All @@ -19,17 +19,20 @@ import (
)

func TestTransport_RoundTrip(t *testing.T) {
var rt mockRoundTripperFunc = func(req *http.Request) (*http.Response, error) {
correlationID := req.Header.Get(api.CorrelationIDHeader)
t.Run("No span", func(t *testing.T) {
var rt mockRoundTripperFunc = func(req *http.Request) (*http.Response, error) {
require.Len(t, req.Header.Get(api.CorrelationIDHeader), 8)

require.Len(t, correlationID, 8)
return &http.Response{}, nil
}
return &http.Response{}, nil
}

transport := New(rt, WithCorrelationIDLength(8))
transport := NewHTTPTransport(rt)

t.Run("No span", func(t *testing.T) {
req, err := http.NewRequestWithContext(context.Background(), http.MethodGet, "http://example.com", nil)
ctx, correlationID, err := Set(context.Background())
require.NoError(t, err)
require.NotEmpty(t, correlationID)

req, err := http.NewRequestWithContext(ctx, http.MethodGet, "http://example.com", nil)
require.NoError(t, err)

resp, err := transport.RoundTrip(req)
Expand All @@ -38,13 +41,27 @@ func TestTransport_RoundTrip(t *testing.T) {
})

t.Run("With span", func(t *testing.T) {
var correlationID string

var rt mockRoundTripperFunc = func(req *http.Request) (*http.Response, error) {
require.Equal(t, correlationID, req.Header.Get(api.CorrelationIDHeader))

return &http.Response{}, nil
}

transport := NewHTTPTransport(rt)

tp := trace.NewTracerProvider()

otel.SetTracerProvider(tp)

ctx, span := tp.Tracer("test").Start(context.Background(), "test")
require.NotNil(t, span)

var err error
ctx, correlationID, err = Set(ctx)
require.NoError(t, err)

req, err := http.NewRequestWithContext(ctx, http.MethodGet, "http://example.com", nil)
require.NoError(t, err)

Expand Down
92 changes: 0 additions & 92 deletions pkg/otel/correlationidtransport/correlationidtransport.go

This file was deleted.

0 comments on commit 59a4ff9

Please sign in to comment.