diff --git a/internal/connectors/db/process_result_set.go b/internal/connectors/db/process_result_set.go index 90c6184e..ae007aa5 100644 --- a/internal/connectors/db/process_result_set.go +++ b/internal/connectors/db/process_result_set.go @@ -2,7 +2,6 @@ package db import ( "fmt" - "strings" "time" "ydbcp/internal/types" @@ -108,7 +107,7 @@ func ReadBackupFromResultSet(res result.Result) (*types.Backup, error) { sourcePathsSlice := make([]string, 0) if sourcePaths != nil { - sourcePathsSlice = strings.Split(*sourcePaths, ",") + sourcePathsSlice = types.ParseSourcePaths(*sourcePaths) } return &types.Backup{ @@ -191,10 +190,10 @@ func ReadOperationFromResultSet(res result.Result) (types.Operation, error) { sourcePathsSlice := make([]string, 0) sourcePathsToExcludeSlice := make([]string, 0) if sourcePaths != nil { - sourcePathsSlice = strings.Split(*sourcePaths, ",") + sourcePathsSlice = types.ParseSourcePaths(*sourcePaths) } if sourcePathsToExclude != nil { - sourcePathsToExcludeSlice = strings.Split(*sourcePathsToExclude, ",") + sourcePathsToExcludeSlice = types.ParseSourcePaths(*sourcePathsToExclude) } if updatedAt != nil { @@ -363,10 +362,10 @@ func ReadBackupScheduleFromResultSet(res result.Result, withRPOInfo bool) (*type var sourcePathsSlice []string var sourcePathsToExcludeSlice []string if sourcePaths != nil { - sourcePathsSlice = strings.Split(*sourcePaths, ",") + sourcePathsSlice = types.ParseSourcePaths(*sourcePaths) } if sourcePathsToExclude != nil { - sourcePathsToExcludeSlice = strings.Split(*sourcePathsToExclude, ",") + sourcePathsToExcludeSlice = types.ParseSourcePaths(*sourcePathsToExclude) } var ttlDuration *durationpb.Duration diff --git a/internal/connectors/db/yql/queries/write.go b/internal/connectors/db/yql/queries/write.go index cd5dd942..80236c37 100644 --- a/internal/connectors/db/yql/queries/write.go +++ b/internal/connectors/db/yql/queries/write.go @@ -118,12 +118,12 @@ func BuildCreateOperationQuery(operation types.Operation, index int) WriteSingle table_types.StringValueFromString(tb.YdbOperationId), ) if len(tb.SourcePaths) > 0 { - d.AddValueParam("$paths", table_types.StringValueFromString(strings.Join(tb.SourcePaths, ","))) + d.AddValueParam("$paths", table_types.StringValueFromString(types.SerializeSourcePaths(tb.SourcePaths))) } if len(tb.SourcePathsToExclude) > 0 { d.AddValueParam( "$paths_to_exclude", - table_types.StringValueFromString(strings.Join(tb.SourcePathsToExclude, ",")), + table_types.StringValueFromString(types.SerializeSourcePaths(tb.SourcePathsToExclude)), ) } if tb.ParentOperationID != nil { @@ -147,12 +147,12 @@ func BuildCreateOperationQuery(operation types.Operation, index int) WriteSingle table_types.StringValueFromString(tbwr.YdbConnectionParams.Endpoint), ) if len(tbwr.SourcePaths) > 0 { - d.AddValueParam("$paths", table_types.StringValueFromString(strings.Join(tbwr.SourcePaths, ","))) + d.AddValueParam("$paths", table_types.StringValueFromString(types.SerializeSourcePaths(tbwr.SourcePaths))) } if len(tbwr.SourcePathsToExclude) > 0 { d.AddValueParam( "$paths_to_exclude", - table_types.StringValueFromString(strings.Join(tbwr.SourcePathsToExclude, ",")), + table_types.StringValueFromString(types.SerializeSourcePaths(tbwr.SourcePathsToExclude)), ) } d.AddValueParam("$retries", table_types.Uint32Value(uint32(tbwr.Retries))) @@ -205,7 +205,7 @@ func BuildCreateOperationQuery(operation types.Operation, index int) WriteSingle ) if len(rb.SourcePaths) > 0 { - d.AddValueParam("$paths", table_types.StringValueFromString(strings.Join(rb.SourcePaths, ","))) + d.AddValueParam("$paths", table_types.StringValueFromString(types.SerializeSourcePaths(rb.SourcePaths))) } } else if operation.GetType() == types.OperationTypeDB { db, ok := operation.(*types.DeleteBackupOperation) @@ -311,7 +311,7 @@ func BuildCreateBackupQuery(b types.Backup, index int) WriteSingleTableQueryImpl d.AddValueParam("$status", table_types.StringValueFromString(b.Status)) d.AddValueParam("$message", table_types.StringValueFromString(b.Message)) d.AddValueParam("$size", table_types.Int64Value(b.Size)) - d.AddValueParam("$paths", table_types.StringValueFromString(strings.Join(b.SourcePaths, ","))) + d.AddValueParam("$paths", table_types.StringValueFromString(types.SerializeSourcePaths(b.SourcePaths))) if b.ScheduleID != nil { d.AddValueParam("$schedule_id", table_types.StringValueFromString(*b.ScheduleID)) } @@ -356,12 +356,12 @@ func BuildCreateBackupScheduleQuery(schedule types.BackupSchedule, index int) Wr d.AddValueParam("$ttl", table_types.IntervalValueFromDuration(schedule.ScheduleSettings.Ttl.AsDuration())) } if len(schedule.SourcePaths) > 0 { - d.AddValueParam("$paths", table_types.StringValueFromString(strings.Join(schedule.SourcePaths, ","))) + d.AddValueParam("$paths", table_types.StringValueFromString(types.SerializeSourcePaths(schedule.SourcePaths))) } if len(schedule.SourcePathsToExclude) > 0 { d.AddValueParam( "$paths_to_exclude", - table_types.StringValueFromString(strings.Join(schedule.SourcePathsToExclude, ",")), + table_types.StringValueFromString(types.SerializeSourcePaths(schedule.SourcePathsToExclude)), ) } if schedule.Audit != nil { @@ -395,14 +395,14 @@ func BuildUpdateBackupScheduleQuery(schedule types.BackupSchedule, index int) Wr d.AddValueParam("$crontab", table_types.StringValueFromString(schedule.ScheduleSettings.SchedulePattern.Crontab)) if len(schedule.SourcePaths) > 0 { - d.AddValueParam("$paths", table_types.StringValueFromString(strings.Join(schedule.SourcePaths, ","))) + d.AddValueParam("$paths", table_types.StringValueFromString(types.SerializeSourcePaths(schedule.SourcePaths))) } else { d.AddValueParam("$paths", table_types.NullableStringValueFromString(nil)) } if len(schedule.SourcePathsToExclude) > 0 { d.AddValueParam( "$paths_to_exclude", - table_types.StringValueFromString(strings.Join(schedule.SourcePathsToExclude, ",")), + table_types.StringValueFromString(types.SerializeSourcePaths(schedule.SourcePathsToExclude)), ) } else { d.AddValueParam("$paths_to_exclude", table_types.NullableStringValueFromString(nil)) diff --git a/internal/connectors/db/yql/queries/write_test.go b/internal/connectors/db/yql/queries/write_test.go index a1121b76..3a39f663 100644 --- a/internal/connectors/db/yql/queries/write_test.go +++ b/internal/connectors/db/yql/queries/write_test.go @@ -2,7 +2,6 @@ package queries import ( "context" - "strings" "testing" "time" @@ -246,11 +245,11 @@ UPSERT INTO Operations (id, type, status, message, initiated, created_at, contai ), table.ValueParam( "$paths_1", - table_types.StringValueFromString(strings.Join(tbOp.SourcePaths, ",")), + table_types.StringValueFromString(types.SerializeSourcePaths(tbOp.SourcePaths)), ), table.ValueParam( "$paths_to_exclude_1", - table_types.StringValueFromString(strings.Join(tbOp.SourcePathsToExclude, ",")), + table_types.StringValueFromString(types.SerializeSourcePaths(tbOp.SourcePathsToExclude)), ), ) ) @@ -307,7 +306,7 @@ func TestQueryBuilder_CreateBackupSchedule(t *testing.T) { table.ValueParam( "$ttl_0", table_types.IntervalValueFromDuration(schedule.ScheduleSettings.Ttl.AsDuration()), ), - table.ValueParam("$paths_0", table_types.StringValueFromString(strings.Join(schedule.SourcePaths, ","))), + table.ValueParam("$paths_0", table_types.StringValueFromString(types.SerializeSourcePaths(schedule.SourcePaths))), table.ValueParam("$initiated_0", table_types.StringValueFromString(schedule.Audit.Creator)), table.ValueParam("$created_at_0", table_types.TimestampValueFromTime(schedule.Audit.CreatedAt.AsTime())), table.ValueParam( diff --git a/internal/types/source_paths.go b/internal/types/source_paths.go new file mode 100644 index 00000000..9ded2d27 --- /dev/null +++ b/internal/types/source_paths.go @@ -0,0 +1,31 @@ +package types + +import ( + "encoding/base64" + "strings" +) + +func ParseSourcePaths(str string) []string { + if str == "" { + return make([]string, 0) + } + codedSlice := strings.Split(str, ",") + slice := make([]string, len(codedSlice)) + for i, s := range codedSlice { + data, err := base64.StdEncoding.DecodeString(s) + if err != nil { + slice[i] = s + } else { + slice[i] = string(data) + } + } + return slice +} + +func SerializeSourcePaths(slice []string) string { + codedSlice := make([]string, len(slice)) + for i, s := range slice { + codedSlice[i] = base64.StdEncoding.EncodeToString([]byte(s)) + } + return strings.Join(codedSlice, ",") +} diff --git a/internal/types/source_paths_test.go b/internal/types/source_paths_test.go new file mode 100644 index 00000000..2e150473 --- /dev/null +++ b/internal/types/source_paths_test.go @@ -0,0 +1,92 @@ +package types + +import ( + "reflect" + "testing" +) + +func TestParseSourcePaths(t *testing.T) { + tests := []struct { + name string + input string + expected []string + }{ + { + name: "Parse base64", + input: "L3Rlc3RpbmctZ2xvYmFsL2lhbQ==,L3BhdGgsd2l0aCxjb21tYXM=", + expected: []string{"/testing-global/iam", "/path,with,commas"}, + }, + { + name: "Parse plaintext", + input: "hello,world", + expected: []string{"hello", "world"}, + }, + { + name: "Empty input string", + input: "", + expected: []string{}, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := ParseSourcePaths(tt.input) + if !reflect.DeepEqual(result, tt.expected) { + t.Errorf("expected %v, got %v", tt.expected, result) + } + }) + } +} + +func TestSerializeSourcePaths(t *testing.T) { + tests := []struct { + name string + input []string + expected string + }{ + { + name: "Serialize base64", + input: []string{"hello", "world"}, + expected: "aGVsbG8=,d29ybGQ=", + }, + { + name: "Serialize plaintext", + input: []string{"hello"}, + expected: "aGVsbG8=", + }, + { + name: "Serialize empty slice", + input: []string{}, + expected: "", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := SerializeSourcePaths(tt.input) + if result != tt.expected { + t.Errorf("expected %s, got %s", tt.expected, result) + } + }) + } +} + +func TestSymmetry(t *testing.T) { + tests := []struct { + input []string + }{ + { + input: []string{"/testing-global/iam", "idk,,,strangepath"}, + }, + } + + for _, tt := range tests { + t.Run("", func(t *testing.T) { + serialized := SerializeSourcePaths(tt.input) + parsed := ParseSourcePaths(serialized) + if !reflect.DeepEqual(parsed, tt.input) { + t.Errorf("expected %v, got %v", tt.input, parsed) + } + }) + } +}