diff --git a/go/vt/vttablet/tabletmanager/vdiff/table_differ.go b/go/vt/vttablet/tabletmanager/vdiff/table_differ.go index 102d7535af9..7526c218550 100644 --- a/go/vt/vttablet/tabletmanager/vdiff/table_differ.go +++ b/go/vt/vttablet/tabletmanager/vdiff/table_differ.go @@ -86,9 +86,10 @@ type tableDiffer struct { targetPrimitive engine.Primitive // sourceQuery is computed from the associated query for this table in the vreplication workflow's Rule Filter - sourceQuery string - table *tabletmanagerdatapb.TableDefinition - lastPK *querypb.QueryResult + sourceQuery string + table *tabletmanagerdatapb.TableDefinition + lastSourcePK *querypb.QueryResult + lastTargetPK *querypb.QueryResult // wgShardStreamers is used, with a cancellable context, to wait for all shard streamers // to finish after each diff is complete. @@ -349,7 +350,7 @@ func (td *tableDiffer) startTargetDataStream(ctx context.Context) error { ct := td.wd.ct gtidch := make(chan string, 1) ct.targetShardStreamer.result = make(chan *sqltypes.Result, 1) - go td.streamOneShard(ctx, ct.targetShardStreamer, td.tablePlan.targetQuery, td.lastPK, gtidch) + go td.streamOneShard(ctx, ct.targetShardStreamer, td.tablePlan.targetQuery, td.lastTargetPK, gtidch) gtid, ok := <-gtidch if !ok { log.Infof("streaming error: %v", ct.targetShardStreamer.err) @@ -364,7 +365,7 @@ func (td *tableDiffer) startSourceDataStreams(ctx context.Context) error { if err := td.forEachSource(func(source *migrationSource) error { gtidch := make(chan string, 1) source.result = make(chan *sqltypes.Result, 1) - go td.streamOneShard(ctx, source.shardStreamer, td.tablePlan.sourceQuery, td.lastPK, gtidch) + go td.streamOneShard(ctx, source.shardStreamer, td.tablePlan.sourceQuery, td.lastSourcePK, gtidch) gtid, ok := <-gtidch if !ok { @@ -527,13 +528,13 @@ func (td *tableDiffer) diff(ctx context.Context, coreOpts *tabletmanagerdatapb.V sourceExecutor := newPrimitiveExecutor(ctx, td.sourcePrimitive, "source") targetExecutor := newPrimitiveExecutor(ctx, td.targetPrimitive, "target") - var sourceRow, lastProcessedRow, targetRow []sqltypes.Value + var sourceRow, targetRow []sqltypes.Value advanceSource := true advanceTarget := true // Save our progress when we finish the run. defer func() { - if err := td.updateTableProgress(dbClient, dr, lastProcessedRow); err != nil { + if err := td.updateTableProgress(dbClient, dr, sourceRow, targetRow); err != nil { log.Errorf("Failed to update vdiff progress on %s table: %v", td.table.Name, err) } globalStats.RowsDiffedCount.Add(dr.ProcessedRows) @@ -544,8 +545,6 @@ func (td *tableDiffer) diff(ctx context.Context, coreOpts *tabletmanagerdatapb.V maxReportSampleRows := reportOpts.GetMaxSampleRows() for { - lastProcessedRow = sourceRow - select { case <-ctx.Done(): return nil, vterrors.Errorf(vtrpcpb.Code_CANCELED, "context has expired") @@ -683,7 +682,7 @@ func (td *tableDiffer) diff(ctx context.Context, coreOpts *tabletmanagerdatapb.V // approximate progress information but without too much overhead for when it's not // needed or even desired. if dr.ProcessedRows%1e4 == 0 { - if err := td.updateTableProgress(dbClient, dr, sourceRow); err != nil { + if err := td.updateTableProgress(dbClient, dr, sourceRow, targetRow); err != nil { return nil, err } } @@ -717,46 +716,68 @@ func (td *tableDiffer) compare(sourceRow, targetRow []sqltypes.Value, cols []com return 0, nil } -func (td *tableDiffer) updateTableProgress(dbClient binlogplayer.DBClient, dr *DiffReport, lastRow []sqltypes.Value) error { +func (td *tableDiffer) updateTableProgress(dbClient binlogplayer.DBClient, dr *DiffReport, lastSourceRow, lastTargetRow []sqltypes.Value) error { if dr == nil { return fmt.Errorf("cannot update progress with a nil diff report") } - var lastPK []byte var err error var query string rpt, err := json.Marshal(dr) if err != nil { return err } - if lastRow != nil { - lastPK, err = td.lastPKFromRow(lastRow) + + if lastSourceRow == nil && lastTargetRow == nil { + query, err = sqlparser.ParseAndBind(sqlUpdateTableNoProgress, + sqltypes.Int64BindVariable(dr.ProcessedRows), + sqltypes.StringBindVariable(string(rpt)), + sqltypes.Int64BindVariable(td.wd.ct.id), + sqltypes.StringBindVariable(td.table.Name), + ) if err != nil { return err } - + } else { + var lastSourcePK, lastTargetPK []byte + lastPK := make(map[string]string, 2) + if lastSourceRow != nil { + lastSourcePK, err = td.lastSourcePKFromRow(lastSourceRow) + if err != nil { + return err + } + lastPK["source"] = string(lastSourcePK) + log.Errorf("DEBUG: updateTableProgress lastSourcePK: %s", string(lastSourcePK)) + } + if lastTargetRow != nil { + lastTargetPK, err = td.lastTargetPKFromRow(lastTargetRow) + if err != nil { + return err + } + lastPK["target"] = string(lastTargetPK) + log.Errorf("DEBUG: updateTableProgress lastTargetPK: %s", string(lastTargetPK)) + } if td.wd.opts.CoreOptions.MaxDiffSeconds > 0 { // Update the in-memory lastPK as well so that we can restart the table // diff if needed. - lastpkpb := &querypb.QueryResult{} - if err := prototext.Unmarshal(lastPK, lastpkpb); err != nil { + lastSourcePKPB, lastTargetPKPB := &querypb.QueryResult{}, &querypb.QueryResult{} + if err := prototext.Unmarshal(lastSourcePK, lastSourcePKPB); err != nil { + return err + } + td.lastSourcePK = lastSourcePKPB + if err := prototext.Unmarshal(lastTargetPK, lastTargetPKPB); err != nil { return err } - td.lastPK = lastpkpb + td.lastTargetPK = lastTargetPKPB } - - query, err = sqlparser.ParseAndBind(sqlUpdateTableProgress, - sqltypes.Int64BindVariable(dr.ProcessedRows), - sqltypes.StringBindVariable(string(lastPK)), - sqltypes.StringBindVariable(string(rpt)), - sqltypes.Int64BindVariable(td.wd.ct.id), - sqltypes.StringBindVariable(td.table.Name), - ) + log.Errorf("DEBUG: updateTableProgress lastPK map: %v", lastPK) + lastPKJS, err := json.Marshal(lastPK) if err != nil { return err } - } else { - query, err = sqlparser.ParseAndBind(sqlUpdateTableNoProgress, + log.Errorf("DEBUG: updateTableProgress lastPK JSON: %v", lastPKJS) + query, err = sqlparser.ParseAndBind(sqlUpdateTableProgress, sqltypes.Int64BindVariable(dr.ProcessedRows), + sqltypes.StringBindVariable(string(lastPKJS)), sqltypes.StringBindVariable(string(rpt)), sqltypes.Int64BindVariable(td.wd.ct.id), sqltypes.StringBindVariable(td.table.Name), @@ -832,7 +853,7 @@ func updateTableMismatch(dbClient binlogplayer.DBClient, vdiffID int64, table st return nil } -func (td *tableDiffer) lastPKFromRow(row []sqltypes.Value) ([]byte, error) { +func (td *tableDiffer) lastTargetPKFromRow(row []sqltypes.Value) ([]byte, error) { pkColCnt := len(td.tablePlan.pkCols) pkFields := make([]*querypb.Field, pkColCnt) pkVals := make([]sqltypes.Value, pkColCnt) @@ -847,6 +868,26 @@ func (td *tableDiffer) lastPKFromRow(row []sqltypes.Value) ([]byte, error) { return buf, err } +func (td *tableDiffer) lastSourcePKFromRow(row []sqltypes.Value) ([]byte, error) { + if len(td.tablePlan.sourcePkCols) == 0 { + // If there are no PKs on the source then we use + // the same PK[E] columns as the target. + td.tablePlan.sourcePkCols = td.tablePlan.pkCols + } + pkColCnt := len(td.tablePlan.sourcePkCols) + pkFields := make([]*querypb.Field, pkColCnt) + pkVals := make([]sqltypes.Value, pkColCnt) + for i, colIndex := range td.tablePlan.sourcePkCols { + pkFields[i] = td.tablePlan.table.Fields[colIndex] + pkVals[i] = row[colIndex] + } + buf, err := prototext.Marshal(&querypb.QueryResult{ + Fields: pkFields, + Rows: []*querypb.Row{sqltypes.RowToProto3(pkVals)}, + }) + return buf, err +} + // If SourceTimeZone is defined in the BinlogSource (_vt.vreplication.source), the // VReplication workflow would have converted the datetime columns expecting the // source to have been in the SourceTimeZone and target in TargetTimeZone. We need diff --git a/go/vt/vttablet/tabletmanager/vdiff/table_plan.go b/go/vt/vttablet/tabletmanager/vdiff/table_plan.go index 836df8ffe94..f03cf20fb0a 100644 --- a/go/vt/vttablet/tabletmanager/vdiff/table_plan.go +++ b/go/vt/vttablet/tabletmanager/vdiff/table_plan.go @@ -53,6 +53,10 @@ type tablePlan struct { comparePKs []compareColInfo // pkCols has the indices of PK cols in the select list pkCols []int + // sourcePkCols has the indices of PK cols in the select + // list, but from the source keyspace. This is needed to + // properly store the lastpk for the source. + sourcePkCols []int // selectPks is the list of pk columns as they appear in the select clause for the diff. selectPks []int @@ -207,6 +211,7 @@ func (tp *tablePlan) findPKs(dbClient binlogplayer.DBClient, targetSelect *sqlpa if len(tp.table.PrimaryKeyColumns) == 0 { return nil } + var orderby sqlparser.OrderBy for _, pk := range tp.table.PrimaryKeyColumns { found := false diff --git a/go/vt/vttablet/tabletmanager/vdiff/workflow_differ.go b/go/vt/vttablet/tabletmanager/vdiff/workflow_differ.go index ef30d8f14b0..01291f97a26 100644 --- a/go/vt/vttablet/tabletmanager/vdiff/workflow_differ.go +++ b/go/vt/vttablet/tabletmanager/vdiff/workflow_differ.go @@ -18,12 +18,15 @@ package vdiff import ( "context" + "encoding/json" "errors" "fmt" "reflect" + "runtime/debug" "strings" "time" + "golang.org/x/exp/maps" "google.golang.org/protobuf/encoding/prototext" "vitess.io/vitess/go/mysql/collations" @@ -370,15 +373,59 @@ func (wd *workflowDiffer) buildPlan(dbClient binlogplayer.DBClient, filter *binl } td := newTableDiffer(wd, table, sourceQuery) - lastpkpb, err := wd.getTableLastPK(dbClient, table.Name) + + defer func() { + if r := recover(); r != nil { + log.Errorf("DEBUG: uncaught panic: %s", r) + log.Errorf("DEBUG: stack: %s", debug.Stack()) + } + }() + + lastPK, err := wd.getTableLastPK(dbClient, table.Name) if err != nil { return err } - td.lastPK = lastpkpb + td.lastSourcePK = lastPK["source"] + td.lastTargetPK = lastPK["target"] wd.tableDiffers[table.Name] = td if _, err := td.buildTablePlan(dbClient, wd.ct.vde.dbName, wd.collationEnv); err != nil { return err } + + // We get the PK columns from the source schema as well, as they can + // differ and determine the proper lastPK to use when saving progress. + // We use the first sourceShard as all of them should have the same schema. + sourceShardName := maps.Keys(wd.ct.sources)[0] + sourceShard, err := wd.ct.ts.GetShard(wd.ct.vde.ctx, wd.ct.sourceKeyspace, sourceShardName) + if err != nil { + return err + } + if sourceShard.PrimaryAlias == nil { + return fmt.Errorf("source shard %s has no primary", sourceShardName) + } + sourceTablet, err := wd.ct.ts.GetTablet(wd.ct.vde.ctx, sourceShard.PrimaryAlias) + if err != nil { + return fmt.Errorf("failed to get source shard %s primary", sourceShardName) + } + sourceSchema, err := wd.ct.tmc.GetSchema(wd.ct.vde.ctx, sourceTablet.Tablet, &tabletmanagerdatapb.GetSchemaRequest{ + Tables: []string{table.Name}, + }) + if err != nil { + return err + } + log.Errorf("DEBUG: sourceTable.PrimaryKeyColumns: %v", sourceSchema.TableDefinitions[0].PrimaryKeyColumns) + sourcePKColumns := make(map[string]struct{}, len(sourceSchema.TableDefinitions[0].PrimaryKeyColumns)) + td.tablePlan.sourcePkCols = make([]int, 0, len(sourceSchema.TableDefinitions[0].PrimaryKeyColumns)) + for _, pkc := range sourceSchema.TableDefinitions[0].PrimaryKeyColumns { + sourcePKColumns[pkc] = struct{}{} + } + log.Errorf("DEBUG: sourcePKColumns: %v", sourcePKColumns) + for i, pkc := range table.PrimaryKeyColumns { + if _, ok := sourcePKColumns[pkc]; ok { + td.tablePlan.sourcePkCols = append(td.tablePlan.sourcePkCols, i) + } + } + log.Errorf("DEBUG: td.tablePlan.sourcePkCols: %v", td.tablePlan.sourcePkCols) } if len(wd.tableDiffers) == 0 { return fmt.Errorf("no tables found to diff, %s:%s, on tablet %v", @@ -388,7 +435,7 @@ func (wd *workflowDiffer) buildPlan(dbClient binlogplayer.DBClient, filter *binl } // getTableLastPK gets the lastPK protobuf message for a given vdiff table. -func (wd *workflowDiffer) getTableLastPK(dbClient binlogplayer.DBClient, tableName string) (*querypb.QueryResult, error) { +func (wd *workflowDiffer) getTableLastPK(dbClient binlogplayer.DBClient, tableName string) (map[string]*querypb.QueryResult, error) { query, err := sqlparser.ParseAndBind(sqlGetVDiffTable, sqltypes.Int64BindVariable(wd.ct.id), sqltypes.StringBindVariable(tableName), @@ -406,11 +453,19 @@ func (wd *workflowDiffer) getTableLastPK(dbClient binlogplayer.DBClient, tableNa return nil, err } if len(lastpk) != 0 { - var lastpkpb querypb.QueryResult - if err := prototext.Unmarshal(lastpk, &lastpkpb); err != nil { + lastPKBytes := make(map[string][]byte, 2) + lastPKResults := make(map[string]*querypb.QueryResult, 2) + if err := json.Unmarshal(lastpk, &lastPKBytes); err != nil { return nil, err } - return &lastpkpb, nil + log.Errorf("DEBUG: getTabletLastPK lastPKBytes: %v", lastPKBytes) + for k, v := range lastPKBytes { + if err := prototext.Unmarshal(v, lastPKResults[k]); err != nil { + return nil, err + } + } + log.Errorf("DEBUG: getTabletLastPK lastPKRResults: %v", lastPKResults) + return lastPKResults, nil } } return nil, nil