Skip to content

Commit

Permalink
Code improvements
Browse files Browse the repository at this point in the history
Signed-off-by: Matt Lord <[email protected]>
  • Loading branch information
mattlord committed Jan 9, 2025
1 parent c9f57f7 commit 69947cc
Show file tree
Hide file tree
Showing 2 changed files with 65 additions and 73 deletions.
92 changes: 57 additions & 35 deletions go/vt/vttablet/tabletmanager/vdiff/table_differ.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ import (
"sync"
"time"

"golang.org/x/exp/maps"
"google.golang.org/protobuf/encoding/prototext"

"vitess.io/vitess/go/mysql/collations"
Expand Down Expand Up @@ -194,7 +195,7 @@ func (td *tableDiffer) stopTargetVReplicationStreams(ctx context.Context, dbClie
return fmt.Errorf("stream %d has not started on tablet %v",
id, td.wd.ct.vde.thisTablet.Alias)
}
sourceBytes, err := row["source"].ToBytes()
sourceBytes, err := row[source].ToBytes()
if err != nil {
return err
}
Expand Down Expand Up @@ -520,8 +521,8 @@ func (td *tableDiffer) diff(ctx context.Context, coreOpts *tabletmanagerdatapb.V
}
dr.TableName = td.table.Name

sourceExecutor := newPrimitiveExecutor(ctx, td.sourcePrimitive, "source")
targetExecutor := newPrimitiveExecutor(ctx, td.targetPrimitive, "target")
sourceExecutor := newPrimitiveExecutor(ctx, td.sourcePrimitive, source)
targetExecutor := newPrimitiveExecutor(ctx, td.targetPrimitive, target)
var sourceRow, lastProcessedRow, targetRow []sqltypes.Value
advanceSource := true
advanceTarget := true
Expand Down Expand Up @@ -736,17 +737,21 @@ func (td *tableDiffer) updateTableProgress(dbClient binlogplayer.DBClient, dr *D
} else {
var lastSourcePK, lastTargetPK []byte
lastPK := make(map[string]string, 2)
if lastRow != nil {
lastSourcePK, err = td.lastSourcePKFromRow(lastRow)
if len(lastRow) != 0 {
lastTargetPK, err = td.lastPKFromRow(lastRow, td.tablePlan.pkCols)
if err != nil {
return err
}
lastPK["source"] = string(lastSourcePK)
lastTargetPK, err = td.lastTargetPKFromRow(lastRow)
if err != nil {
return err
lastPK[target] = string(lastTargetPK)
if len(td.tablePlan.sourcePkCols) == len(td.tablePlan.pkCols) {
lastPK[source] = string(lastTargetPK)
} else {
lastSourcePK, err = td.lastPKFromRow(lastRow, td.tablePlan.sourcePkCols)
if err != nil {
return err
}
lastPK[source] = string(lastSourcePK)
}
lastPK["target"] = string(lastTargetPK)
}
if td.wd.opts.CoreOptions.MaxDiffSeconds > 0 {
// Update the in-memory lastPK as well so that we can restart the table
Expand All @@ -761,12 +766,10 @@ func (td *tableDiffer) updateTableProgress(dbClient binlogplayer.DBClient, dr *D
}
td.lastTargetPK = lastTargetPKPB
}
//log.Errorf("DEBUG: updateTableProgress lastPK map: %v", lastPK)
lastPKJS, err := json.Marshal(lastPK)
if err != nil {
return err
}
//log.Errorf("DEBUG: updateTableProgress lastPK JSON: %v", lastPKJS)
query, err = sqlparser.ParseAndBind(sqlUpdateTableProgress,
sqltypes.Int64BindVariable(dr.ProcessedRows),
sqltypes.StringBindVariable(string(lastPKJS)),
Expand Down Expand Up @@ -845,31 +848,11 @@ func updateTableMismatch(dbClient binlogplayer.DBClient, vdiffID int64, table st
return nil
}

func (td *tableDiffer) lastTargetPKFromRow(row []sqltypes.Value) ([]byte, error) {
pkColCnt := len(td.tablePlan.pkCols)
pkFields := make([]*querypb.Field, pkColCnt)
pkVals := make([]sqltypes.Value, pkColCnt)
for i, colIndex := range td.tablePlan.pkCols {
pkFields[i] = td.tablePlan.table.Fields[colIndex]
pkVals[i] = row[colIndex]
}
buf, err := prototext.Marshal(&querypb.QueryResult{
Fields: pkFields,
Rows: []*querypb.Row{sqltypes.RowToProto3(pkVals)},
})
return buf, err
}

func (td *tableDiffer) lastSourcePKFromRow(row []sqltypes.Value) ([]byte, error) {
if len(td.tablePlan.sourcePkCols) == 0 {
// If there are no PKs on the source then we use
// the same PK[E] columns as the target.
td.tablePlan.sourcePkCols = td.tablePlan.pkCols
}
pkColCnt := len(td.tablePlan.sourcePkCols)
func (td *tableDiffer) lastPKFromRow(row []sqltypes.Value, pkCols []int) ([]byte, error) {
pkColCnt := len(pkCols)
pkFields := make([]*querypb.Field, pkColCnt)
pkVals := make([]sqltypes.Value, pkColCnt)
for i, colIndex := range td.tablePlan.sourcePkCols {
for i, colIndex := range pkCols {
pkFields[i] = td.tablePlan.table.Fields[colIndex]
pkVals[i] = row[colIndex]
}
Expand Down Expand Up @@ -926,6 +909,45 @@ func (td *tableDiffer) adjustForSourceTimeZone(targetSelectExprs sqlparser.Selec
return targetSelectExprs
}

// getSourcePKCols populates the sourcePkCols field in the tablePlan.
// We need this information in order to save the lastpk value for the
// source as the PK columns may differ between the source and target.
func (td *tableDiffer) getSourcePKCols() error {
sourceShardName := maps.Keys(td.wd.ct.sources)[0]
sourceTS, err := td.wd.getSourceTopoServer()
if err != nil {
return vterrors.Wrap(err, "failed to get source topo server")
}
sourceShard, err := sourceTS.GetShard(td.wd.ct.vde.ctx, td.wd.ct.sourceKeyspace, sourceShardName)
if err != nil {
return err
}
if sourceShard.PrimaryAlias == nil {
return fmt.Errorf("source shard %s has no primary", sourceShardName)
}
sourceTablet, err := sourceTS.GetTablet(td.wd.ct.vde.ctx, sourceShard.PrimaryAlias)
if err != nil {
return fmt.Errorf("failed to get source shard %s primary", sourceShardName)
}
sourceSchema, err := td.wd.ct.tmc.GetSchema(td.wd.ct.vde.ctx, sourceTablet.Tablet, &tabletmanagerdatapb.GetSchemaRequest{
Tables: []string{td.table.Name},
})
if err != nil {
return err
}
sourcePKColumns := make(map[string]struct{}, len(sourceSchema.TableDefinitions[0].PrimaryKeyColumns))
td.tablePlan.sourcePkCols = make([]int, 0, len(sourceSchema.TableDefinitions[0].PrimaryKeyColumns))
for _, pkc := range sourceSchema.TableDefinitions[0].PrimaryKeyColumns {
sourcePKColumns[pkc] = struct{}{}
}
for i, pkc := range td.table.PrimaryKeyColumns {
if _, ok := sourcePKColumns[pkc]; ok {
td.tablePlan.sourcePkCols = append(td.tablePlan.sourcePkCols, i)
}
}
return nil
}

func getColumnNameForSelectExpr(selectExpression sqlparser.SelectExpr) (string, error) {
aliasedExpr := selectExpression.(*sqlparser.AliasedExpr)
expr := aliasedExpr.Expr
Expand Down
46 changes: 8 additions & 38 deletions go/vt/vttablet/tabletmanager/vdiff/workflow_differ.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@ import (
"strings"
"time"

"golang.org/x/exp/maps"
"google.golang.org/protobuf/encoding/prototext"

"vitess.io/vitess/go/mysql/collations"
Expand All @@ -49,6 +48,11 @@ import (
vtrpcpb "vitess.io/vitess/go/vt/proto/vtrpc"
)

const (
source = "source"
target = "target"
)

// workflowDiffer has metadata and state for the vdiff of a single workflow on this tablet
// only one vdiff can be running for a workflow at any time.
type workflowDiffer struct {
Expand Down Expand Up @@ -379,8 +383,8 @@ func (wd *workflowDiffer) buildPlan(dbClient binlogplayer.DBClient, filter *binl
if err != nil {
return err
}
td.lastSourcePK = lastPK["source"]
td.lastTargetPK = lastPK["target"]
td.lastSourcePK = lastPK[source]
td.lastTargetPK = lastPK[target]
wd.tableDiffers[table.Name] = td
if _, err := td.buildTablePlan(dbClient, wd.ct.vde.dbName, wd.collationEnv); err != nil {
return err
Expand All @@ -389,41 +393,9 @@ func (wd *workflowDiffer) buildPlan(dbClient binlogplayer.DBClient, filter *binl
// We get the PK columns from the source schema as well, as they can
// differ and determine the proper lastPK to use when saving progress.
// We use the first sourceShard as all of them should have the same schema.
sourceShardName := maps.Keys(wd.ct.sources)[0]
sourceTS, err := wd.getSourceTopoServer()
if err != nil {
return vterrors.Wrap(err, "failed to get source topo server")
}
sourceShard, err := sourceTS.GetShard(wd.ct.vde.ctx, wd.ct.sourceKeyspace, sourceShardName)
if err != nil {
return err
}
if sourceShard.PrimaryAlias == nil {
return fmt.Errorf("source shard %s has no primary", sourceShardName)
}
sourceTablet, err := sourceTS.GetTablet(wd.ct.vde.ctx, sourceShard.PrimaryAlias)
if err != nil {
return fmt.Errorf("failed to get source shard %s primary", sourceShardName)
}
sourceSchema, err := wd.ct.tmc.GetSchema(wd.ct.vde.ctx, sourceTablet.Tablet, &tabletmanagerdatapb.GetSchemaRequest{
Tables: []string{table.Name},
})
if err != nil {
if err := td.getSourcePKCols(); err != nil {
return err
}
//log.Errorf("DEBUG: sourceTable.PrimaryKeyColumns: %v", sourceSchema.TableDefinitions[0].PrimaryKeyColumns)
sourcePKColumns := make(map[string]struct{}, len(sourceSchema.TableDefinitions[0].PrimaryKeyColumns))
td.tablePlan.sourcePkCols = make([]int, 0, len(sourceSchema.TableDefinitions[0].PrimaryKeyColumns))
for _, pkc := range sourceSchema.TableDefinitions[0].PrimaryKeyColumns {
sourcePKColumns[pkc] = struct{}{}
}
//log.Errorf("DEBUG: sourcePKColumns: %v", sourcePKColumns)
for i, pkc := range table.PrimaryKeyColumns {
if _, ok := sourcePKColumns[pkc]; ok {
td.tablePlan.sourcePkCols = append(td.tablePlan.sourcePkCols, i)
}
}
//log.Errorf("DEBUG: td.tablePlan.sourcePkCols: %v", td.tablePlan.sourcePkCols)
}
if len(wd.tableDiffers) == 0 {
return fmt.Errorf("no tables found to diff, %s:%s, on tablet %v",
Expand Down Expand Up @@ -456,14 +428,12 @@ func (wd *workflowDiffer) getTableLastPK(dbClient binlogplayer.DBClient, tableNa
if err := json.Unmarshal(lastpk, &lastPK); err != nil {
return nil, vterrors.Wrapf(err, "failed to unmarshal lastpk JSON for table %s", tableName)
}
//log.Errorf("DEBUG: getTabletLastPK lastPKBytes: %v", lastPK)
for k, v := range lastPK {
lastPKResults[k] = &querypb.QueryResult{}
if err := prototext.Unmarshal([]byte(v), lastPKResults[k]); err != nil {
return nil, vterrors.Wrapf(err, "failed to unmarshal lastpk QueryResult for table %s", tableName)
}
}
//log.Errorf("DEBUG: getTabletLastPK lastPKRResults: %v", lastPKResults)
return lastPKResults, nil
}
}
Expand Down

0 comments on commit 69947cc

Please sign in to comment.