Skip to content

Commit

Permalink
feat: support tls of mongodb (bytebase#10868)
Browse files Browse the repository at this point in the history
* chore: enable tls for docdb

Signed-off-by: h3n4l <[email protected]>

* chore: update

Signed-off-by: h3n4l <[email protected]>

* chore: compability

Signed-off-by: h3n4l <[email protected]>

* chore: compability

Signed-off-by: h3n4l <[email protected]>

* chore: updatE

Signed-off-by: h3n4l <[email protected]>

---------

Signed-off-by: h3n4l <[email protected]>
  • Loading branch information
h3n4l authored Feb 27, 2024
1 parent 3adb56c commit 669cd49
Show file tree
Hide file tree
Showing 6 changed files with 134 additions and 36 deletions.
85 changes: 63 additions & 22 deletions backend/plugin/db/mongodb/mongodb.go
Original file line number Diff line number Diff line change
Expand Up @@ -50,8 +50,15 @@ func newDriver(dc db.DriverConfig) db.Driver {

// Open opens a MongoDB driver.
func (driver *Driver) Open(ctx context.Context, _ storepb.Engine, connCfg db.ConnectionConfig) (db.Driver, error) {
connectionURI := getMongoDBConnectionURI(connCfg)
connectionURI := getBasicMongoDBConnectionURI(connCfg)
opts := options.Client().ApplyURI(connectionURI)
tlsConfig, err := connCfg.TLSConfig.GetSslConfig()
if err != nil {
return nil, errors.Wrap(err, "failed to get SSL config")
}
if tlsConfig != nil {
opts.SetTLSConfig(tlsConfig)
}
client, err := mongo.Connect(ctx, opts)
if err != nil {
return nil, errors.Wrap(err, "failed to create MongoDB client")
Expand Down Expand Up @@ -91,14 +98,36 @@ func (*Driver) GetDB() *sql.DB {

// Execute executes a statement, always returns 0 as the number of rows affected because we execute the statement by mongosh, it's hard to catch the row effected number.
func (driver *Driver) Execute(ctx context.Context, statement string, _ db.ExecuteOptions) (int64, error) {
connectionURI := getMongoDBConnectionURI(driver.connCfg)
connectionURI := getBasicMongoDBConnectionURI(driver.connCfg)
// For MongoDB, we execute the statement in mongosh, which is a shell for MongoDB.
// There are some ways to execute the statement in mongosh:
// 1. Use the --eval option to execute the statement.
// 2. Use the --file option to execute the statement from a file.
// We choose the second way with the following reasons:
// 1. The statement may too long to be executed in the command line.
// 2. We cannot catch the error from the --eval option.
mongoshArgs := []string{
connectionURI,
// DocumentDB do not support retryWrites, so we set it to false.
"--retryWrites",
"false",
"--quiet",
}

if driver.connCfg.TLSConfig.SslCA != "" {
mongoshArgs = append(mongoshArgs, "--tls")
// Write the tlsCAFile to a temporary file, and use the temporary file as the value of --tlsCAFile.
// The reason is that the --tlsCAFile option of mongosh does not support the value of the certificate directly.
caFileName := fmt.Sprintf("mongodb-tls-ca-%s-%s", driver.connCfg.ConnectionDatabase, uuid.New().String())
defer func() {
// While error occurred in mongosh, the temporary file may not created, so we ignore the error here.
_ = os.Remove(caFileName)
}()
if err := os.WriteFile(caFileName, []byte(driver.connCfg.TLSConfig.SslCA), 0400); err != nil {
return 0, errors.Wrap(err, "failed to write tlsCAFile to temporary file")
}
mongoshArgs = append(mongoshArgs, "--tlsCAFile", caFileName)
}

// First, we create a temporary file to store the statement.
tempDir := os.TempDir()
Expand All @@ -113,14 +142,8 @@ func (driver *Driver) Execute(ctx context.Context, statement string, _ db.Execut
if err := tempFile.Close(); err != nil {
return 0, errors.Wrap(err, "failed to close temporary file")
}
mongoshArgs = append(mongoshArgs, "--file", tempFile.Name())

// Then, we execute the statement in mongosh.
mongoshArgs := []string{
connectionURI,
"--quiet",
"--file",
tempFile.Name(),
}
mongoshCmd := exec.CommandContext(ctx, mongoutil.GetMongoshPath(driver.dbBinDir), mongoshArgs...)
var errContent bytes.Buffer
mongoshCmd.Stderr = &errContent
Expand All @@ -140,9 +163,9 @@ func (*Driver) Restore(_ context.Context, _ io.Reader) error {
panic("not implemented")
}

// getMongoDBConnectionURI returns the MongoDB connection URI.
// getBasicMongoDBConnectionURI returns the MongoDB connection URI.
// https://www.mongodb.com/docs/manual/reference/connection-string/
func getMongoDBConnectionURI(connConfig db.ConnectionConfig) string {
func getBasicMongoDBConnectionURI(connConfig db.ConnectionConfig) string {
u := &url.URL{
Scheme: "mongodb",
// In RFC, there can be no tailing slash('/') in the path if the path is empty and the query is not empty.
Expand Down Expand Up @@ -176,7 +199,7 @@ func (driver *Driver) QueryConn(ctx context.Context, _ *sql.Conn, statement stri
statement = strings.Trim(statement, " \t\n\r\f;")
simpleStatement := isMongoStatement(statement)
startTime := time.Now()
connectionURI := getMongoDBConnectionURI(driver.connCfg)
connectionURI := getBasicMongoDBConnectionURI(driver.connCfg)
// For MongoDB query, we execute the statement in mongosh with flag --eval for the following reasons:
// 1. Query always short, so it's safe to execute in the command line.
// 2. We cannot catch the output if we use the --file option.
Expand All @@ -198,21 +221,39 @@ func (driver *Driver) QueryConn(ctx context.Context, _ *sql.Conn, statement stri
evalArg = strings.ReplaceAll(evalArg, `'`, `'"'`)
evalArg = fmt.Sprintf(`'%s'`, evalArg)

fileName := fmt.Sprintf("mongodb-query-%s-%s", driver.connCfg.ConnectionDatabase, uuid.New().String())
defer func() {
// While error occurred in mongosh, the temporary file may not created, so we ignore the error here.
_ = os.Remove(fileName)
}()
mongoshArgs := []string{
mongoutil.GetMongoshPath(driver.dbBinDir),
connectionURI,
"--quiet",
"--eval",
evalArg,
">",
fileName,
// DocumentDB do not support retryWrites, so we set it to false.
"--retryWrites",
"false",
}

if driver.connCfg.TLSConfig.SslCA != "" {
mongoshArgs = append(mongoshArgs, "--tls")
// Write the tlsCAFile to a temporary file, and use the temporary file as the value of --tlsCAFile.
// The reason is that the --tlsCAFile option of mongosh does not support the value of the certificate directly.
caFileName := fmt.Sprintf("mongodb-tls-ca-%s-%s", driver.connCfg.ConnectionDatabase, uuid.New().String())
defer func() {
// While error occurred in mongosh, the temporary file may not created, so we ignore the error here.
_ = os.Remove(caFileName)
}()
if err := os.WriteFile(caFileName, []byte(driver.connCfg.TLSConfig.SslCA), 0400); err != nil {
return nil, errors.Wrap(err, "failed to write tlsCAFile to temporary file")
}
mongoshArgs = append(mongoshArgs, "--tlsCAFile", caFileName)
}

queryResultFileName := fmt.Sprintf("mongodb-query-%s-%s", driver.connCfg.ConnectionDatabase, uuid.New().String())
defer func() {
// While error occurred in mongosh, the temporary file may not created, so we ignore the error here.
_ = os.Remove(queryResultFileName)
}()
mongoshArgs = append(mongoshArgs, ">", queryResultFileName)

shellArgs := []string{
"-c",
strings.Join(mongoshArgs, " "),
Expand All @@ -226,15 +267,15 @@ func (driver *Driver) QueryConn(ctx context.Context, _ *sql.Conn, statement stri
return nil, errors.Wrapf(err, "failed to execute statement in mongosh: \n stdout: %s\n stderr: %s", outContent.String(), errContent.String())
}

f, err := os.OpenFile(fileName, os.O_RDONLY, 0644)
f, err := os.OpenFile(queryResultFileName, os.O_RDONLY, 0644)
if err != nil {
return nil, errors.Wrapf(err, "failed to open file: %s", fileName)
return nil, errors.Wrapf(err, "failed to open file: %s", queryResultFileName)
}
defer f.Close()

content, err := io.ReadAll(f)
if err != nil {
return nil, errors.Wrapf(err, "failed to read file: %s", fileName)
return nil, errors.Wrapf(err, "failed to read file: %s", queryResultFileName)
}

if simpleStatement {
Expand Down
2 changes: 1 addition & 1 deletion backend/plugin/db/mongodb/mongodb_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ func TestGetMongoDBConnectionURL(t *testing.T) {

a := require.New(t)
for _, tt := range tests {
got := getMongoDBConnectionURI(tt.connConfig)
got := getBasicMongoDBConnectionURI(tt.connConfig)
a.Equal(tt.want, got)
}
}
Expand Down
50 changes: 42 additions & 8 deletions backend/plugin/db/mongodb/sync.go
Original file line number Diff line number Diff line change
Expand Up @@ -94,13 +94,51 @@ func (driver *Driver) SyncDBSchema(ctx context.Context) (*storepb.DatabaseSchema
}

database := driver.client.Database(driver.databaseName)
collectionList, err := database.ListCollectionNames(ctx, bson.M{"type": "collection"})
collectionList, err := database.ListCollections(ctx, bson.M{})
if err != nil {
return nil, errors.Wrap(err, "failed to list collection names")
}
sort.Strings(collectionList)
var collectionNames []string
var viewNames []string
for collectionList.Next(ctx) {
var collection bson.M
if err := collectionList.Decode(&collection); err != nil {
return nil, errors.Wrap(err, "failed to decode collection")
}
var tp string
if t, ok := collection["type"]; ok {
if s, ok := t.(string); ok && s == "collection" {
tp = "collection"
}
if s, ok := t.(string); ok && s == "view" {
tp = "view"
}
}
name, ok := collection["name"]
if !ok {
return nil, errors.New("cannot get collection name from collection info")
}
collectionName, ok := name.(string)
if !ok {
return nil, errors.New("cannot convert collection name to string")
}
switch tp {
case "collection":
collectionNames = append(collectionNames, collectionName)
case "view":
viewNames = append(viewNames, collectionName)
}
}
if err := collectionList.Err(); err != nil {
return nil, errors.Wrap(err, "failed to list collection names")
}
if err := collectionList.Close(ctx); err != nil {
return nil, errors.Wrap(err, "failed to close collection list")
}
sort.Strings(collectionNames)
sort.Strings(viewNames)

for _, collectionName := range collectionList {
for _, collectionName := range collectionNames {
if systemCollection[collectionName] {
continue
}
Expand Down Expand Up @@ -150,11 +188,7 @@ func (driver *Driver) SyncDBSchema(ctx context.Context) (*storepb.DatabaseSchema
})
}

viewList, err := database.ListCollectionNames(ctx, bson.M{"type": "view"})
if err != nil {
return nil, errors.Wrap(err, "failed to list view names")
}
for _, viewName := range viewList {
for _, viewName := range viewNames {
schemaMetadata.Views = append(schemaMetadata.Views, &storepb.ViewMetadata{Name: viewName})
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -154,11 +154,19 @@
</label>
</div>
<template v-if="dataSource.pendingCreate">
<SslCertificateForm :value="dataSource" @change="handleSSLChange" />
<SslCertificateForm
:value="dataSource"
:engine-type="basicInfo.engine"
@change="handleSSLChange"
/>
</template>
<template v-else>
<template v-if="dataSource.updateSsl">
<SslCertificateForm :value="dataSource" @change="handleSSLChange" />
<SslCertificateForm
:value="dataSource"
:engine-type="basicInfo.engine"
@change="handleSSLChange"
/>
</template>
<template v-else>
<NButton
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,11 @@
v-model:value="state.type"
class="!flex flex-row items-center gap-x-4 mt-2"
>
<NRadio v-for="type in SslTypes" :key="type" :value="type">
<NRadio
v-for="type in sslTypeCandidatesOfEngine(props.engineType)"
:key="type"
:value="type"
>
<span class="textlabel">{{ getSslTypeLabel(type) }}</span>
</NRadio>
</NRadioGroup>
Expand Down Expand Up @@ -60,10 +64,9 @@ import { NTabs, NTabPane, NRadio, NRadioGroup } from "naive-ui";
import { PropType, reactive, watch } from "vue";
import { useI18n } from "vue-i18n";
import DroppableTextarea from "@/components/misc/DroppableTextarea.vue";
import { Engine } from "@/types/proto/v1/common";
import { DataSource } from "@/types/proto/v1/instance_service";
const SslTypes = ["NONE", "CA", "CA+KEY+CERT"] as const;
type SslType = "NONE" | "CA" | "CA+KEY+CERT";
type WithSslOptions = Partial<Pick<DataSource, "sslCa" | "sslCert" | "sslKey">>;
Expand All @@ -79,6 +82,10 @@ const props = defineProps({
type: Object as PropType<WithSslOptions>,
required: true,
},
engineType: {
type: Object as PropType<Engine>,
required: true,
},
});
const emit = defineEmits<{
Expand Down Expand Up @@ -151,4 +158,11 @@ function guessSslType(value: WithSslOptions): SslType {
}
return "NONE";
}
function sslTypeCandidatesOfEngine(engineType: Engine): SslType[] {
if (engineType === Engine.MONGODB) {
return ["NONE", "CA"];
}
return ["NONE", "CA", "CA+KEY+CERT"];
}
</script>
1 change: 1 addition & 0 deletions frontend/src/utils/v1/instance.ts
Original file line number Diff line number Diff line change
Expand Up @@ -172,6 +172,7 @@ export const instanceV1HasSSL = (
Engine.DM,
Engine.STARROCKS,
Engine.DORIS,
Engine.MONGODB,
].includes(engine);
};

Expand Down

0 comments on commit 669cd49

Please sign in to comment.