From 7fdbf58b41a3169990cdf03372326781a2fd8cb6 Mon Sep 17 00:00:00 2001 From: Nick Tobey Date: Fri, 19 Jan 2024 09:42:28 -0800 Subject: [PATCH] RangeHeapJoin should consistently sort NULL values before non-NULL values while managing its heap. --- enginetest/join_planning_tests.go | 46 +++++++++++++++++-------------- enginetest/queries/query_plans.go | 2 +- sql/memo/exec_builder.go | 4 +-- sql/rowexec/range_heap_iter.go | 28 ++++++++++++++----- 4 files changed, 50 insertions(+), 30 deletions(-) diff --git a/enginetest/join_planning_tests.go b/enginetest/join_planning_tests.go index 08268516aa..636f374549 100644 --- a/enginetest/join_planning_tests.go +++ b/enginetest/join_planning_tests.go @@ -1060,12 +1060,12 @@ join uv d on d.u = c.x`, }, }, { - name: "primary key range join", + name: "indexed range join", setup: []string{ - "create table vals (val int primary key);", - "create table ranges (min int primary key, max int, unique key(min,max));", - "insert into vals values (0), (1), (2), (3), (4), (5), (6);", - "insert into ranges values (0,2), (1,3), (2,4), (3,5), (4,6);", + "create table vals (val int unique key);", + "create table ranges (min int unique key, max int, unique key(min,max));", + "insert into vals values (null), (0), (1), (2), (3), (4), (5), (6);", + "insert into ranges values (null,1), (0,2), (1,3), (2,4), (3,5), (4,6);", }, tests: []JoinPlanTest{ { @@ -1246,8 +1246,9 @@ join uv d on d.u = c.x`, }, { q: "select * from vals where exists (select * from vals join ranges on val between min and max where min >= 2 and max <= 5)", - types: []plan.JoinType{plan.JoinTypeCross, plan.JoinTypeInner}, + types: []plan.JoinType{plan.JoinTypeCrossHash, plan.JoinTypeInner}, exp: []sql.Row{ + {nil}, {0}, {1}, {2}, @@ -1296,18 +1297,20 @@ join uv d on d.u = c.x`, {6}, }, }, - { - q: "select * from vals where exists (select * from ranges where val between min and max limit 1 offset 1);", - types: []plan.JoinType{plan.JoinTypeSemi}, - exp: []sql.Row{ - {1}, - {2}, - {3}, - {4}, - {5}, - {6}, - }, - }, + /* + Disabled because of https://github.com/dolthub/go-mysql-server/issues/2277 + { + q: "select * from vals where exists (select * from ranges where val between min and max limit 1 offset 1);", + types: []plan.JoinType{plan.JoinTypeSemi}, + exp: []sql.Row{ + {1}, + {2}, + {3}, + {4}, + {5}, + }, + }, + */ { q: "select * from vals where exists (select * from ranges where val between min and max having val > 1);", types: []plan.JoinType{}, @@ -1326,8 +1329,8 @@ join uv d on d.u = c.x`, setup: []string{ "create table vals (val int)", "create table ranges (min int, max int)", - "insert into vals values (0), (1), (2), (3), (4), (5), (6)", - "insert into ranges values (0,2), (1,3), (2,4), (3,5), (4,6)", + "insert into vals values (null), (0), (1), (2), (3), (4), (5), (6)", + "insert into ranges values (null,1), (0,2), (1,3), (2,4), (3,5), (4,6)", }, tests: []JoinPlanTest{ { @@ -1430,6 +1433,7 @@ join uv d on d.u = c.x`, q: "select * from vals left join ranges on val > min and val < max", types: []plan.JoinType{plan.JoinTypeLeftOuterRangeHeap}, exp: []sql.Row{ + {nil, nil, nil}, {0, nil, nil}, {1, 0, 2}, {2, 1, 3}, @@ -1453,6 +1457,7 @@ join uv d on d.u = c.x`, q: "select * from vals left join ranges r1 on val > r1.min and val < r1.max left join ranges r2 on r1.min > r2.min and r1.min < r2.max", types: []plan.JoinType{plan.JoinTypeLeftOuterRangeHeap, plan.JoinTypeLeftOuterRangeHeap}, exp: []sql.Row{ + {nil, nil, nil, nil, nil}, {0, nil, nil, nil, nil}, {1, 0, 2, nil, nil}, {2, 1, 3, 0, 2}, @@ -1494,6 +1499,7 @@ join uv d on d.u = c.x`, q: "select * from vals left join (select * from ranges where 0) as newRanges on val > min and val < max;", types: []plan.JoinType{plan.JoinTypeLeftOuterRangeHeap}, exp: []sql.Row{ + {nil, nil, nil}, {0, nil, nil}, {1, nil, nil}, {2, nil, nil}, diff --git a/enginetest/queries/query_plans.go b/enginetest/queries/query_plans.go index f176130197..ac0dce1755 100644 --- a/enginetest/queries/query_plans.go +++ b/enginetest/queries/query_plans.go @@ -12165,7 +12165,7 @@ order by x, y; " │ │ └─ Table\n" + " │ │ ├─ name: xy\n" + " │ │ └─ columns: [x y]\n" + - " │ └─ Sort(bigtable.n:1 ASC nullsLast)\n" + + " │ └─ Sort(bigtable.n:1 ASC nullsFirst)\n" + " │ └─ ProcessTable\n" + " │ └─ Table\n" + " │ ├─ name: bigtable\n" + diff --git a/sql/memo/exec_builder.go b/sql/memo/exec_builder.go index 06a95a77ff..ba53e74282 100644 --- a/sql/memo/exec_builder.go +++ b/sql/memo/exec_builder.go @@ -100,7 +100,7 @@ func (b *ExecBuilder) buildRangeHeap(sr *RangeHeap, children ...sql.Node) (ret s sf := []sql.SortField{{ Column: sortExpr, Order: sql.Ascending, - NullOrdering: sql.NullsLast, // Due to https://github.com/dolthub/go-mysql-server/issues/1903 + NullOrdering: sql.NullsFirst, }} childNode = plan.NewSort(sf, n) } @@ -135,7 +135,7 @@ func (b *ExecBuilder) buildRangeHeapJoin(j *RangeHeapJoin, children ...sql.Node) sf := []sql.SortField{{ Column: sortExpr, Order: sql.Ascending, - NullOrdering: sql.NullsLast, // Due to https://github.com/dolthub/go-mysql-server/issues/1903 + NullOrdering: sql.NullsFirst, }} left = plan.NewSort(sf, children[0]) } diff --git a/sql/rowexec/range_heap_iter.go b/sql/rowexec/range_heap_iter.go index 283702fffc..08b09f81a0 100644 --- a/sql/rowexec/range_heap_iter.go +++ b/sql/rowexec/range_heap_iter.go @@ -215,9 +215,6 @@ func (iter *rangeHeapJoinIter) Close(ctx *sql.Context) (err error) { return err } -type rangeHeapRowIterProvider struct { -} - func (iter *rangeHeapJoinIter) initializeHeap(ctx *sql.Context, builder sql.NodeExecBuilder, primaryRow sql.Row) (err error) { iter.childRowIter, err = builder.Build(ctx, iter.rangeHeapPlan.Child, primaryRow) if err != nil { @@ -235,11 +232,10 @@ func (iter *rangeHeapJoinIter) initializeHeap(ctx *sql.Context, builder sql.Node } func (iter *rangeHeapJoinIter) getActiveRanges(ctx *sql.Context, _ sql.NodeExecBuilder, row sql.Row) (sql.RowIter, error) { - // Remove rows from the heap if we've advanced beyond their max value. for iter.Len() > 0 { maxValue := iter.Peek() - compareResult, err := iter.rangeHeapPlan.ComparisonType.Compare(row[iter.rangeHeapPlan.ValueColumnIndex], maxValue) + compareResult, err := compareNullsFirst(iter.rangeHeapPlan.ComparisonType, row[iter.rangeHeapPlan.ValueColumnIndex], maxValue) if err != nil { return nil, err } @@ -258,7 +254,7 @@ func (iter *rangeHeapJoinIter) getActiveRanges(ctx *sql.Context, _ sql.NodeExecB // Advance the child iterator until we encounter a row whose min value is beyond the range. for iter.pendingRow != nil { minValue := iter.pendingRow[iter.rangeHeapPlan.MinColumnIndex] - compareResult, err := iter.rangeHeapPlan.ComparisonType.Compare(row[iter.rangeHeapPlan.ValueColumnIndex], minValue) + compareResult, err := compareNullsFirst(iter.rangeHeapPlan.ComparisonType, row[iter.rangeHeapPlan.ValueColumnIndex], minValue) if err != nil { return nil, err } @@ -289,13 +285,31 @@ func (iter *rangeHeapJoinIter) getActiveRanges(ctx *sql.Context, _ sql.NodeExecB return sql.RowsToRowIter(iter.activeRanges...), nil } +// When managing the heap, consider all NULLs to come before any non-NULLS. +// This is consistent with the order received if either child node is an index. +// Note: We could get the same behavior by simply excluding values and ranges containing NULL, +// but this is forward compatible if we ever want to convert joins with null-safe conditions into RangeHeapJoins. +func compareNullsFirst(comparisonType sql.Type, a, b interface{}) (int, error) { + if a == nil { + if b == nil { + return 0, nil + } else { + return -1, nil + } + } + if b == nil { + return 1, nil + } + return comparisonType.Compare(a, b) +} + func (iter rangeHeapJoinIter) Len() int { return len(iter.activeRanges) } func (iter *rangeHeapJoinIter) Less(i, j int) bool { lhs := iter.activeRanges[i][iter.rangeHeapPlan.MaxColumnIndex] rhs := iter.activeRanges[j][iter.rangeHeapPlan.MaxColumnIndex] // compareResult will be 0 if lhs==rhs, -1 if lhs < rhs, and +1 if lhs > rhs. - compareResult, err := iter.rangeHeapPlan.ComparisonType.Compare(lhs, rhs) + compareResult, err := compareNullsFirst(iter.rangeHeapPlan.ComparisonType, lhs, rhs) if iter.err == nil && err != nil { iter.err = err }