Skip to content

Commit

Permalink
Use mtest for mocking otelmongo
Browse files Browse the repository at this point in the history
  • Loading branch information
prestonvasquez committed Jul 30, 2024
1 parent e153791 commit 74892a2
Show file tree
Hide file tree
Showing 5 changed files with 55 additions and 314 deletions.
4 changes: 0 additions & 4 deletions .github/codecov.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,3 @@ comment:
layout: "reach,diff,flags,tree"
behavior: default
require_changes: yes

ignore:
# opmsg_deployment is copied from mongo-go-driver.
- "instrumentation/go.mongodb.org/mongo-driver/mongo/otelmongo/test/opmsg_deployment.go"
8 changes: 1 addition & 7 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -186,15 +186,9 @@ vanity-import-check: | $(PORTO)
.PHONY: lint
lint: go-mod-tidy golangci-lint misspell govulncheck

# The following file is a third-party copy from the mongo-go-driver:
# ./instrumentation/go.mongodb.org/mongo-driver/mongo/otelmongo/test/opmsg_deployment.go
.PHONY: license-check
license-check:
@licRes=$$(for f in $$(find . -type f \( -iname '*.go' -o -iname '*.sh' \) \
! -path './vendor/*' \
! -path './exporters/otlp/internal/opentelemetry-proto/*' \
! -path './instrumentation/go.mongodb.org/mongo-driver/mongo/otelmongo/test/opmsg_deployment.go') ; do \
awk '/Copyright The OpenTelemetry Authors|generated|GENERATED/ && NR<=4 { found=1; next } END { if (!found) print FILENAME }' $$f; \
@licRes=$$(for f in $$(find . -type f \( -iname '*.go' -o -iname '*.sh' \) ! -path './vendor/*' ! -path './exporters/otlp/internal/opentelemetry-proto/*') ; do \
done); \
if [ -n "$${licRes}" ]; then \
echo "license header checking failed:"; echo "$${licRes}"; \
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@ module go.opentelemetry.io/contrib/instrumentation/go.mongodb.org/mongo-driver/m

go 1.21

replace go.mongodb.org/mongo-driver => /Users/preston.vasquez/Developer/mongo-go-driver-2

require (
github.com/stretchr/testify v1.9.0
go.mongodb.org/mongo-driver v1.15.0
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ import (
"github.com/stretchr/testify/assert"
"go.mongodb.org/mongo-driver/bson"
"go.mongodb.org/mongo-driver/mongo"
"go.mongodb.org/mongo-driver/mongo/integration/mtest"
"go.mongodb.org/mongo-driver/mongo/options"

"go.opentelemetry.io/contrib/instrumentation/go.mongodb.org/mongo-driver/mongo/otelmongo" // nolint:staticcheck // deprecated.
Expand All @@ -24,6 +25,8 @@ import (
type validator func(sdktrace.ReadOnlySpan) bool

func TestDBCrudOperation(t *testing.T) {
t.Parallel()

commonValidators := []validator{
func(s sdktrace.ReadOnlySpan) bool {
return assert.Equal(t, "test-collection.insert", s.Name(), "expected %s", s.Name())
Expand Down Expand Up @@ -80,13 +83,19 @@ func TestDBCrudOperation(t *testing.T) {
},
}
for _, tc := range tt {
tc := tc

title := tc.title
if tc.excludeCommand {
title = title + "/excludeCommand"
} else {
title = title + "/includeCommand"
}
t.Run(title, func(t *testing.T) {

mt := mtest.New(t, mtest.NewOptions().ClientType(mtest.Mock))
mt.Run(title, func(mt *mtest.T) {
mt.Parallel()

sr := tracetest.NewSpanRecorder()
provider := sdktrace.NewTracerProvider(sdktrace.WithSpanProcessor(sr))

Expand All @@ -103,56 +112,43 @@ func TestDBCrudOperation(t *testing.T) {
)
opts.ApplyURI(addr)

mock := newMockDeployment()

// nolint:staticcheck
//
// Deployment is not part of the stable API guarantee of the
// mongo-go-driver and is therefore marked as deprecated.
//
// See https://jira.mongodb.org/browse/GODRIVER-3241 for a long-term solution.
opts.Deployment = mock

client, err := mongo.Connect(ctx, opts)
if err != nil {
t.Fatal(err)
}

mock.addResponses(tc.mockResponses...)
t.Cleanup(mock.clearResponses)
mt.ResetClient(opts)
mt.AddMockResponses(tc.mockResponses...)

_, err = tc.operation(ctx, client.Database("test-database"))
_, err := tc.operation(ctx, mt.Client.Database("test-database"))
if err != nil {
t.Error(err)
mt.Error(err)
}

span.End()

spans := sr.Ended()
if !assert.Len(t, spans, 2, "expected 2 spans, received %d", len(spans)) {
t.FailNow()
if !assert.Len(mt, spans, 2, "expected 2 spans, received %d", len(spans)) {
mt.FailNow()
}
assert.Len(t, spans, 2)
assert.Equal(t, spans[0].SpanContext().TraceID(), spans[1].SpanContext().TraceID())
assert.Equal(t, spans[0].Parent().SpanID(), spans[1].SpanContext().SpanID())
assert.Equal(t, span.SpanContext().SpanID(), spans[1].SpanContext().SpanID())
assert.Len(mt, spans, 2)
assert.Equal(mt, spans[0].SpanContext().TraceID(), spans[1].SpanContext().TraceID())
assert.Equal(mt, spans[0].Parent().SpanID(), spans[1].SpanContext().SpanID())
assert.Equal(mt, span.SpanContext().SpanID(), spans[1].SpanContext().SpanID())

s := spans[0]
assert.Equal(t, trace.SpanKindClient, s.SpanKind())
assert.Equal(mt, trace.SpanKindClient, s.SpanKind())
attrs := s.Attributes()
assert.Contains(t, attrs, attribute.String("db.system", "mongodb"))
assert.Contains(t, attrs, attribute.String("net.peer.name", "<mock_connection>"))
assert.Contains(t, attrs, attribute.Int64("net.peer.port", int64(27017)))
assert.Contains(t, attrs, attribute.String("net.transport", "ip_tcp"))
assert.Contains(t, attrs, attribute.String("db.name", "test-database"))
assert.Contains(mt, attrs, attribute.String("db.system", "mongodb"))
assert.Contains(mt, attrs, attribute.String("net.peer.name", "<mock_connection>"))
assert.Contains(mt, attrs, attribute.Int64("net.peer.port", int64(27017)))
assert.Contains(mt, attrs, attribute.String("net.transport", "ip_tcp"))
assert.Contains(mt, attrs, attribute.String("db.name", "test-database"))
for _, v := range tc.validators {
assert.True(t, v(s))
assert.True(mt, v(s))
}
})
}
}

func TestDBCollectionAttribute(t *testing.T) {
t.Parallel()

tt := []struct {
title string
operation func(context.Context, *mongo.Database) (interface{}, error)
Expand Down Expand Up @@ -205,7 +201,12 @@ func TestDBCollectionAttribute(t *testing.T) {
},
}
for _, tc := range tt {
t.Run(tc.title, func(t *testing.T) {
tc := tc

mt := mtest.New(t, mtest.NewOptions().ClientType(mtest.Mock))
mt.Run(tc.title, func(mt *mtest.T) {
mt.Parallel()

sr := tracetest.NewSpanRecorder()
provider := sdktrace.NewTracerProvider(sdktrace.WithSpanProcessor(sr))

Expand All @@ -222,50 +223,35 @@ func TestDBCollectionAttribute(t *testing.T) {
)
opts.ApplyURI(addr)

mock := newMockDeployment()

// nolint:staticcheck
//
// Deployment is not part of the stable API guarantee of the
// mongo-go-driver and is therefore marked as deprecated.
//
// See https://jira.mongodb.org/browse/GODRIVER-3241 for a long-term solution.
opts.Deployment = mock

client, err := mongo.Connect(ctx, opts)
if err != nil {
t.Fatal(err)
}

mock.addResponses(tc.mockResponses...)
t.Cleanup(mock.clearResponses)
mt.ResetClient(opts)
mt.AddMockResponses(tc.mockResponses...)

_, err = tc.operation(ctx, client.Database("test-database"))
_, err := tc.operation(ctx, mt.Client.Database("test-database"))
if err != nil {
t.Error(err)
mt.Error(err)
}

span.End()

spans := sr.Ended()
if !assert.Len(t, spans, 2, "expected 2 spans, received %d", len(spans)) {
t.FailNow()
if !assert.Len(mt, spans, 2, "expected 2 spans, received %d", len(spans)) {
mt.FailNow()
}
assert.Len(t, spans, 2)
assert.Equal(t, spans[0].SpanContext().TraceID(), spans[1].SpanContext().TraceID())
assert.Equal(t, spans[0].Parent().SpanID(), spans[1].SpanContext().SpanID())
assert.Equal(t, span.SpanContext().SpanID(), spans[1].SpanContext().SpanID())
assert.Len(mt, spans, 2)
assert.Equal(mt, spans[0].SpanContext().TraceID(), spans[1].SpanContext().TraceID())
assert.Equal(mt, spans[0].Parent().SpanID(), spans[1].SpanContext().SpanID())
assert.Equal(mt, span.SpanContext().SpanID(), spans[1].SpanContext().SpanID())

s := spans[0]
assert.Equal(t, trace.SpanKindClient, s.SpanKind())
assert.Equal(mt, trace.SpanKindClient, s.SpanKind())
attrs := s.Attributes()
assert.Contains(t, attrs, attribute.String("db.system", "mongodb"))
assert.Contains(t, attrs, attribute.String("net.peer.name", "<mock_connection>"))
assert.Contains(t, attrs, attribute.Int64("net.peer.port", int64(27017)))
assert.Contains(t, attrs, attribute.String("net.transport", "ip_tcp"))
assert.Contains(t, attrs, attribute.String("db.name", "test-database"))
assert.Contains(mt, attrs, attribute.String("db.system", "mongodb"))
assert.Contains(mt, attrs, attribute.String("net.peer.name", "<mock_connection>"))
assert.Contains(mt, attrs, attribute.Int64("net.peer.port", int64(27017)))
assert.Contains(mt, attrs, attribute.String("net.transport", "ip_tcp"))
assert.Contains(mt, attrs, attribute.String("db.name", "test-database"))
for _, v := range tc.validators {
assert.True(t, v(s))
assert.True(mt, v(s))
}
})
}
Expand Down
Loading

0 comments on commit 74892a2

Please sign in to comment.