diff --git a/go.mod b/go.mod index c9a8dd4069772..280917c17ed59 100644 --- a/go.mod +++ b/go.mod @@ -153,7 +153,7 @@ require ( github.com/vulcand/predicate v1.2.0 // replaced go.etcd.io/etcd/api/v3 v3.5.9 go.etcd.io/etcd/client/v3 v3.5.9 - go.mongodb.org/mongo-driver v1.13.0-prerelease.0.20230726045955-5ee10b94cc66 + go.mongodb.org/mongo-driver v1.13.0 go.opentelemetry.io/contrib/instrumentation/github.com/aws/aws-sdk-go-v2/otelaws v0.46.1 go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc v0.46.1 go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.46.1 diff --git a/go.sum b/go.sum index b4536275e04f9..75aa010ba32ee 100644 --- a/go.sum +++ b/go.sum @@ -1584,8 +1584,8 @@ go.etcd.io/etcd/tests/v3 v3.5.0/go.mod h1:f+mtZ1bE1YPvgKdOJV2BKy4JQW0nAFnQehgOE7 go.etcd.io/etcd/v3 v3.5.0-alpha.0/go.mod h1:JZ79d3LV6NUfPjUxXrpiFAYcjhT+06qqw+i28snx8To= go.etcd.io/etcd/v3 v3.5.0/go.mod h1:FldM0/VzcxYWLvWx1sdA7ghKw7C3L2DvUTzGrcEtsC4= go.mongodb.org/mongo-driver v1.11.4/go.mod h1:PTSz5yu21bkT/wXpkS7WR5f0ddqw5quethTUn9WM+2g= -go.mongodb.org/mongo-driver v1.13.0-prerelease.0.20230726045955-5ee10b94cc66 h1:2uTsucgz0YmaUEvk4iu43KGxvVcG/bZ/rNsCmqiMGC4= -go.mongodb.org/mongo-driver v1.13.0-prerelease.0.20230726045955-5ee10b94cc66/go.mod h1:AZkxhPnFJUoH7kZlFkVKucV20K387miPfm7oimrSmK0= +go.mongodb.org/mongo-driver v1.13.0 h1:67DgFFjYOCMWdtTEmKFpV3ffWlFnh+CYZ8ZS/tXWUfY= +go.mongodb.org/mongo-driver v1.13.0/go.mod h1:/rGBTebI3XYboVmgz+Wv3Bcbl3aD0QF9zl6kDDw18rQ= go.opencensus.io v0.15.0/go.mod h1:UffZAU+4sDEINUGP/B7UfBBkq4fqLu9zXAX7ke6CHW0= go.opencensus.io v0.20.1/go.mod h1:6WKK9ahsWS3RSO+PY9ZHZUfv2irvY6gN279GOPZjmmk= go.opencensus.io v0.20.2/go.mod h1:6WKK9ahsWS3RSO+PY9ZHZUfv2irvY6gN279GOPZjmmk= diff --git a/lib/srv/db/access_test.go b/lib/srv/db/access_test.go index 9f757df45e17f..13839a7e28621 100644 --- a/lib/srv/db/access_test.go +++ b/lib/srv/db/access_test.go @@ -45,7 +45,6 @@ import ( "go.mongodb.org/mongo-driver/bson" "go.mongodb.org/mongo-driver/mongo" "go.mongodb.org/mongo-driver/mongo/options" - "go.mongodb.org/mongo-driver/x/mongo/driver/wiremessage" sqladmin "google.golang.org/api/sqladmin/v1beta4" "github.com/gravitational/teleport" @@ -76,6 +75,7 @@ import ( "github.com/gravitational/teleport/lib/srv/db/dynamodb" "github.com/gravitational/teleport/lib/srv/db/elasticsearch" "github.com/gravitational/teleport/lib/srv/db/mongodb" + "github.com/gravitational/teleport/lib/srv/db/mongodb/protocol" "github.com/gravitational/teleport/lib/srv/db/mysql" "github.com/gravitational/teleport/lib/srv/db/opensearch" "github.com/gravitational/teleport/lib/srv/db/postgres" @@ -862,7 +862,7 @@ func TestAccessMongoDB(t *testing.T) { { name: "current server", opts: []mongodb.TestServerOption{ - mongodb.TestServerWireVersion(wiremessage.OpmsgWireVersion), + mongodb.TestServerWireVersion(protocol.OpmsgWireVersion), }, }, { @@ -903,15 +903,15 @@ func TestAccessMongoDB(t *testing.T) { testCtx := setupTestContext(ctx, t, withSelfHostedMongo("mongo", serverOpt.opts...)) go testCtx.startHandlingConnections() + // Create user/role with the requested permissions. + testCtx.createUserAndRole(ctx, t, test.user, test.role, test.allowDbUsers, test.allowDbNames) + for _, clientOpt := range clientOpts { clientOpt := clientOpt t.Run(fmt.Sprintf("%v/%v", serverOpt.name, clientOpt.name), func(t *testing.T) { t.Parallel() - // Create user/role with the requested permissions. - testCtx.createUserAndRole(ctx, t, test.user, test.role, test.allowDbUsers, test.allowDbNames) - // Try to connect to the database as this user. mongoClient, err := testCtx.mongoClient(ctx, test.user, "mongo", test.dbUser, clientOpt.opts) t.Cleanup(func() { @@ -951,13 +951,13 @@ func TestMongoDBMaxMessageSize(t *testing.T) { expectedQueryError bool }{ "default message size": { - messageSize: 256, + messageSize: 300, }, "message size exceeded": { // Set a value that will enable handshake message to complete // successfully. - maxMessageSize: 256, - messageSize: 512, + maxMessageSize: 300, + messageSize: 500, expectedQueryError: true, }, } { diff --git a/lib/srv/db/mongodb/protocol/deprecated_wiremessage.go b/lib/srv/db/mongodb/protocol/deprecated_wiremessage.go new file mode 100644 index 0000000000000..35942df0b87f1 --- /dev/null +++ b/lib/srv/db/mongodb/protocol/deprecated_wiremessage.go @@ -0,0 +1,51 @@ +/* +Copyright 2023 Gravitational, Inc. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package protocol + +import ( + "go.mongodb.org/mongo-driver/x/bsonx/bsoncore" + "go.mongodb.org/mongo-driver/x/mongo/driver/wiremessage" +) + +// This file contains logic which has been deprecated from MongoDB's client library, but needs to be supported for +// our backwards compatibility needs. This deprecation started in MongoDB 1.13.0. + +// OpmsgWireVersion is the minimum wire version needed to use OP_MSG +const OpmsgWireVersion = 6 + +// ReadQueryFlags reads OP_QUERY flags from src. +func ReadQueryFlags(src []byte) (flags wiremessage.QueryFlag, rem []byte, ok bool) { + i32, rem, ok := readInt32(src) + return wiremessage.QueryFlag(i32), rem, ok +} + +// ReadQueryFullCollectionName reads the full collection name from src. +func ReadQueryFullCollectionName(src []byte) (collname string, rem []byte, ok bool) { + return readCString(src) +} + +// ReadQueryNumber is a replacement for ReadQueryNumberToSkip or ReadQueryNumberToSkip. This function reads a 32 bit +// integer from src. +func ReadQueryNumber(src []byte) (nts int32, rem []byte, ok bool) { + return readInt32(src) +} + +// ReadDocument is a replacement for ReadQueryQuery or ReadQueryReturnFieldsSelector. This function reads a bson +// document from src. +func ReadDocument(src []byte) (rfs bsoncore.Document, rem []byte, ok bool) { + return bsoncore.ReadDocument(src) +} diff --git a/lib/srv/db/mongodb/protocol/message.go b/lib/srv/db/mongodb/protocol/message.go index 51342870a2407..0c839cd312bb7 100644 --- a/lib/srv/db/mongodb/protocol/message.go +++ b/lib/srv/db/mongodb/protocol/message.go @@ -47,6 +47,24 @@ type Message interface { fmt.Stringer } +// These OpCode's define what Teleport supports. They values were up to date as of MongoDB 1.13.0 +// We need to reference these locally as MongoDB is deprecating some of these, but we need to maintain backwards +// compatibility. The state of deprecation can be witnessed by referencing the libraries version when possible, or +// static definition where no longer available. +const ( + OpReply = wiremessage.OpReply + OpUpdate = wiremessage.OpUpdate + OpInsert = wiremessage.OpInsert + OpQuery wiremessage.OpCode = 2004 + OpGetMore = wiremessage.OpGetMore + OpDelete wiremessage.OpCode = wiremessage.OpDelete + OpKillCursors wiremessage.OpCode = wiremessage.OpKillCursors + OpCommand wiremessage.OpCode = wiremessage.OpCommand + OpCommandReply wiremessage.OpCode = wiremessage.OpCommandReply + OpCompressed wiremessage.OpCode = wiremessage.OpCompressed + OpMsg wiremessage.OpCode = wiremessage.OpMsg +) + // ReadMessage reads the next MongoDB wire protocol message from the reader. func ReadMessage(reader io.Reader, maxMessageSize uint32) (Message, error) { header, payload, err := readHeaderAndPayload(reader, maxMessageSize) @@ -54,23 +72,23 @@ func ReadMessage(reader io.Reader, maxMessageSize uint32) (Message, error) { return nil, trace.Wrap(err) } switch header.OpCode { - case wiremessage.OpMsg: + case OpMsg: return readOpMsg(*header, payload) - case wiremessage.OpQuery: + case OpQuery: return readOpQuery(*header, payload) - case wiremessage.OpGetMore: + case OpGetMore: return readOpGetMore(*header, payload) - case wiremessage.OpInsert: + case OpInsert: return readOpInsert(*header, payload) - case wiremessage.OpUpdate: + case OpUpdate: return readOpUpdate(*header, payload) - case wiremessage.OpDelete: + case OpDelete: return readOpDelete(*header, payload) - case wiremessage.OpCompressed: + case OpCompressed: return readOpCompressed(*header, payload, maxMessageSize) - case wiremessage.OpReply: + case OpReply: return readOpReply(*header, payload) - case wiremessage.OpKillCursors: + case OpKillCursors: return readOpKillCursors(*header, payload) } return nil, trace.BadParameter("unknown wire protocol message: %v %v", diff --git a/lib/srv/db/mongodb/protocol/message_test.go b/lib/srv/db/mongodb/protocol/message_test.go index ad2cc10a964b8..8b415208fc209 100644 --- a/lib/srv/db/mongodb/protocol/message_test.go +++ b/lib/srv/db/mongodb/protocol/message_test.go @@ -431,7 +431,8 @@ func makeTestOpQuery(t *testing.T) *MessageOpQuery { ReturnFieldsSelector: makeTestDocument(t), } msg.bytes = msg.ToWire(0) - msg.Header = makeTestHeader(msg.bytes, wiremessage.OpQuery) + // OpQuery is deprecated, we define the code directly to make sure that our mapping is correct + msg.Header = makeTestHeader(msg.bytes, wiremessage.OpCode(2004)) return msg } diff --git a/lib/srv/db/mongodb/protocol/opquery.go b/lib/srv/db/mongodb/protocol/opquery.go index 40ae9619125c1..cf9c5b91690d7 100644 --- a/lib/srv/db/mongodb/protocol/opquery.go +++ b/lib/srv/db/mongodb/protocol/opquery.go @@ -86,29 +86,29 @@ func (m *MessageOpQuery) MoreToCome(_ Message) bool { // // https://docs.mongodb.com/manual/reference/mongodb-wire-protocol/#op_query func readOpQuery(header MessageHeader, payload []byte) (*MessageOpQuery, error) { - flags, rem, ok := wiremessage.ReadQueryFlags(payload) + flags, rem, ok := ReadQueryFlags(payload) if !ok { return nil, trace.BadParameter("malformed OP_QUERY: missing flags %v", payload) } - fullCollectionName, rem, ok := wiremessage.ReadQueryFullCollectionName(rem) + fullCollectionName, rem, ok := ReadQueryFullCollectionName(rem) if !ok { return nil, trace.BadParameter("malformed OP_QUERY: missing full collection name %v", payload) } - numberToSkip, rem, ok := wiremessage.ReadQueryNumberToSkip(rem) + numberToSkip, rem, ok := ReadQueryNumber(rem) if !ok { return nil, trace.BadParameter("malformed OP_QUERY: missing number to skip %v", payload) } - numberToReturn, rem, ok := wiremessage.ReadQueryNumberToReturn(rem) + numberToReturn, rem, ok := ReadQueryNumber(rem) if !ok { return nil, trace.BadParameter("malformed OP_QUERY: missing number to return %v", payload) } - query, rem, ok := wiremessage.ReadQueryQuery(rem) + query, rem, ok := ReadDocument(rem) if !ok { return nil, trace.BadParameter("malformed OP_QUERY: missing query %v", payload) } var returnFieldsSelector bsoncore.Document if len(rem) > 0 { - returnFieldsSelector, _, ok = wiremessage.ReadQueryReturnFieldsSelector(rem) + returnFieldsSelector, _, ok = ReadDocument(rem) if !ok { return nil, trace.BadParameter("malformed OP_QUERY: missing return field selector %v", payload) } @@ -130,6 +130,7 @@ func readOpQuery(header MessageHeader, payload []byte) (*MessageOpQuery, error) // https://docs.mongodb.com/manual/reference/mongodb-wire-protocol/#op_query func (m *MessageOpQuery) ToWire(responseTo int32) (dst []byte) { var idx int32 + //nolint:staticcheck // ignore deprecation till OpQuery is removed, at which point this wire format should be updated idx, dst = wiremessage.AppendHeaderStart(dst, m.Header.RequestID, responseTo, wiremessage.OpQuery) dst = wiremessage.AppendQueryFlags(dst, m.Flags) dst = wiremessage.AppendQueryFullCollectionName(dst, m.FullCollectionName) diff --git a/lib/srv/db/mongodb/test.go b/lib/srv/db/mongodb/test.go index 0f90e001866fd..717357d9bb783 100644 --- a/lib/srv/db/mongodb/test.go +++ b/lib/srv/db/mongodb/test.go @@ -280,12 +280,12 @@ func (s *TestServer) handleFind(message protocol.Message) (protocol.Message, err // handleSaslStart makes response to the client's "saslStart" command. func (s *TestServer) handleSaslStart(message protocol.Message) (protocol.Message, error) { - opmsg, ok := message.(*protocol.MessageOpQuery) + opmsg, ok := message.(*protocol.MessageOpMsg) if !ok { - return nil, trace.BadParameter("expected message type *protocol.MessageOpQuery but got %T", message) + return nil, trace.BadParameter("expected message type *protocol.MessageOpMsg but got %T", message) } - mechanism := opmsg.Query.Lookup("mechanism").StringValue() + mechanism := opmsg.BodySection.Document.Lookup("mechanism").StringValue() conversationID := atomic.AddInt32(&s.conversationIdx, 1) s.saslConversationTracker.Store(conversationID, mechanism) @@ -301,12 +301,12 @@ func (s *TestServer) handleSaslStart(message protocol.Message) (protocol.Message // It expects a conversion to be present at `saslConversationTracker`, // otherwise it won't be able to define which authentication mechanism to use. func (s *TestServer) handleSaslContinue(message protocol.Message) (protocol.Message, error) { - opmsg, ok := message.(*protocol.MessageOpQuery) + opmsg, ok := message.(*protocol.MessageOpMsg) if !ok { - return nil, trace.BadParameter("expected message type *protocol.MessageOpQuery but got %T", message) + return nil, trace.BadParameter("expected message type *protocol.MessageOpMsg but got %T", message) } - conversationID := opmsg.Query.Lookup("conversationId").Int32() + conversationID := opmsg.BodySection.Document.Lookup("conversationId").Int32() mechanism, ok := s.saslConversationTracker.Load(conversationID) if !ok { return nil, trace.NotFound("conversationID not found") @@ -322,8 +322,8 @@ func (s *TestServer) handleSaslContinue(message protocol.Message) (protocol.Mess // handleAWSIAMSaslStart handles the "saslStart" command for "MONGODB-AWS" // authentication. -func (s *TestServer) handleAWSIAMSaslStart(conversationID int32, opmsg *protocol.MessageOpQuery) (protocol.Message, error) { - _, userPass := opmsg.Query.Lookup("payload").Binary() +func (s *TestServer) handleAWSIAMSaslStart(conversationID int32, opmsg *protocol.MessageOpMsg) (protocol.Message, error) { + _, userPass := opmsg.BodySection.Document.Lookup("payload").Binary() doc, _, ok := bsoncore.ReadDocument(userPass) if !ok { return nil, trace.BadParameter("invalid payload") @@ -358,8 +358,8 @@ func (s *TestServer) handleAWSIAMSaslStart(conversationID int32, opmsg *protocol // handleAWSIAMSaslContinue handles the "saslStart" command for "MONGODB-AWS" // authentication. -func (s *TestServer) handleAWSIAMSaslContinue(conversationID int32, opmsg *protocol.MessageOpQuery) (protocol.Message, error) { - _, awsSaslPayload := opmsg.Query.Lookup("payload").Binary() +func (s *TestServer) handleAWSIAMSaslContinue(conversationID int32, opmsg *protocol.MessageOpMsg) (protocol.Message, error) { + _, awsSaslPayload := opmsg.BodySection.Document.Lookup("payload").Binary() doc, _, ok := bsoncore.ReadDocument(awsSaslPayload) if !ok { return nil, trace.BadParameter("invalid payload")