Skip to content

Commit

Permalink
Add lastpk value for source and target
Browse files Browse the repository at this point in the history
Signed-off-by: Matt Lord <[email protected]>
  • Loading branch information
mattlord committed Jan 9, 2025
1 parent 06def14 commit b97ad0e
Show file tree
Hide file tree
Showing 4 changed files with 124 additions and 42 deletions.
85 changes: 62 additions & 23 deletions go/vt/vttablet/tabletmanager/vdiff/table_differ.go
Original file line number Diff line number Diff line change
Expand Up @@ -86,9 +86,10 @@ type tableDiffer struct {
targetPrimitive engine.Primitive

// sourceQuery is computed from the associated query for this table in the vreplication workflow's Rule Filter
sourceQuery string
table *tabletmanagerdatapb.TableDefinition
lastPK *querypb.QueryResult
sourceQuery string
table *tabletmanagerdatapb.TableDefinition
lastSourcePK *querypb.QueryResult
lastTargetPK *querypb.QueryResult

// wgShardStreamers is used, with a cancellable context, to wait for all shard streamers
// to finish after each diff is complete.
Expand Down Expand Up @@ -349,7 +350,7 @@ func (td *tableDiffer) startTargetDataStream(ctx context.Context) error {
ct := td.wd.ct
gtidch := make(chan string, 1)
ct.targetShardStreamer.result = make(chan *sqltypes.Result, 1)
go td.streamOneShard(ctx, ct.targetShardStreamer, td.tablePlan.targetQuery, td.lastPK, gtidch)
go td.streamOneShard(ctx, ct.targetShardStreamer, td.tablePlan.targetQuery, td.lastTargetPK, gtidch)
gtid, ok := <-gtidch
if !ok {
log.Infof("streaming error: %v", ct.targetShardStreamer.err)
Expand All @@ -364,7 +365,7 @@ func (td *tableDiffer) startSourceDataStreams(ctx context.Context) error {
if err := td.forEachSource(func(source *migrationSource) error {
gtidch := make(chan string, 1)
source.result = make(chan *sqltypes.Result, 1)
go td.streamOneShard(ctx, source.shardStreamer, td.tablePlan.sourceQuery, td.lastPK, gtidch)
go td.streamOneShard(ctx, source.shardStreamer, td.tablePlan.sourceQuery, td.lastSourcePK, gtidch)

gtid, ok := <-gtidch
if !ok {
Expand Down Expand Up @@ -721,42 +722,60 @@ func (td *tableDiffer) updateTableProgress(dbClient binlogplayer.DBClient, dr *D
if dr == nil {
return fmt.Errorf("cannot update progress with a nil diff report")
}
var lastPK []byte
var err error
var query string
rpt, err := json.Marshal(dr)
if err != nil {
return err
}
if lastRow != nil {
lastPK, err = td.lastPKFromRow(lastRow)

if lastRow == nil {
query, err = sqlparser.ParseAndBind(sqlUpdateTableNoProgress,
sqltypes.Int64BindVariable(dr.ProcessedRows),
sqltypes.StringBindVariable(string(rpt)),
sqltypes.Int64BindVariable(td.wd.ct.id),
sqltypes.StringBindVariable(td.table.Name),
)
if err != nil {
return err
}

} else {
var lastSourcePK, lastTargetPK []byte
lastPK := make(map[string]string, 2)
if lastRow != nil {
lastSourcePK, err = td.lastSourcePKFromRow(lastRow)
if err != nil {
return err
}
lastPK["source"] = string(lastSourcePK)
lastTargetPK, err = td.lastTargetPKFromRow(lastRow)
if err != nil {
return err
}
lastPK["target"] = string(lastTargetPK)
}
if td.wd.opts.CoreOptions.MaxDiffSeconds > 0 {
// Update the in-memory lastPK as well so that we can restart the table
// diff if needed.
lastpkpb := &querypb.QueryResult{}
if err := prototext.Unmarshal(lastPK, lastpkpb); err != nil {
lastSourcePKPB, lastTargetPKPB := &querypb.QueryResult{}, &querypb.QueryResult{}
if err := prototext.Unmarshal(lastSourcePK, lastSourcePKPB); err != nil {
return err
}
td.lastPK = lastpkpb
td.lastSourcePK = lastSourcePKPB
if err := prototext.Unmarshal(lastTargetPK, lastTargetPKPB); err != nil {
return err
}
td.lastTargetPK = lastTargetPKPB
}

query, err = sqlparser.ParseAndBind(sqlUpdateTableProgress,
sqltypes.Int64BindVariable(dr.ProcessedRows),
sqltypes.StringBindVariable(string(lastPK)),
sqltypes.StringBindVariable(string(rpt)),
sqltypes.Int64BindVariable(td.wd.ct.id),
sqltypes.StringBindVariable(td.table.Name),
)
//log.Errorf("DEBUG: updateTableProgress lastPK map: %v", lastPK)
lastPKJS, err := json.Marshal(lastPK)
if err != nil {
return err
}
} else {
query, err = sqlparser.ParseAndBind(sqlUpdateTableNoProgress,
//log.Errorf("DEBUG: updateTableProgress lastPK JSON: %v", lastPKJS)
query, err = sqlparser.ParseAndBind(sqlUpdateTableProgress,
sqltypes.Int64BindVariable(dr.ProcessedRows),
sqltypes.StringBindVariable(string(lastPKJS)),
sqltypes.StringBindVariable(string(rpt)),
sqltypes.Int64BindVariable(td.wd.ct.id),
sqltypes.StringBindVariable(td.table.Name),
Expand Down Expand Up @@ -832,7 +851,7 @@ func updateTableMismatch(dbClient binlogplayer.DBClient, vdiffID int64, table st
return nil
}

func (td *tableDiffer) lastPKFromRow(row []sqltypes.Value) ([]byte, error) {
func (td *tableDiffer) lastTargetPKFromRow(row []sqltypes.Value) ([]byte, error) {
pkColCnt := len(td.tablePlan.pkCols)
pkFields := make([]*querypb.Field, pkColCnt)
pkVals := make([]sqltypes.Value, pkColCnt)
Expand All @@ -847,6 +866,26 @@ func (td *tableDiffer) lastPKFromRow(row []sqltypes.Value) ([]byte, error) {
return buf, err
}

func (td *tableDiffer) lastSourcePKFromRow(row []sqltypes.Value) ([]byte, error) {
if len(td.tablePlan.sourcePkCols) == 0 {
// If there are no PKs on the source then we use
// the same PK[E] columns as the target.
td.tablePlan.sourcePkCols = td.tablePlan.pkCols
}
pkColCnt := len(td.tablePlan.sourcePkCols)
pkFields := make([]*querypb.Field, pkColCnt)
pkVals := make([]sqltypes.Value, pkColCnt)
for i, colIndex := range td.tablePlan.sourcePkCols {
pkFields[i] = td.tablePlan.table.Fields[colIndex]
pkVals[i] = row[colIndex]
}
buf, err := prototext.Marshal(&querypb.QueryResult{
Fields: pkFields,
Rows: []*querypb.Row{sqltypes.RowToProto3(pkVals)},
})
return buf, err
}

// If SourceTimeZone is defined in the BinlogSource (_vt.vreplication.source), the
// VReplication workflow would have converted the datetime columns expecting the
// source to have been in the SourceTimeZone and target in TargetTimeZone. We need
Expand Down
5 changes: 5 additions & 0 deletions go/vt/vttablet/tabletmanager/vdiff/table_plan.go
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,10 @@ type tablePlan struct {
comparePKs []compareColInfo
// pkCols has the indices of PK cols in the select list
pkCols []int
// sourcePkCols has the indices of PK cols in the select
// list, but from the source keyspace. This is needed to
// properly store the lastpk for the source.
sourcePkCols []int

// selectPks is the list of pk columns as they appear in the select clause for the diff.
selectPks []int
Expand Down Expand Up @@ -207,6 +211,7 @@ func (tp *tablePlan) findPKs(dbClient binlogplayer.DBClient, targetSelect *sqlpa
if len(tp.table.PrimaryKeyColumns) == 0 {
return nil
}

var orderby sqlparser.OrderBy
for _, pk := range tp.table.PrimaryKeyColumns {
found := false
Expand Down
11 changes: 0 additions & 11 deletions go/vt/vttablet/tabletmanager/vdiff/utils.go
Original file line number Diff line number Diff line change
Expand Up @@ -80,17 +80,6 @@ func insertVDiffLog(ctx context.Context, dbClient binlogplayer.DBClient, vdiffID
}
}

func stringListContains(lst []string, item string) bool {
contains := false
for _, t := range lst {
if t == item {
contains = true
break
}
}
return contains
}

// copyNonKeyRangeExpressions copies all expressions from the input WHERE clause
// to the output WHERE clause except for any in_keyrange() expressions.
func copyNonKeyRangeExpressions(where *sqlparser.Where) *sqlparser.Where {
Expand Down
65 changes: 57 additions & 8 deletions go/vt/vttablet/tabletmanager/vdiff/workflow_differ.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,15 @@ package vdiff

import (
"context"
"encoding/json"
"errors"
"fmt"
"reflect"
"slices"
"strings"
"time"

"golang.org/x/exp/maps"
"google.golang.org/protobuf/encoding/prototext"

"vitess.io/vitess/go/mysql/collations"
Expand Down Expand Up @@ -344,7 +347,7 @@ func (wd *workflowDiffer) buildPlan(dbClient binlogplayer.DBClient, filter *binl

for _, table := range schm.TableDefinitions {
// if user specified tables explicitly only use those, otherwise diff all tables in workflow
if len(specifiedTables) != 0 && !stringListContains(specifiedTables, table.Name) {
if len(specifiedTables) != 0 && !slices.Contains(specifiedTables, table.Name) {
continue
}
if schema.IsInternalOperationTableName(table.Name) && !schema.IsOnlineDDLTableName(table.Name) {
Expand All @@ -370,15 +373,52 @@ func (wd *workflowDiffer) buildPlan(dbClient binlogplayer.DBClient, filter *binl
}

td := newTableDiffer(wd, table, sourceQuery)
lastpkpb, err := wd.getTableLastPK(dbClient, table.Name)

lastPK, err := wd.getTableLastPK(dbClient, table.Name)
if err != nil {
return err
}
td.lastPK = lastpkpb
td.lastSourcePK = lastPK["source"]
td.lastTargetPK = lastPK["target"]
wd.tableDiffers[table.Name] = td
if _, err := td.buildTablePlan(dbClient, wd.ct.vde.dbName, wd.collationEnv); err != nil {
return err
}

// We get the PK columns from the source schema as well, as they can
// differ and determine the proper lastPK to use when saving progress.
// We use the first sourceShard as all of them should have the same schema.
sourceShardName := maps.Keys(wd.ct.sources)[0]
sourceShard, err := wd.ct.ts.GetShard(wd.ct.vde.ctx, wd.ct.sourceKeyspace, sourceShardName)
if err != nil {
return err
}
if sourceShard.PrimaryAlias == nil {
return fmt.Errorf("source shard %s has no primary", sourceShardName)
}
sourceTablet, err := wd.ct.ts.GetTablet(wd.ct.vde.ctx, sourceShard.PrimaryAlias)
if err != nil {
return fmt.Errorf("failed to get source shard %s primary", sourceShardName)
}
sourceSchema, err := wd.ct.tmc.GetSchema(wd.ct.vde.ctx, sourceTablet.Tablet, &tabletmanagerdatapb.GetSchemaRequest{
Tables: []string{table.Name},
})
if err != nil {
return err
}
//log.Errorf("DEBUG: sourceTable.PrimaryKeyColumns: %v", sourceSchema.TableDefinitions[0].PrimaryKeyColumns)
sourcePKColumns := make(map[string]struct{}, len(sourceSchema.TableDefinitions[0].PrimaryKeyColumns))
td.tablePlan.sourcePkCols = make([]int, 0, len(sourceSchema.TableDefinitions[0].PrimaryKeyColumns))
for _, pkc := range sourceSchema.TableDefinitions[0].PrimaryKeyColumns {
sourcePKColumns[pkc] = struct{}{}
}
//log.Errorf("DEBUG: sourcePKColumns: %v", sourcePKColumns)
for i, pkc := range table.PrimaryKeyColumns {
if _, ok := sourcePKColumns[pkc]; ok {
td.tablePlan.sourcePkCols = append(td.tablePlan.sourcePkCols, i)
}
}
//log.Errorf("DEBUG: td.tablePlan.sourcePkCols: %v", td.tablePlan.sourcePkCols)
}
if len(wd.tableDiffers) == 0 {
return fmt.Errorf("no tables found to diff, %s:%s, on tablet %v",
Expand All @@ -388,7 +428,7 @@ func (wd *workflowDiffer) buildPlan(dbClient binlogplayer.DBClient, filter *binl
}

// getTableLastPK gets the lastPK protobuf message for a given vdiff table.
func (wd *workflowDiffer) getTableLastPK(dbClient binlogplayer.DBClient, tableName string) (*querypb.QueryResult, error) {
func (wd *workflowDiffer) getTableLastPK(dbClient binlogplayer.DBClient, tableName string) (map[string]*querypb.QueryResult, error) {
query, err := sqlparser.ParseAndBind(sqlGetVDiffTable,
sqltypes.Int64BindVariable(wd.ct.id),
sqltypes.StringBindVariable(tableName),
Expand All @@ -406,11 +446,20 @@ func (wd *workflowDiffer) getTableLastPK(dbClient binlogplayer.DBClient, tableNa
return nil, err
}
if len(lastpk) != 0 {
var lastpkpb querypb.QueryResult
if err := prototext.Unmarshal(lastpk, &lastpkpb); err != nil {
return nil, err
lastPK := make(map[string]string, 2)
lastPKResults := make(map[string]*querypb.QueryResult, 2)
if err := json.Unmarshal(lastpk, &lastPK); err != nil {
return nil, vterrors.Wrapf(err, "failed to unmarshal lastpk JSON for table %s", tableName)
}
//log.Errorf("DEBUG: getTabletLastPK lastPKBytes: %v", lastPK)
for k, v := range lastPK {
lastPKResults[k] = &querypb.QueryResult{}
if err := prototext.Unmarshal([]byte(v), lastPKResults[k]); err != nil {
return nil, vterrors.Wrapf(err, "failed to unmarshal lastpk QueryResult for table %s", tableName)
}
}
return &lastpkpb, nil
//log.Errorf("DEBUG: getTabletLastPK lastPKRResults: %v", lastPKResults)
return lastPKResults, nil
}
}
return nil, nil
Expand Down

0 comments on commit b97ad0e

Please sign in to comment.