Skip to content

Commit

Permalink
write encoded source_paths to DB to correctly handle database names w…
Browse files Browse the repository at this point in the history
…ith commas
  • Loading branch information
qrort committed Dec 18, 2024
1 parent 4332f1c commit 90bd3ed
Show file tree
Hide file tree
Showing 5 changed files with 141 additions and 20 deletions.
11 changes: 5 additions & 6 deletions internal/connectors/db/process_result_set.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@ package db

import (
"fmt"
"strings"
"time"

"ydbcp/internal/types"
Expand Down Expand Up @@ -108,7 +107,7 @@ func ReadBackupFromResultSet(res result.Result) (*types.Backup, error) {

sourcePathsSlice := make([]string, 0)
if sourcePaths != nil {
sourcePathsSlice = strings.Split(*sourcePaths, ",")
sourcePathsSlice = types.ParseSourcePaths(*sourcePaths)
}

return &types.Backup{
Expand Down Expand Up @@ -191,10 +190,10 @@ func ReadOperationFromResultSet(res result.Result) (types.Operation, error) {
sourcePathsSlice := make([]string, 0)
sourcePathsToExcludeSlice := make([]string, 0)
if sourcePaths != nil {
sourcePathsSlice = strings.Split(*sourcePaths, ",")
sourcePathsSlice = types.ParseSourcePaths(*sourcePaths)
}
if sourcePathsToExclude != nil {
sourcePathsToExcludeSlice = strings.Split(*sourcePathsToExclude, ",")
sourcePathsToExcludeSlice = types.ParseSourcePaths(*sourcePathsToExclude)
}

if updatedAt != nil {
Expand Down Expand Up @@ -363,10 +362,10 @@ func ReadBackupScheduleFromResultSet(res result.Result, withRPOInfo bool) (*type
var sourcePathsSlice []string
var sourcePathsToExcludeSlice []string
if sourcePaths != nil {
sourcePathsSlice = strings.Split(*sourcePaths, ",")
sourcePathsSlice = types.ParseSourcePaths(*sourcePaths)
}
if sourcePathsToExclude != nil {
sourcePathsToExcludeSlice = strings.Split(*sourcePathsToExclude, ",")
sourcePathsToExcludeSlice = types.ParseSourcePaths(*sourcePathsToExclude)
}

var ttlDuration *durationpb.Duration
Expand Down
20 changes: 10 additions & 10 deletions internal/connectors/db/yql/queries/write.go
Original file line number Diff line number Diff line change
Expand Up @@ -118,12 +118,12 @@ func BuildCreateOperationQuery(operation types.Operation, index int) WriteSingle
table_types.StringValueFromString(tb.YdbOperationId),
)
if len(tb.SourcePaths) > 0 {
d.AddValueParam("$paths", table_types.StringValueFromString(strings.Join(tb.SourcePaths, ",")))
d.AddValueParam("$paths", table_types.StringValueFromString(types.SerializeSourcePaths(tb.SourcePaths)))
}
if len(tb.SourcePathsToExclude) > 0 {
d.AddValueParam(
"$paths_to_exclude",
table_types.StringValueFromString(strings.Join(tb.SourcePathsToExclude, ",")),
table_types.StringValueFromString(types.SerializeSourcePaths(tb.SourcePathsToExclude)),
)
}
if tb.ParentOperationID != nil {
Expand All @@ -147,12 +147,12 @@ func BuildCreateOperationQuery(operation types.Operation, index int) WriteSingle
table_types.StringValueFromString(tbwr.YdbConnectionParams.Endpoint),
)
if len(tbwr.SourcePaths) > 0 {
d.AddValueParam("$paths", table_types.StringValueFromString(strings.Join(tbwr.SourcePaths, ",")))
d.AddValueParam("$paths", table_types.StringValueFromString(types.SerializeSourcePaths(tbwr.SourcePaths)))
}
if len(tbwr.SourcePathsToExclude) > 0 {
d.AddValueParam(
"$paths_to_exclude",
table_types.StringValueFromString(strings.Join(tbwr.SourcePathsToExclude, ",")),
table_types.StringValueFromString(types.SerializeSourcePaths(tbwr.SourcePathsToExclude)),
)
}
d.AddValueParam("$retries", table_types.Uint32Value(uint32(tbwr.Retries)))
Expand Down Expand Up @@ -205,7 +205,7 @@ func BuildCreateOperationQuery(operation types.Operation, index int) WriteSingle
)

if len(rb.SourcePaths) > 0 {
d.AddValueParam("$paths", table_types.StringValueFromString(strings.Join(rb.SourcePaths, ",")))
d.AddValueParam("$paths", table_types.StringValueFromString(types.SerializeSourcePaths(rb.SourcePaths)))
}
} else if operation.GetType() == types.OperationTypeDB {
db, ok := operation.(*types.DeleteBackupOperation)
Expand Down Expand Up @@ -311,7 +311,7 @@ func BuildCreateBackupQuery(b types.Backup, index int) WriteSingleTableQueryImpl
d.AddValueParam("$status", table_types.StringValueFromString(b.Status))
d.AddValueParam("$message", table_types.StringValueFromString(b.Message))
d.AddValueParam("$size", table_types.Int64Value(b.Size))
d.AddValueParam("$paths", table_types.StringValueFromString(strings.Join(b.SourcePaths, ",")))
d.AddValueParam("$paths", table_types.StringValueFromString(types.SerializeSourcePaths(b.SourcePaths)))
if b.ScheduleID != nil {
d.AddValueParam("$schedule_id", table_types.StringValueFromString(*b.ScheduleID))
}
Expand Down Expand Up @@ -356,12 +356,12 @@ func BuildCreateBackupScheduleQuery(schedule types.BackupSchedule, index int) Wr
d.AddValueParam("$ttl", table_types.IntervalValueFromDuration(schedule.ScheduleSettings.Ttl.AsDuration()))
}
if len(schedule.SourcePaths) > 0 {
d.AddValueParam("$paths", table_types.StringValueFromString(strings.Join(schedule.SourcePaths, ",")))
d.AddValueParam("$paths", table_types.StringValueFromString(types.SerializeSourcePaths(schedule.SourcePaths)))
}
if len(schedule.SourcePathsToExclude) > 0 {
d.AddValueParam(
"$paths_to_exclude",
table_types.StringValueFromString(strings.Join(schedule.SourcePathsToExclude, ",")),
table_types.StringValueFromString(types.SerializeSourcePaths(schedule.SourcePathsToExclude)),
)
}
if schedule.Audit != nil {
Expand Down Expand Up @@ -395,14 +395,14 @@ func BuildUpdateBackupScheduleQuery(schedule types.BackupSchedule, index int) Wr
d.AddValueParam("$crontab", table_types.StringValueFromString(schedule.ScheduleSettings.SchedulePattern.Crontab))

if len(schedule.SourcePaths) > 0 {
d.AddValueParam("$paths", table_types.StringValueFromString(strings.Join(schedule.SourcePaths, ",")))
d.AddValueParam("$paths", table_types.StringValueFromString(types.SerializeSourcePaths(schedule.SourcePaths)))
} else {
d.AddValueParam("$paths", table_types.NullableStringValueFromString(nil))
}
if len(schedule.SourcePathsToExclude) > 0 {
d.AddValueParam(
"$paths_to_exclude",
table_types.StringValueFromString(strings.Join(schedule.SourcePathsToExclude, ",")),
table_types.StringValueFromString(types.SerializeSourcePaths(schedule.SourcePathsToExclude)),
)
} else {
d.AddValueParam("$paths_to_exclude", table_types.NullableStringValueFromString(nil))
Expand Down
7 changes: 3 additions & 4 deletions internal/connectors/db/yql/queries/write_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@ package queries

import (
"context"
"strings"
"testing"
"time"

Expand Down Expand Up @@ -246,11 +245,11 @@ UPSERT INTO Operations (id, type, status, message, initiated, created_at, contai
),
table.ValueParam(
"$paths_1",
table_types.StringValueFromString(strings.Join(tbOp.SourcePaths, ",")),
table_types.StringValueFromString(types.SerializeSourcePaths(tbOp.SourcePaths)),
),
table.ValueParam(
"$paths_to_exclude_1",
table_types.StringValueFromString(strings.Join(tbOp.SourcePathsToExclude, ",")),
table_types.StringValueFromString(types.SerializeSourcePaths(tbOp.SourcePathsToExclude)),
),
)
)
Expand Down Expand Up @@ -307,7 +306,7 @@ func TestQueryBuilder_CreateBackupSchedule(t *testing.T) {
table.ValueParam(
"$ttl_0", table_types.IntervalValueFromDuration(schedule.ScheduleSettings.Ttl.AsDuration()),
),
table.ValueParam("$paths_0", table_types.StringValueFromString(strings.Join(schedule.SourcePaths, ","))),
table.ValueParam("$paths_0", table_types.StringValueFromString(types.SerializeSourcePaths(schedule.SourcePaths))),
table.ValueParam("$initiated_0", table_types.StringValueFromString(schedule.Audit.Creator)),
table.ValueParam("$created_at_0", table_types.TimestampValueFromTime(schedule.Audit.CreatedAt.AsTime())),
table.ValueParam(
Expand Down
31 changes: 31 additions & 0 deletions internal/types/source_paths.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
package types

import (
"encoding/base64"
"strings"
)

func ParseSourcePaths(str string) []string {
if str == "" {
return make([]string, 0)
}
codedSlice := strings.Split(str, ",")
slice := make([]string, len(codedSlice))
for i, s := range codedSlice {
data, err := base64.StdEncoding.DecodeString(s)
if err != nil {
slice[i] = s
} else {
slice[i] = string(data)
}
}
return slice
}

func SerializeSourcePaths(slice []string) string {
codedSlice := make([]string, len(slice))
for i, s := range slice {
codedSlice[i] = base64.StdEncoding.EncodeToString([]byte(s))
}
return strings.Join(codedSlice, ",")
}
92 changes: 92 additions & 0 deletions internal/types/source_paths_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
package types

import (
"reflect"
"testing"
)

func TestParseSourcePaths(t *testing.T) {
tests := []struct {
name string
input string
expected []string
}{
{
name: "Parse base64",
input: "L3Rlc3RpbmctZ2xvYmFsL2lhbQ==,L3BhdGgsd2l0aCxjb21tYXM=",
expected: []string{"/testing-global/iam", "/path,with,commas"},
},
{
name: "Parse plaintext",
input: "hello,world",
expected: []string{"hello", "world"},
},
{
name: "Empty input string",
input: "",
expected: []string{},
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := ParseSourcePaths(tt.input)
if !reflect.DeepEqual(result, tt.expected) {
t.Errorf("expected %v, got %v", tt.expected, result)
}
})
}
}

func TestSerializeSourcePaths(t *testing.T) {
tests := []struct {
name string
input []string
expected string
}{
{
name: "Serialize base64",
input: []string{"hello", "world"},
expected: "aGVsbG8=,d29ybGQ=",
},
{
name: "Serialize plaintext",
input: []string{"hello"},
expected: "aGVsbG8=",
},
{
name: "Serialize empty slice",
input: []string{},
expected: "",
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := SerializeSourcePaths(tt.input)
if result != tt.expected {
t.Errorf("expected %s, got %s", tt.expected, result)
}
})
}
}

func TestSymmetry(t *testing.T) {
tests := []struct {
input []string
}{
{
input: []string{"/testing-global/iam", "idk,,,strangepath"},
},
}

for _, tt := range tests {
t.Run("", func(t *testing.T) {
serialized := SerializeSourcePaths(tt.input)
parsed := ParseSourcePaths(serialized)
if !reflect.DeepEqual(parsed, tt.input) {
t.Errorf("expected %v, got %v", tt.input, parsed)
}
})
}
}

0 comments on commit 90bd3ed

Please sign in to comment.