Skip to content

Commit

Permalink
Add lastpk value for source and target
Browse files Browse the repository at this point in the history
Signed-off-by: Matt Lord <[email protected]>
  • Loading branch information
mattlord committed Jan 9, 2025
1 parent 06def14 commit b4670c0
Show file tree
Hide file tree
Showing 3 changed files with 136 additions and 35 deletions.
99 changes: 70 additions & 29 deletions go/vt/vttablet/tabletmanager/vdiff/table_differ.go
Original file line number Diff line number Diff line change
Expand Up @@ -86,9 +86,10 @@ type tableDiffer struct {
targetPrimitive engine.Primitive

// sourceQuery is computed from the associated query for this table in the vreplication workflow's Rule Filter
sourceQuery string
table *tabletmanagerdatapb.TableDefinition
lastPK *querypb.QueryResult
sourceQuery string
table *tabletmanagerdatapb.TableDefinition
lastSourcePK *querypb.QueryResult
lastTargetPK *querypb.QueryResult

// wgShardStreamers is used, with a cancellable context, to wait for all shard streamers
// to finish after each diff is complete.
Expand Down Expand Up @@ -349,7 +350,7 @@ func (td *tableDiffer) startTargetDataStream(ctx context.Context) error {
ct := td.wd.ct
gtidch := make(chan string, 1)
ct.targetShardStreamer.result = make(chan *sqltypes.Result, 1)
go td.streamOneShard(ctx, ct.targetShardStreamer, td.tablePlan.targetQuery, td.lastPK, gtidch)
go td.streamOneShard(ctx, ct.targetShardStreamer, td.tablePlan.targetQuery, td.lastTargetPK, gtidch)
gtid, ok := <-gtidch
if !ok {
log.Infof("streaming error: %v", ct.targetShardStreamer.err)
Expand All @@ -364,7 +365,7 @@ func (td *tableDiffer) startSourceDataStreams(ctx context.Context) error {
if err := td.forEachSource(func(source *migrationSource) error {
gtidch := make(chan string, 1)
source.result = make(chan *sqltypes.Result, 1)
go td.streamOneShard(ctx, source.shardStreamer, td.tablePlan.sourceQuery, td.lastPK, gtidch)
go td.streamOneShard(ctx, source.shardStreamer, td.tablePlan.sourceQuery, td.lastSourcePK, gtidch)

gtid, ok := <-gtidch
if !ok {
Expand Down Expand Up @@ -527,13 +528,13 @@ func (td *tableDiffer) diff(ctx context.Context, coreOpts *tabletmanagerdatapb.V

sourceExecutor := newPrimitiveExecutor(ctx, td.sourcePrimitive, "source")
targetExecutor := newPrimitiveExecutor(ctx, td.targetPrimitive, "target")
var sourceRow, lastProcessedRow, targetRow []sqltypes.Value
var sourceRow, targetRow []sqltypes.Value
advanceSource := true
advanceTarget := true

// Save our progress when we finish the run.
defer func() {
if err := td.updateTableProgress(dbClient, dr, lastProcessedRow); err != nil {
if err := td.updateTableProgress(dbClient, dr, sourceRow, targetRow); err != nil {
log.Errorf("Failed to update vdiff progress on %s table: %v", td.table.Name, err)
}
globalStats.RowsDiffedCount.Add(dr.ProcessedRows)
Expand All @@ -544,8 +545,6 @@ func (td *tableDiffer) diff(ctx context.Context, coreOpts *tabletmanagerdatapb.V
maxReportSampleRows := reportOpts.GetMaxSampleRows()

for {
lastProcessedRow = sourceRow

select {
case <-ctx.Done():
return nil, vterrors.Errorf(vtrpcpb.Code_CANCELED, "context has expired")
Expand Down Expand Up @@ -683,7 +682,7 @@ func (td *tableDiffer) diff(ctx context.Context, coreOpts *tabletmanagerdatapb.V
// approximate progress information but without too much overhead for when it's not
// needed or even desired.
if dr.ProcessedRows%1e4 == 0 {
if err := td.updateTableProgress(dbClient, dr, sourceRow); err != nil {
if err := td.updateTableProgress(dbClient, dr, sourceRow, targetRow); err != nil {
return nil, err
}
}
Expand Down Expand Up @@ -717,46 +716,68 @@ func (td *tableDiffer) compare(sourceRow, targetRow []sqltypes.Value, cols []com
return 0, nil
}

func (td *tableDiffer) updateTableProgress(dbClient binlogplayer.DBClient, dr *DiffReport, lastRow []sqltypes.Value) error {
func (td *tableDiffer) updateTableProgress(dbClient binlogplayer.DBClient, dr *DiffReport, lastSourceRow, lastTargetRow []sqltypes.Value) error {
if dr == nil {
return fmt.Errorf("cannot update progress with a nil diff report")
}
var lastPK []byte
var err error
var query string
rpt, err := json.Marshal(dr)
if err != nil {
return err
}
if lastRow != nil {
lastPK, err = td.lastPKFromRow(lastRow)

if lastSourceRow == nil && lastTargetRow == nil {
query, err = sqlparser.ParseAndBind(sqlUpdateTableNoProgress,
sqltypes.Int64BindVariable(dr.ProcessedRows),
sqltypes.StringBindVariable(string(rpt)),
sqltypes.Int64BindVariable(td.wd.ct.id),
sqltypes.StringBindVariable(td.table.Name),
)
if err != nil {
return err
}

} else {
var lastSourcePK, lastTargetPK []byte
lastPK := make(map[string]string, 2)
if lastSourceRow != nil {
lastSourcePK, err = td.lastSourcePKFromRow(lastSourceRow)
if err != nil {
return err
}
lastPK["source"] = string(lastSourcePK)
log.Errorf("DEBUG: updateTableProgress lastSourcePK: %s", string(lastSourcePK))
}
if lastTargetRow != nil {
lastTargetPK, err = td.lastTargetPKFromRow(lastTargetRow)
if err != nil {
return err
}
lastPK["target"] = string(lastTargetPK)
log.Errorf("DEBUG: updateTableProgress lastTargetPK: %s", string(lastTargetPK))
}
if td.wd.opts.CoreOptions.MaxDiffSeconds > 0 {
// Update the in-memory lastPK as well so that we can restart the table
// diff if needed.
lastpkpb := &querypb.QueryResult{}
if err := prototext.Unmarshal(lastPK, lastpkpb); err != nil {
lastSourcePKPB, lastTargetPKPB := &querypb.QueryResult{}, &querypb.QueryResult{}
if err := prototext.Unmarshal(lastSourcePK, lastSourcePKPB); err != nil {
return err
}
td.lastSourcePK = lastSourcePKPB
if err := prototext.Unmarshal(lastTargetPK, lastTargetPKPB); err != nil {
return err
}
td.lastPK = lastpkpb
td.lastTargetPK = lastTargetPKPB
}

query, err = sqlparser.ParseAndBind(sqlUpdateTableProgress,
sqltypes.Int64BindVariable(dr.ProcessedRows),
sqltypes.StringBindVariable(string(lastPK)),
sqltypes.StringBindVariable(string(rpt)),
sqltypes.Int64BindVariable(td.wd.ct.id),
sqltypes.StringBindVariable(td.table.Name),
)
log.Errorf("DEBUG: updateTableProgress lastPK map: %v", lastPK)
lastPKJS, err := json.Marshal(lastPK)
if err != nil {
return err
}
} else {
query, err = sqlparser.ParseAndBind(sqlUpdateTableNoProgress,
log.Errorf("DEBUG: updateTableProgress lastPK JSON: %v", lastPKJS)
query, err = sqlparser.ParseAndBind(sqlUpdateTableProgress,
sqltypes.Int64BindVariable(dr.ProcessedRows),
sqltypes.StringBindVariable(string(lastPKJS)),
sqltypes.StringBindVariable(string(rpt)),
sqltypes.Int64BindVariable(td.wd.ct.id),
sqltypes.StringBindVariable(td.table.Name),
Expand Down Expand Up @@ -832,7 +853,7 @@ func updateTableMismatch(dbClient binlogplayer.DBClient, vdiffID int64, table st
return nil
}

func (td *tableDiffer) lastPKFromRow(row []sqltypes.Value) ([]byte, error) {
func (td *tableDiffer) lastTargetPKFromRow(row []sqltypes.Value) ([]byte, error) {
pkColCnt := len(td.tablePlan.pkCols)
pkFields := make([]*querypb.Field, pkColCnt)
pkVals := make([]sqltypes.Value, pkColCnt)
Expand All @@ -847,6 +868,26 @@ func (td *tableDiffer) lastPKFromRow(row []sqltypes.Value) ([]byte, error) {
return buf, err
}

func (td *tableDiffer) lastSourcePKFromRow(row []sqltypes.Value) ([]byte, error) {
if len(td.tablePlan.sourcePkCols) == 0 {
// If there are no PKs on the source then we use
// the same PK[E] columns as the target.
td.tablePlan.sourcePkCols = td.tablePlan.pkCols
}
pkColCnt := len(td.tablePlan.sourcePkCols)
pkFields := make([]*querypb.Field, pkColCnt)
pkVals := make([]sqltypes.Value, pkColCnt)
for i, colIndex := range td.tablePlan.sourcePkCols {
pkFields[i] = td.tablePlan.table.Fields[colIndex]
pkVals[i] = row[colIndex]
}
buf, err := prototext.Marshal(&querypb.QueryResult{
Fields: pkFields,
Rows: []*querypb.Row{sqltypes.RowToProto3(pkVals)},
})
return buf, err
}

// If SourceTimeZone is defined in the BinlogSource (_vt.vreplication.source), the
// VReplication workflow would have converted the datetime columns expecting the
// source to have been in the SourceTimeZone and target in TargetTimeZone. We need
Expand Down
5 changes: 5 additions & 0 deletions go/vt/vttablet/tabletmanager/vdiff/table_plan.go
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,10 @@ type tablePlan struct {
comparePKs []compareColInfo
// pkCols has the indices of PK cols in the select list
pkCols []int
// sourcePkCols has the indices of PK cols in the select
// list, but from the source keyspace. This is needed to
// properly store the lastpk for the source.
sourcePkCols []int

// selectPks is the list of pk columns as they appear in the select clause for the diff.
selectPks []int
Expand Down Expand Up @@ -207,6 +211,7 @@ func (tp *tablePlan) findPKs(dbClient binlogplayer.DBClient, targetSelect *sqlpa
if len(tp.table.PrimaryKeyColumns) == 0 {
return nil
}

var orderby sqlparser.OrderBy
for _, pk := range tp.table.PrimaryKeyColumns {
found := false
Expand Down
67 changes: 61 additions & 6 deletions go/vt/vttablet/tabletmanager/vdiff/workflow_differ.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,15 @@ package vdiff

import (
"context"
"encoding/json"
"errors"
"fmt"
"reflect"
"runtime/debug"
"strings"
"time"

"golang.org/x/exp/maps"
"google.golang.org/protobuf/encoding/prototext"

"vitess.io/vitess/go/mysql/collations"
Expand Down Expand Up @@ -370,15 +373,59 @@ func (wd *workflowDiffer) buildPlan(dbClient binlogplayer.DBClient, filter *binl
}

td := newTableDiffer(wd, table, sourceQuery)
lastpkpb, err := wd.getTableLastPK(dbClient, table.Name)

defer func() {
if r := recover(); r != nil {
log.Errorf("DEBUG: uncaught panic: %s", r)
log.Errorf("DEBUG: stack: %s", debug.Stack())
}
}()

lastPK, err := wd.getTableLastPK(dbClient, table.Name)
if err != nil {
return err
}
td.lastPK = lastpkpb
td.lastSourcePK = lastPK["source"]
td.lastTargetPK = lastPK["target"]
wd.tableDiffers[table.Name] = td
if _, err := td.buildTablePlan(dbClient, wd.ct.vde.dbName, wd.collationEnv); err != nil {
return err
}

// We get the PK columns from the source schema as well, as they can
// differ and determine the proper lastPK to use when saving progress.
// We use the first sourceShard as all of them should have the same schema.
sourceShardName := maps.Keys(wd.ct.sources)[0]
sourceShard, err := wd.ct.ts.GetShard(wd.ct.vde.ctx, wd.ct.sourceKeyspace, sourceShardName)
if err != nil {
return err
}
if sourceShard.PrimaryAlias == nil {
return fmt.Errorf("source shard %s has no primary", sourceShardName)
}
sourceTablet, err := wd.ct.ts.GetTablet(wd.ct.vde.ctx, sourceShard.PrimaryAlias)
if err != nil {
return fmt.Errorf("failed to get source shard %s primary", sourceShardName)
}
sourceSchema, err := wd.ct.tmc.GetSchema(wd.ct.vde.ctx, sourceTablet.Tablet, &tabletmanagerdatapb.GetSchemaRequest{
Tables: []string{table.Name},
})
if err != nil {
return err
}
log.Errorf("DEBUG: sourceTable.PrimaryKeyColumns: %v", sourceSchema.TableDefinitions[0].PrimaryKeyColumns)
sourcePKColumns := make(map[string]struct{}, len(sourceSchema.TableDefinitions[0].PrimaryKeyColumns))
td.tablePlan.sourcePkCols = make([]int, 0, len(sourceSchema.TableDefinitions[0].PrimaryKeyColumns))
for _, pkc := range sourceSchema.TableDefinitions[0].PrimaryKeyColumns {
sourcePKColumns[pkc] = struct{}{}
}
log.Errorf("DEBUG: sourcePKColumns: %v", sourcePKColumns)
for i, pkc := range table.PrimaryKeyColumns {
if _, ok := sourcePKColumns[pkc]; ok {
td.tablePlan.sourcePkCols = append(td.tablePlan.sourcePkCols, i)
}
}
log.Errorf("DEBUG: td.tablePlan.sourcePkCols: %v", td.tablePlan.sourcePkCols)
}
if len(wd.tableDiffers) == 0 {
return fmt.Errorf("no tables found to diff, %s:%s, on tablet %v",
Expand All @@ -388,7 +435,7 @@ func (wd *workflowDiffer) buildPlan(dbClient binlogplayer.DBClient, filter *binl
}

// getTableLastPK gets the lastPK protobuf message for a given vdiff table.
func (wd *workflowDiffer) getTableLastPK(dbClient binlogplayer.DBClient, tableName string) (*querypb.QueryResult, error) {
func (wd *workflowDiffer) getTableLastPK(dbClient binlogplayer.DBClient, tableName string) (map[string]*querypb.QueryResult, error) {
query, err := sqlparser.ParseAndBind(sqlGetVDiffTable,
sqltypes.Int64BindVariable(wd.ct.id),
sqltypes.StringBindVariable(tableName),
Expand All @@ -406,11 +453,19 @@ func (wd *workflowDiffer) getTableLastPK(dbClient binlogplayer.DBClient, tableNa
return nil, err
}
if len(lastpk) != 0 {
var lastpkpb querypb.QueryResult
if err := prototext.Unmarshal(lastpk, &lastpkpb); err != nil {
lastPKBytes := make(map[string][]byte, 2)
lastPKResults := make(map[string]*querypb.QueryResult, 2)
if err := json.Unmarshal(lastpk, &lastPKBytes); err != nil {
return nil, err
}
return &lastpkpb, nil
log.Errorf("DEBUG: getTabletLastPK lastPKBytes: %v", lastPKBytes)
for k, v := range lastPKBytes {
if err := prototext.Unmarshal(v, lastPKResults[k]); err != nil {
return nil, err
}
}
log.Errorf("DEBUG: getTabletLastPK lastPKRResults: %v", lastPKResults)
return lastPKResults, nil
}
}
return nil, nil
Expand Down

0 comments on commit b4670c0

Please sign in to comment.