From b97ad0e1340a424e0f86e4b074766caae9c37e11 Mon Sep 17 00:00:00 2001 From: Matt Lord Date: Wed, 8 Jan 2025 19:14:51 -0500 Subject: [PATCH] Add lastpk value for source and target Signed-off-by: Matt Lord --- .../tabletmanager/vdiff/table_differ.go | 85 ++++++++++++++----- .../tabletmanager/vdiff/table_plan.go | 5 ++ go/vt/vttablet/tabletmanager/vdiff/utils.go | 11 --- .../tabletmanager/vdiff/workflow_differ.go | 65 ++++++++++++-- 4 files changed, 124 insertions(+), 42 deletions(-) diff --git a/go/vt/vttablet/tabletmanager/vdiff/table_differ.go b/go/vt/vttablet/tabletmanager/vdiff/table_differ.go index 102d7535af9..eeaced94166 100644 --- a/go/vt/vttablet/tabletmanager/vdiff/table_differ.go +++ b/go/vt/vttablet/tabletmanager/vdiff/table_differ.go @@ -86,9 +86,10 @@ type tableDiffer struct { targetPrimitive engine.Primitive // sourceQuery is computed from the associated query for this table in the vreplication workflow's Rule Filter - sourceQuery string - table *tabletmanagerdatapb.TableDefinition - lastPK *querypb.QueryResult + sourceQuery string + table *tabletmanagerdatapb.TableDefinition + lastSourcePK *querypb.QueryResult + lastTargetPK *querypb.QueryResult // wgShardStreamers is used, with a cancellable context, to wait for all shard streamers // to finish after each diff is complete. @@ -349,7 +350,7 @@ func (td *tableDiffer) startTargetDataStream(ctx context.Context) error { ct := td.wd.ct gtidch := make(chan string, 1) ct.targetShardStreamer.result = make(chan *sqltypes.Result, 1) - go td.streamOneShard(ctx, ct.targetShardStreamer, td.tablePlan.targetQuery, td.lastPK, gtidch) + go td.streamOneShard(ctx, ct.targetShardStreamer, td.tablePlan.targetQuery, td.lastTargetPK, gtidch) gtid, ok := <-gtidch if !ok { log.Infof("streaming error: %v", ct.targetShardStreamer.err) @@ -364,7 +365,7 @@ func (td *tableDiffer) startSourceDataStreams(ctx context.Context) error { if err := td.forEachSource(func(source *migrationSource) error { gtidch := make(chan string, 1) source.result = make(chan *sqltypes.Result, 1) - go td.streamOneShard(ctx, source.shardStreamer, td.tablePlan.sourceQuery, td.lastPK, gtidch) + go td.streamOneShard(ctx, source.shardStreamer, td.tablePlan.sourceQuery, td.lastSourcePK, gtidch) gtid, ok := <-gtidch if !ok { @@ -721,42 +722,60 @@ func (td *tableDiffer) updateTableProgress(dbClient binlogplayer.DBClient, dr *D if dr == nil { return fmt.Errorf("cannot update progress with a nil diff report") } - var lastPK []byte var err error var query string rpt, err := json.Marshal(dr) if err != nil { return err } - if lastRow != nil { - lastPK, err = td.lastPKFromRow(lastRow) + + if lastRow == nil { + query, err = sqlparser.ParseAndBind(sqlUpdateTableNoProgress, + sqltypes.Int64BindVariable(dr.ProcessedRows), + sqltypes.StringBindVariable(string(rpt)), + sqltypes.Int64BindVariable(td.wd.ct.id), + sqltypes.StringBindVariable(td.table.Name), + ) if err != nil { return err } - + } else { + var lastSourcePK, lastTargetPK []byte + lastPK := make(map[string]string, 2) + if lastRow != nil { + lastSourcePK, err = td.lastSourcePKFromRow(lastRow) + if err != nil { + return err + } + lastPK["source"] = string(lastSourcePK) + lastTargetPK, err = td.lastTargetPKFromRow(lastRow) + if err != nil { + return err + } + lastPK["target"] = string(lastTargetPK) + } if td.wd.opts.CoreOptions.MaxDiffSeconds > 0 { // Update the in-memory lastPK as well so that we can restart the table // diff if needed. - lastpkpb := &querypb.QueryResult{} - if err := prototext.Unmarshal(lastPK, lastpkpb); err != nil { + lastSourcePKPB, lastTargetPKPB := &querypb.QueryResult{}, &querypb.QueryResult{} + if err := prototext.Unmarshal(lastSourcePK, lastSourcePKPB); err != nil { return err } - td.lastPK = lastpkpb + td.lastSourcePK = lastSourcePKPB + if err := prototext.Unmarshal(lastTargetPK, lastTargetPKPB); err != nil { + return err + } + td.lastTargetPK = lastTargetPKPB } - - query, err = sqlparser.ParseAndBind(sqlUpdateTableProgress, - sqltypes.Int64BindVariable(dr.ProcessedRows), - sqltypes.StringBindVariable(string(lastPK)), - sqltypes.StringBindVariable(string(rpt)), - sqltypes.Int64BindVariable(td.wd.ct.id), - sqltypes.StringBindVariable(td.table.Name), - ) + //log.Errorf("DEBUG: updateTableProgress lastPK map: %v", lastPK) + lastPKJS, err := json.Marshal(lastPK) if err != nil { return err } - } else { - query, err = sqlparser.ParseAndBind(sqlUpdateTableNoProgress, + //log.Errorf("DEBUG: updateTableProgress lastPK JSON: %v", lastPKJS) + query, err = sqlparser.ParseAndBind(sqlUpdateTableProgress, sqltypes.Int64BindVariable(dr.ProcessedRows), + sqltypes.StringBindVariable(string(lastPKJS)), sqltypes.StringBindVariable(string(rpt)), sqltypes.Int64BindVariable(td.wd.ct.id), sqltypes.StringBindVariable(td.table.Name), @@ -832,7 +851,7 @@ func updateTableMismatch(dbClient binlogplayer.DBClient, vdiffID int64, table st return nil } -func (td *tableDiffer) lastPKFromRow(row []sqltypes.Value) ([]byte, error) { +func (td *tableDiffer) lastTargetPKFromRow(row []sqltypes.Value) ([]byte, error) { pkColCnt := len(td.tablePlan.pkCols) pkFields := make([]*querypb.Field, pkColCnt) pkVals := make([]sqltypes.Value, pkColCnt) @@ -847,6 +866,26 @@ func (td *tableDiffer) lastPKFromRow(row []sqltypes.Value) ([]byte, error) { return buf, err } +func (td *tableDiffer) lastSourcePKFromRow(row []sqltypes.Value) ([]byte, error) { + if len(td.tablePlan.sourcePkCols) == 0 { + // If there are no PKs on the source then we use + // the same PK[E] columns as the target. + td.tablePlan.sourcePkCols = td.tablePlan.pkCols + } + pkColCnt := len(td.tablePlan.sourcePkCols) + pkFields := make([]*querypb.Field, pkColCnt) + pkVals := make([]sqltypes.Value, pkColCnt) + for i, colIndex := range td.tablePlan.sourcePkCols { + pkFields[i] = td.tablePlan.table.Fields[colIndex] + pkVals[i] = row[colIndex] + } + buf, err := prototext.Marshal(&querypb.QueryResult{ + Fields: pkFields, + Rows: []*querypb.Row{sqltypes.RowToProto3(pkVals)}, + }) + return buf, err +} + // If SourceTimeZone is defined in the BinlogSource (_vt.vreplication.source), the // VReplication workflow would have converted the datetime columns expecting the // source to have been in the SourceTimeZone and target in TargetTimeZone. We need diff --git a/go/vt/vttablet/tabletmanager/vdiff/table_plan.go b/go/vt/vttablet/tabletmanager/vdiff/table_plan.go index 836df8ffe94..f03cf20fb0a 100644 --- a/go/vt/vttablet/tabletmanager/vdiff/table_plan.go +++ b/go/vt/vttablet/tabletmanager/vdiff/table_plan.go @@ -53,6 +53,10 @@ type tablePlan struct { comparePKs []compareColInfo // pkCols has the indices of PK cols in the select list pkCols []int + // sourcePkCols has the indices of PK cols in the select + // list, but from the source keyspace. This is needed to + // properly store the lastpk for the source. + sourcePkCols []int // selectPks is the list of pk columns as they appear in the select clause for the diff. selectPks []int @@ -207,6 +211,7 @@ func (tp *tablePlan) findPKs(dbClient binlogplayer.DBClient, targetSelect *sqlpa if len(tp.table.PrimaryKeyColumns) == 0 { return nil } + var orderby sqlparser.OrderBy for _, pk := range tp.table.PrimaryKeyColumns { found := false diff --git a/go/vt/vttablet/tabletmanager/vdiff/utils.go b/go/vt/vttablet/tabletmanager/vdiff/utils.go index aeaa28972e0..68e8a6acb57 100644 --- a/go/vt/vttablet/tabletmanager/vdiff/utils.go +++ b/go/vt/vttablet/tabletmanager/vdiff/utils.go @@ -80,17 +80,6 @@ func insertVDiffLog(ctx context.Context, dbClient binlogplayer.DBClient, vdiffID } } -func stringListContains(lst []string, item string) bool { - contains := false - for _, t := range lst { - if t == item { - contains = true - break - } - } - return contains -} - // copyNonKeyRangeExpressions copies all expressions from the input WHERE clause // to the output WHERE clause except for any in_keyrange() expressions. func copyNonKeyRangeExpressions(where *sqlparser.Where) *sqlparser.Where { diff --git a/go/vt/vttablet/tabletmanager/vdiff/workflow_differ.go b/go/vt/vttablet/tabletmanager/vdiff/workflow_differ.go index ef30d8f14b0..f5b4584e1be 100644 --- a/go/vt/vttablet/tabletmanager/vdiff/workflow_differ.go +++ b/go/vt/vttablet/tabletmanager/vdiff/workflow_differ.go @@ -18,12 +18,15 @@ package vdiff import ( "context" + "encoding/json" "errors" "fmt" "reflect" + "slices" "strings" "time" + "golang.org/x/exp/maps" "google.golang.org/protobuf/encoding/prototext" "vitess.io/vitess/go/mysql/collations" @@ -344,7 +347,7 @@ func (wd *workflowDiffer) buildPlan(dbClient binlogplayer.DBClient, filter *binl for _, table := range schm.TableDefinitions { // if user specified tables explicitly only use those, otherwise diff all tables in workflow - if len(specifiedTables) != 0 && !stringListContains(specifiedTables, table.Name) { + if len(specifiedTables) != 0 && !slices.Contains(specifiedTables, table.Name) { continue } if schema.IsInternalOperationTableName(table.Name) && !schema.IsOnlineDDLTableName(table.Name) { @@ -370,15 +373,52 @@ func (wd *workflowDiffer) buildPlan(dbClient binlogplayer.DBClient, filter *binl } td := newTableDiffer(wd, table, sourceQuery) - lastpkpb, err := wd.getTableLastPK(dbClient, table.Name) + + lastPK, err := wd.getTableLastPK(dbClient, table.Name) if err != nil { return err } - td.lastPK = lastpkpb + td.lastSourcePK = lastPK["source"] + td.lastTargetPK = lastPK["target"] wd.tableDiffers[table.Name] = td if _, err := td.buildTablePlan(dbClient, wd.ct.vde.dbName, wd.collationEnv); err != nil { return err } + + // We get the PK columns from the source schema as well, as they can + // differ and determine the proper lastPK to use when saving progress. + // We use the first sourceShard as all of them should have the same schema. + sourceShardName := maps.Keys(wd.ct.sources)[0] + sourceShard, err := wd.ct.ts.GetShard(wd.ct.vde.ctx, wd.ct.sourceKeyspace, sourceShardName) + if err != nil { + return err + } + if sourceShard.PrimaryAlias == nil { + return fmt.Errorf("source shard %s has no primary", sourceShardName) + } + sourceTablet, err := wd.ct.ts.GetTablet(wd.ct.vde.ctx, sourceShard.PrimaryAlias) + if err != nil { + return fmt.Errorf("failed to get source shard %s primary", sourceShardName) + } + sourceSchema, err := wd.ct.tmc.GetSchema(wd.ct.vde.ctx, sourceTablet.Tablet, &tabletmanagerdatapb.GetSchemaRequest{ + Tables: []string{table.Name}, + }) + if err != nil { + return err + } + //log.Errorf("DEBUG: sourceTable.PrimaryKeyColumns: %v", sourceSchema.TableDefinitions[0].PrimaryKeyColumns) + sourcePKColumns := make(map[string]struct{}, len(sourceSchema.TableDefinitions[0].PrimaryKeyColumns)) + td.tablePlan.sourcePkCols = make([]int, 0, len(sourceSchema.TableDefinitions[0].PrimaryKeyColumns)) + for _, pkc := range sourceSchema.TableDefinitions[0].PrimaryKeyColumns { + sourcePKColumns[pkc] = struct{}{} + } + //log.Errorf("DEBUG: sourcePKColumns: %v", sourcePKColumns) + for i, pkc := range table.PrimaryKeyColumns { + if _, ok := sourcePKColumns[pkc]; ok { + td.tablePlan.sourcePkCols = append(td.tablePlan.sourcePkCols, i) + } + } + //log.Errorf("DEBUG: td.tablePlan.sourcePkCols: %v", td.tablePlan.sourcePkCols) } if len(wd.tableDiffers) == 0 { return fmt.Errorf("no tables found to diff, %s:%s, on tablet %v", @@ -388,7 +428,7 @@ func (wd *workflowDiffer) buildPlan(dbClient binlogplayer.DBClient, filter *binl } // getTableLastPK gets the lastPK protobuf message for a given vdiff table. -func (wd *workflowDiffer) getTableLastPK(dbClient binlogplayer.DBClient, tableName string) (*querypb.QueryResult, error) { +func (wd *workflowDiffer) getTableLastPK(dbClient binlogplayer.DBClient, tableName string) (map[string]*querypb.QueryResult, error) { query, err := sqlparser.ParseAndBind(sqlGetVDiffTable, sqltypes.Int64BindVariable(wd.ct.id), sqltypes.StringBindVariable(tableName), @@ -406,11 +446,20 @@ func (wd *workflowDiffer) getTableLastPK(dbClient binlogplayer.DBClient, tableNa return nil, err } if len(lastpk) != 0 { - var lastpkpb querypb.QueryResult - if err := prototext.Unmarshal(lastpk, &lastpkpb); err != nil { - return nil, err + lastPK := make(map[string]string, 2) + lastPKResults := make(map[string]*querypb.QueryResult, 2) + if err := json.Unmarshal(lastpk, &lastPK); err != nil { + return nil, vterrors.Wrapf(err, "failed to unmarshal lastpk JSON for table %s", tableName) + } + //log.Errorf("DEBUG: getTabletLastPK lastPKBytes: %v", lastPK) + for k, v := range lastPK { + lastPKResults[k] = &querypb.QueryResult{} + if err := prototext.Unmarshal([]byte(v), lastPKResults[k]); err != nil { + return nil, vterrors.Wrapf(err, "failed to unmarshal lastpk QueryResult for table %s", tableName) + } } - return &lastpkpb, nil + //log.Errorf("DEBUG: getTabletLastPK lastPKRResults: %v", lastPKResults) + return lastPKResults, nil } } return nil, nil