Skip to content

Commit

Permalink
[release-20.0] Reference Table DML Join Fix (#17414) (#17473)
Browse files Browse the repository at this point in the history
Signed-off-by: Harshit Gangal <[email protected]>
Co-authored-by: Harshit Gangal <[email protected]>
  • Loading branch information
vitess-bot[bot] and harshit-gangal authored Jan 13, 2025
1 parent 6b3c47c commit 192fa94
Show file tree
Hide file tree
Showing 11 changed files with 192 additions and 22 deletions.
19 changes: 19 additions & 0 deletions go/vt/sqlparser/ast_funcs.go
Original file line number Diff line number Diff line change
Expand Up @@ -2285,6 +2285,25 @@ func RemoveKeyspace(in SQLNode) {
})
}

// RemoveKeyspaceIgnoreSysSchema removes the Qualifier.Qualifier on all ColNames and Qualifier on all TableNames in the AST
// except for the system schema.
func RemoveKeyspaceIgnoreSysSchema(in SQLNode) {
Rewrite(in, nil, func(cursor *Cursor) bool {
switch expr := cursor.Node().(type) {
case *ColName:
if expr.Qualifier.Qualifier.NotEmpty() && !SystemSchema(expr.Qualifier.Qualifier.String()) {
expr.Qualifier.Qualifier = NewIdentifierCS("")
}
case TableName:
if expr.Qualifier.NotEmpty() && !SystemSchema(expr.Qualifier.String()) {
expr.Qualifier = NewIdentifierCS("")
cursor.Replace(expr)
}
}
return true
})
}

func convertStringToInt(integer string) int {
val, _ := strconv.Atoi(integer)
return val
Expand Down
4 changes: 2 additions & 2 deletions go/vt/vtgate/planbuilder/operators/delete.go
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,7 @@ func createDeleteWithInputOp(ctx *plancontext.PlanningContext, del *sqlparser.De
}

var delOps []dmlOp
for _, target := range ctx.SemTable.Targets.Constituents() {
for _, target := range ctx.SemTable.DMLTargets.Constituents() {
op := createDeleteOpWithTarget(ctx, target, del.Ignore)
delOps = append(delOps, op)
}
Expand Down Expand Up @@ -336,7 +336,7 @@ func updateQueryGraphWithSource(ctx *plancontext.PlanningContext, input Operator
return op, NoRewrite
}
if len(qg.Tables) > 1 {
panic(vterrors.VT12001("DELETE on reference table with join"))
panic(vterrors.VT12001("DML on reference table with join"))
}
for _, tbl := range qg.Tables {
if tbl.ID != tblID {
Expand Down
16 changes: 10 additions & 6 deletions go/vt/vtgate/planbuilder/operators/join_merging.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ import (
// If they can be merged, a new operator with the merged routing is returned
// If they cannot be merged, nil is returned.
func mergeJoinInputs(ctx *plancontext.PlanningContext, lhs, rhs Operator, joinPredicates []sqlparser.Expr, m *joinMerger) *Route {
lhsRoute, rhsRoute, routingA, routingB, a, b, sameKeyspace := prepareInputRoutes(lhs, rhs)
lhsRoute, rhsRoute, routingA, routingB, a, b, sameKeyspace := prepareInputRoutes(ctx, lhs, rhs)
if lhsRoute == nil {
return nil
}
Expand Down Expand Up @@ -91,13 +91,13 @@ func mergeAnyShardRoutings(ctx *plancontext.PlanningContext, a, b *AnyShardRouti
}
}

func prepareInputRoutes(lhs Operator, rhs Operator) (*Route, *Route, Routing, Routing, routingType, routingType, bool) {
func prepareInputRoutes(ctx *plancontext.PlanningContext, lhs Operator, rhs Operator) (*Route, *Route, Routing, Routing, routingType, routingType, bool) {
lhsRoute, rhsRoute := operatorsToRoutes(lhs, rhs)
if lhsRoute == nil || rhsRoute == nil {
return nil, nil, nil, nil, 0, 0, false
}

lhsRoute, rhsRoute, routingA, routingB, sameKeyspace := getRoutesOrAlternates(lhsRoute, rhsRoute)
lhsRoute, rhsRoute, routingA, routingB, sameKeyspace := getRoutesOrAlternates(ctx, lhsRoute, rhsRoute)

a, b := getRoutingType(routingA), getRoutingType(routingB)
if getTypeName(routingA) < getTypeName(routingB) {
Expand Down Expand Up @@ -155,7 +155,7 @@ func (rt routingType) String() string {

// getRoutesOrAlternates gets the Routings from each Route. If they are from different keyspaces,
// we check if this is a table with alternates in other keyspaces that we can use
func getRoutesOrAlternates(lhsRoute, rhsRoute *Route) (*Route, *Route, Routing, Routing, bool) {
func getRoutesOrAlternates(ctx *plancontext.PlanningContext, lhsRoute, rhsRoute *Route) (*Route, *Route, Routing, Routing, bool) {
routingA := lhsRoute.Routing
routingB := rhsRoute.Routing
sameKeyspace := routingA.Keyspace() == routingB.Keyspace()
Expand All @@ -167,13 +167,17 @@ func getRoutesOrAlternates(lhsRoute, rhsRoute *Route) (*Route, *Route, Routing,
return lhsRoute, rhsRoute, routingA, routingB, sameKeyspace
}

if refA, ok := routingA.(*AnyShardRouting); ok {
// If we have a reference route, we will try to find an alternate route in same keyspace as other routing keyspace.
// If the reference route is part of DML table update target, alternate keyspace route cannot be considered.
if refA, ok := routingA.(*AnyShardRouting); ok &&
!TableID(lhsRoute).IsOverlapping(ctx.SemTable.DMLTargets) {
if altARoute := refA.AlternateInKeyspace(routingB.Keyspace()); altARoute != nil {
return altARoute, rhsRoute, altARoute.Routing, routingB, true
}
}

if refB, ok := routingB.(*AnyShardRouting); ok {
if refB, ok := routingB.(*AnyShardRouting); ok &&
!TableID(rhsRoute).IsOverlapping(ctx.SemTable.DMLTargets) {
if altBRoute := refB.AlternateInKeyspace(routingA.Keyspace()); altBRoute != nil {
return lhsRoute, altBRoute, routingA, altBRoute.Routing, true
}
Expand Down
2 changes: 1 addition & 1 deletion go/vt/vtgate/planbuilder/operators/subquery_planning.go
Original file line number Diff line number Diff line change
Expand Up @@ -694,7 +694,7 @@ func mergeSubqueryInputs(ctx *plancontext.PlanningContext, in, out Operator, joi
return nil
}

inRoute, outRoute, inRouting, outRouting, sameKeyspace := getRoutesOrAlternates(inRoute, outRoute)
inRoute, outRoute, inRouting, outRouting, sameKeyspace := getRoutesOrAlternates(ctx, inRoute, outRoute)
inner, outer := getRoutingType(inRouting), getRoutingType(outRouting)

switch {
Expand Down
2 changes: 1 addition & 1 deletion go/vt/vtgate/planbuilder/operators/union_merging.go
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@ func mergeUnionInputs(
lhsExprs, rhsExprs sqlparser.SelectExprs,
distinct bool,
) (Operator, sqlparser.SelectExprs) {
lhsRoute, rhsRoute, routingA, routingB, a, b, sameKeyspace := prepareInputRoutes(lhs, rhs)
lhsRoute, rhsRoute, routingA, routingB, a, b, sameKeyspace := prepareInputRoutes(ctx, lhs, rhs)
if lhsRoute == nil {
return nil, nil
}
Expand Down
4 changes: 2 additions & 2 deletions go/vt/vtgate/planbuilder/operators/update.go
Original file line number Diff line number Diff line change
Expand Up @@ -171,7 +171,7 @@ func createUpdateWithInputOp(ctx *plancontext.PlanningContext, upd *sqlparser.Up
ueMap := prepareUpdateExpressionList(ctx, upd)

var updOps []dmlOp
for _, target := range ctx.SemTable.Targets.Constituents() {
for _, target := range ctx.SemTable.DMLTargets.Constituents() {
op := createUpdateOpWithTarget(ctx, upd, target, ueMap[target])
updOps = append(updOps, op)
}
Expand Down Expand Up @@ -318,7 +318,7 @@ func errIfUpdateNotSupported(ctx *plancontext.PlanningContext, stmt *sqlparser.U
}
}

// Now we check if any of the foreign key columns that are being udpated have dependencies on other updated columns.
// Now we check if any of the foreign key columns that are being updated have dependencies on other updated columns.
// This is unsafe, and we currently don't support this in Vitess.
if err := ctx.SemTable.ErrIfFkDependentColumnUpdated(stmt.Exprs); err != nil {
panic(err)
Expand Down
4 changes: 3 additions & 1 deletion go/vt/vtgate/planbuilder/plan_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,7 @@ func (s *planTestSuite) TestPlan() {
s.addPKsProvided(vschemaWrapper.V, "user", []string{"user_extra"}, []string{"id", "user_id"})
s.addPKsProvided(vschemaWrapper.V, "ordering", []string{"order"}, []string{"oid", "region_id"})
s.addPKsProvided(vschemaWrapper.V, "ordering", []string{"order_event"}, []string{"oid", "ename"})
s.addPKsProvided(vschemaWrapper.V, "main", []string{"source_of_ref"}, []string{"id"})

// You will notice that some tests expect user.Id instead of user.id.
// This is because we now pre-create vindex columns in the symbol
Expand Down Expand Up @@ -304,6 +305,7 @@ func (s *planTestSuite) TestOne() {
s.addPKsProvided(lv, "user", []string{"user_extra"}, []string{"id", "user_id"})
s.addPKsProvided(lv, "ordering", []string{"order"}, []string{"oid", "region_id"})
s.addPKsProvided(lv, "ordering", []string{"order_event"}, []string{"oid", "ename"})
s.addPKsProvided(lv, "main", []string{"source_of_ref"}, []string{"id"})
vschema := &vschemawrapper.VSchemaWrapper{
V: lv,
TestBuilder: TestBuilder,
Expand Down Expand Up @@ -684,7 +686,7 @@ func (s *planTestSuite) testFile(filename string, vschema *vschemawrapper.VSchem
if tcase.Skip {
t.Skip(message)
} else {
t.Errorf(message)
t.Error(message)
}
} else if tcase.Skip {
t.Errorf("query is correct even though it is skipped:\n %s", tcase.Query)
Expand Down
140 changes: 140 additions & 0 deletions go/vt/vtgate/planbuilder/testdata/reference_cases.json
Original file line number Diff line number Diff line change
Expand Up @@ -771,5 +771,145 @@
"user.user_extra"
]
}
},
{
"comment": "update reference table with join on sharded table",
"query": "update main.source_of_ref as sr join main.rerouted_ref as rr on sr.id = rr.id inner join user.music as m on sr.col = m.col set sr.tt = 5 where m.user_id = 1",
"plan": {
"QueryType": "UPDATE",
"Original": "update main.source_of_ref as sr join main.rerouted_ref as rr on sr.id = rr.id inner join user.music as m on sr.col = m.col set sr.tt = 5 where m.user_id = 1",
"Instructions": {
"OperatorType": "DMLWithInput",
"TargetTabletType": "PRIMARY",
"Offset": [
"0:[0]"
],
"Inputs": [
{
"OperatorType": "Join",
"Variant": "Join",
"JoinColumnIndexes": "R:0",
"JoinVars": {
"m_col": 0
},
"TableName": "music_rerouted_ref, source_of_ref",
"Inputs": [
{
"OperatorType": "Route",
"Variant": "EqualUnique",
"Keyspace": {
"Name": "user",
"Sharded": true
},
"FieldQuery": "select m.col from music as m where 1 != 1",
"Query": "select m.col from music as m where m.user_id = 1 lock in share mode",
"Table": "music",
"Values": [
"1"
],
"Vindex": "user_index"
},
{
"OperatorType": "Route",
"Variant": "Unsharded",
"Keyspace": {
"Name": "main",
"Sharded": false
},
"FieldQuery": "select sr.id from source_of_ref as sr, rerouted_ref as rr where 1 != 1",
"Query": "select sr.id from source_of_ref as sr, rerouted_ref as rr where sr.col = :m_col and sr.id = rr.id lock in share mode",
"Table": "rerouted_ref, source_of_ref"
}
]
},
{
"OperatorType": "Update",
"Variant": "Unsharded",
"Keyspace": {
"Name": "main",
"Sharded": false
},
"TargetTabletType": "PRIMARY",
"Query": "update source_of_ref as sr set sr.tt = 5 where sr.id in ::dml_vals",
"Table": "source_of_ref"
}
]
},
"TablesUsed": [
"main.rerouted_ref",
"main.source_of_ref",
"user.music"
]
}
},
{
"comment": "delete from reference table with join on sharded table",
"query": "delete sr from main.source_of_ref as sr join main.rerouted_ref as rr on sr.id = rr.id inner join user.music as m on sr.col = m.col where m.user_id = 1",
"plan": {
"QueryType": "DELETE",
"Original": "delete sr from main.source_of_ref as sr join main.rerouted_ref as rr on sr.id = rr.id inner join user.music as m on sr.col = m.col where m.user_id = 1",
"Instructions": {
"OperatorType": "DMLWithInput",
"TargetTabletType": "PRIMARY",
"Offset": [
"0:[0]"
],
"Inputs": [
{
"OperatorType": "Join",
"Variant": "Join",
"JoinColumnIndexes": "R:0",
"JoinVars": {
"m_col": 0
},
"TableName": "music_rerouted_ref, source_of_ref",
"Inputs": [
{
"OperatorType": "Route",
"Variant": "EqualUnique",
"Keyspace": {
"Name": "user",
"Sharded": true
},
"FieldQuery": "select m.col from music as m where 1 != 1",
"Query": "select m.col from music as m where m.user_id = 1",
"Table": "music",
"Values": [
"1"
],
"Vindex": "user_index"
},
{
"OperatorType": "Route",
"Variant": "Unsharded",
"Keyspace": {
"Name": "main",
"Sharded": false
},
"FieldQuery": "select sr.id from source_of_ref as sr, rerouted_ref as rr where 1 != 1",
"Query": "select sr.id from source_of_ref as sr, rerouted_ref as rr where sr.col = :m_col and sr.id = rr.id",
"Table": "rerouted_ref, source_of_ref"
}
]
},
{
"OperatorType": "Delete",
"Variant": "Unsharded",
"Keyspace": {
"Name": "main",
"Sharded": false
},
"TargetTabletType": "PRIMARY",
"Query": "delete from source_of_ref as sr where sr.id in ::dml_vals",
"Table": "source_of_ref"
}
]
},
"TablesUsed": [
"main.rerouted_ref",
"main.source_of_ref",
"user.music"
]
}
}
]
7 changes: 6 additions & 1 deletion go/vt/vtgate/planbuilder/testdata/unsupported_cases.json
Original file line number Diff line number Diff line change
Expand Up @@ -347,7 +347,12 @@
{
"comment": "reference table delete with join",
"query": "delete r from user u join ref_with_source r on u.col = r.col",
"plan": "VT12001: unsupported: DELETE on reference table with join"
"plan": "VT12001: unsupported: DML on reference table with join"
},
{
"comment": "reference table update with join",
"query": "update user u join ref_with_source r on u.col = r.col set r.col = 5",
"plan": "VT12001: unsupported: DML on reference table with join"
},
{
"comment": "group_concat unsupported when needs full evaluation at vtgate with more than 1 column",
Expand Down
2 changes: 1 addition & 1 deletion go/vt/vtgate/semantics/analyzer.go
Original file line number Diff line number Diff line change
Expand Up @@ -173,7 +173,7 @@ func (a *analyzer) newSemTable(
Direct: a.binder.direct,
ExprTypes: a.typer.m,
Tables: a.tables.Tables,
Targets: a.binder.targets,
DMLTargets: a.binder.targets,
NotSingleRouteErr: a.projErr,
NotUnshardedErr: a.unshardedErr,
Warning: a.warning,
Expand Down
14 changes: 7 additions & 7 deletions go/vt/vtgate/semantics/semantic_state.go
Original file line number Diff line number Diff line change
Expand Up @@ -119,8 +119,8 @@ type (
// It doesn't recurse inside derived tables to find the original dependencies.
Direct ExprDependencies

// Targets contains the TableSet of each table getting modified by the update/delete statement.
Targets TableSet
// DMLTargets contains the TableSet of each table getting modified by the update/delete statement.
DMLTargets TableSet

// ColumnEqualities is used for transitive closures (e.g., if a == b and b == c, then a == c).
ColumnEqualities map[columnName][]sqlparser.Expr
Expand Down Expand Up @@ -193,15 +193,15 @@ func (st *SemTable) CopyDependencies(from, to sqlparser.Expr) {

// GetChildForeignKeysForTargets gets the child foreign keys as a list for all the target tables.
func (st *SemTable) GetChildForeignKeysForTargets() (fks []vindexes.ChildFKInfo) {
for _, ts := range st.Targets.Constituents() {
for _, ts := range st.DMLTargets.Constituents() {
fks = append(fks, st.childForeignKeysInvolved[ts]...)
}
return fks
}

// GetChildForeignKeysForTableSet gets the child foreign keys as a listfor the TableSet.
func (st *SemTable) GetChildForeignKeysForTableSet(target TableSet) (fks []vindexes.ChildFKInfo) {
for _, ts := range st.Targets.Constituents() {
for _, ts := range st.DMLTargets.Constituents() {
if target.IsSolvedBy(ts) {
fks = append(fks, st.childForeignKeysInvolved[ts]...)
}
Expand Down Expand Up @@ -229,15 +229,15 @@ func (st *SemTable) GetChildForeignKeysList() []vindexes.ChildFKInfo {

// GetParentForeignKeysForTargets gets the parent foreign keys as a list for all the target tables.
func (st *SemTable) GetParentForeignKeysForTargets() (fks []vindexes.ParentFKInfo) {
for _, ts := range st.Targets.Constituents() {
for _, ts := range st.DMLTargets.Constituents() {
fks = append(fks, st.parentForeignKeysInvolved[ts]...)
}
return fks
}

// GetParentForeignKeysForTableSet gets the parent foreign keys as a list for the TableSet.
func (st *SemTable) GetParentForeignKeysForTableSet(target TableSet) (fks []vindexes.ParentFKInfo) {
for _, ts := range st.Targets.Constituents() {
for _, ts := range st.DMLTargets.Constituents() {
if target.IsSolvedBy(ts) {
fks = append(fks, st.parentForeignKeysInvolved[ts]...)
}
Expand Down Expand Up @@ -984,7 +984,7 @@ func (st *SemTable) UpdateChildFKExpr(origUpdExpr *sqlparser.UpdateExpr, newExpr

// GetTargetTableSetForTableName returns the TableSet for the given table name from the target tables.
func (st *SemTable) GetTargetTableSetForTableName(name sqlparser.TableName) (TableSet, error) {
for _, target := range st.Targets.Constituents() {
for _, target := range st.DMLTargets.Constituents() {
tbl, err := st.Tables[target.TableOffset()].Name()
if err != nil {
return "", err
Expand Down

0 comments on commit 192fa94

Please sign in to comment.