From 1ec356814b4ee6f87a0b4676a46a2435af30be78 Mon Sep 17 00:00:00 2001 From: Drew Kimball Date: Thu, 19 Sep 2024 19:49:54 -0600 Subject: [PATCH] opt: add rule to decorrelate unions in EXISTS subqueries This commit adds a new rule `TryDecorrelateUnion`, which matches on a `Union` or `UnionAll` operator in the input of a `ScalarGroupBy`. The `ScalarGroupBy` must have "any-not-null" semantics, meaning it produces an arbitrary non-null value from each input column. If these conditions are satisfied, the `Union` operator is replaced by an `InnerJoin` between two `ScalarGroupBy` operators. A `Project` coalesces columns from each side of the join to produce the final aggregated values. This transformation does not itself decorrelate the `Union` operators, but it does make it easier for other rules to do so. Release note: None Epic: None --- pkg/sql/opt/norm/decorrelate_funcs.go | 29 ++ pkg/sql/opt/norm/rules/decorrelate.opt | 64 ++++ pkg/sql/opt/norm/testdata/rules/decorrelate | 345 ++++++++++++++++++++ 3 files changed, 438 insertions(+) diff --git a/pkg/sql/opt/norm/decorrelate_funcs.go b/pkg/sql/opt/norm/decorrelate_funcs.go index 1bd8121f31a6..3e5e7fbd3264 100644 --- a/pkg/sql/opt/norm/decorrelate_funcs.go +++ b/pkg/sql/opt/norm/decorrelate_funcs.go @@ -1537,3 +1537,32 @@ func getSubstituteColsSetOp(set memo.RelExpr, substituteCols opt.ColSet) opt.Col } return newSubstituteCols } + +// MakeCoalesceProjectionsForUnion builds a series of projections that coalesce +// columns from the left and right inputs of a union, projecting the result +// using the union operator's output columns. +func (c *CustomFuncs) MakeCoalesceProjectionsForUnion( + setPrivate *memo.SetPrivate, +) memo.ProjectionsExpr { + projections := make(memo.ProjectionsExpr, len(setPrivate.OutCols)) + for i := range setPrivate.OutCols { + projections[i] = c.f.ConstructProjectionsItem( + c.f.ConstructCoalesce(memo.ScalarListExpr{ + c.f.ConstructVariable(setPrivate.LeftCols[i]), + c.f.ConstructVariable(setPrivate.RightCols[i]), + }), + setPrivate.OutCols[i], + ) + } + return projections +} + +// MakeAnyNotNullScalarGroupBy wraps the input expression in a ScalarGroupBy +// that aggregates the input columns with AnyNotNull functions. +func (c *CustomFuncs) MakeAnyNotNullScalarGroupBy(input memo.RelExpr) memo.RelExpr { + return c.f.ConstructScalarGroupBy( + input, + c.MakeAggCols(opt.AnyNotNullAggOp, input.Relational().OutputCols), + memo.EmptyGroupingPrivate, + ) +} diff --git a/pkg/sql/opt/norm/rules/decorrelate.opt b/pkg/sql/opt/norm/rules/decorrelate.opt index 17712868db29..1a185267075a 100644 --- a/pkg/sql/opt/norm/rules/decorrelate.opt +++ b/pkg/sql/opt/norm/rules/decorrelate.opt @@ -406,6 +406,70 @@ (OutputCols2 $left $right) ) +# TryDecorrelateUnion replaces a Union/UnionAll beneath a ScalarGroupBy with a +# cross-join (InnerJoin on True) between two ScalarGroupBy operators. A Project +# operator coalesces columns from each join input to produce the final result. +# This transformation applies when the ScalarGroupBy has only "any-not-null" +# aggregations, which select an arbitrary non-null value from the input column. +# +# Here's a simplified example: +# +# scalar-group-by +# ├── union-all +# │ ├── scan foo +# │ └── scan bar (has-outer-cols) +# └── aggregations +# └── any-not-null +# => +# project +# ├── inner-join (cross) +# │ ├── scalar-group-by +# │ │ └── scan foo +# │ ├── scalar-group-by +# │ │ └── scan bar +# │ └── filters (true) +# └── projections +# └── coalesce +# +# This situation occurs after a correlated EXISTS subquery containing a Union is +# hoisted. Note that TryDecorrelateUnion does not itself decorrelate the Union, +# but makes it easier for other rules to do so. +# +# NOTE: the outer Project operator is necessary just in case the ScalarGroupBy +# is synthesizing new columns, despite using any-not-null aggregations. +# NOTE: TryDecorrelateUnion should be ordered before TryDecorrelateScalarGroupBy +# to ensure that Union operators have a chance to be decorrelated. +# +# TODO(drewk): We could extend this rule to apply to other aggregations; for +# example, for a count() we can sum the counts taken on each side of the join. +# TODO(drewk): We could extend this rule to handle other set operations. For +# example, ExceptAll could become an AntiJoin. +[TryDecorrelateUnion, Normalize] +(ScalarGroupBy + $input:(Union | UnionAll $left:* $right:* $unionPrivate:*) & + (HasOuterCols $input) + $aggs:* & (AreAllAnyNotNullAggs $aggs) + $private:* +) +=> +(Project + (Project + (InnerJoin + (MakeAnyNotNullScalarGroupBy $left) + (MakeAnyNotNullScalarGroupBy $right) + [] + (EmptyJoinPrivate) + ) + (MakeCoalesceProjectionsForUnion $unionPrivate) + (MakeEmptyColSet) + ) + (ConvertAnyNotNullAggsToProjections $aggs) + (IntersectionCols + (GroupingOutputCols $private $aggs) + (OutputCols $input) + ) +) + # TryDecorrelateScalarGroupBy "pushes down" a Join into a ScalarGroupBy # operator, in an attempt to keep "digging" down to find and eliminate # unnecessary correlation. The eventual hope is to trigger the DecorrelateJoin diff --git a/pkg/sql/opt/norm/testdata/rules/decorrelate b/pkg/sql/opt/norm/testdata/rules/decorrelate index 4d824e6849a4..b156616e1839 100644 --- a/pkg/sql/opt/norm/testdata/rules/decorrelate +++ b/pkg/sql/opt/norm/testdata/rules/decorrelate @@ -7552,3 +7552,348 @@ semi-join (hash) │ └── a IS DISTINCT FROM CAST(NULL AS INT8) └── filters └── x = a + +# -------------------------------------------------- +# TryDecorrelateUnion +# -------------------------------------------------- + +# Case with UnionAll. +norm expect=TryDecorrelateUnion format=(hide-all,show-columns) +SELECT *, CASE WHEN EXISTS (SELECT 1 FROM xy WHERE x = k UNION ALL SELECT 1 FROM uv WHERE v = k) THEN 1 ELSE 0 END FROM a; +---- +project + ├── columns: k:1 i:2 f:3 s:4 j:5 case:20 + ├── group-by (hash) + │ ├── columns: k:1 i:2 f:3 s:4 j:5 "?column?":12 "?column?":17 + │ ├── grouping columns: k:1 + │ ├── left-join (hash) + │ │ ├── columns: k:1 i:2 f:3 s:4 j:5 x:8 "?column?":12 v:14 "?column?":17 + │ │ ├── left-join (hash) + │ │ │ ├── columns: k:1 i:2 f:3 s:4 j:5 x:8 "?column?":12 + │ │ │ ├── scan a + │ │ │ │ └── columns: k:1 i:2 f:3 s:4 j:5 + │ │ │ ├── project + │ │ │ │ ├── columns: "?column?":12 x:8 + │ │ │ │ ├── scan xy + │ │ │ │ │ └── columns: x:8 + │ │ │ │ └── projections + │ │ │ │ └── 1 [as="?column?":12] + │ │ │ └── filters + │ │ │ └── x:8 = k:1 + │ │ ├── project + │ │ │ ├── columns: "?column?":17 v:14 + │ │ │ ├── scan uv + │ │ │ │ └── columns: v:14 + │ │ │ └── projections + │ │ │ └── 1 [as="?column?":17] + │ │ └── filters + │ │ └── v:14 = k:1 + │ └── aggregations + │ ├── any-not-null-agg [as="?column?":17] + │ │ └── "?column?":17 + │ ├── const-agg [as="?column?":12] + │ │ └── "?column?":12 + │ ├── const-agg [as=i:2] + │ │ └── i:2 + │ ├── const-agg [as=f:3] + │ │ └── f:3 + │ ├── const-agg [as=s:4] + │ │ └── s:4 + │ └── const-agg [as=j:5] + │ └── j:5 + └── projections + └── CASE WHEN COALESCE("?column?":12, "?column?":17) IS NOT NULL THEN 1 ELSE 0 END [as=case:20] + +# Case with Union. +norm expect=TryDecorrelateUnion format=(hide-all,show-columns) +SELECT *, CASE WHEN EXISTS (SELECT 1 FROM xy WHERE x = k UNION SELECT 1 FROM uv WHERE v = k) THEN 1 ELSE 0 END FROM a; +---- +project + ├── columns: k:1 i:2 f:3 s:4 j:5 case:20 + ├── group-by (hash) + │ ├── columns: k:1 i:2 f:3 s:4 j:5 "?column?":12 "?column?":17 + │ ├── grouping columns: k:1 + │ ├── left-join (hash) + │ │ ├── columns: k:1 i:2 f:3 s:4 j:5 x:8 "?column?":12 v:14 "?column?":17 + │ │ ├── left-join (hash) + │ │ │ ├── columns: k:1 i:2 f:3 s:4 j:5 x:8 "?column?":12 + │ │ │ ├── scan a + │ │ │ │ └── columns: k:1 i:2 f:3 s:4 j:5 + │ │ │ ├── project + │ │ │ │ ├── columns: "?column?":12 x:8 + │ │ │ │ ├── scan xy + │ │ │ │ │ └── columns: x:8 + │ │ │ │ └── projections + │ │ │ │ └── 1 [as="?column?":12] + │ │ │ └── filters + │ │ │ └── x:8 = k:1 + │ │ ├── project + │ │ │ ├── columns: "?column?":17 v:14 + │ │ │ ├── scan uv + │ │ │ │ └── columns: v:14 + │ │ │ └── projections + │ │ │ └── 1 [as="?column?":17] + │ │ └── filters + │ │ └── v:14 = k:1 + │ └── aggregations + │ ├── any-not-null-agg [as="?column?":17] + │ │ └── "?column?":17 + │ ├── const-agg [as="?column?":12] + │ │ └── "?column?":12 + │ ├── const-agg [as=i:2] + │ │ └── i:2 + │ ├── const-agg [as=f:3] + │ │ └── f:3 + │ ├── const-agg [as=s:4] + │ │ └── s:4 + │ └── const-agg [as=j:5] + │ └── j:5 + └── projections + └── CASE WHEN COALESCE("?column?":12, "?column?":17) IS NOT NULL THEN 1 ELSE 0 END [as=case:20] + +# Case with an uncorrelated Union branch. +norm expect=TryDecorrelateUnion format=(hide-all,show-columns) +SELECT *, CASE WHEN EXISTS (SELECT 1 FROM xy WHERE x = k UNION ALL SELECT 1 FROM uv WHERE u = 1) THEN 1 ELSE 0 END FROM a; +---- +project + ├── columns: k:1 i:2 f:3 s:4 j:5 case:20 + ├── inner-join-apply + │ ├── columns: k:1 i:2 f:3 s:4 j:5 "?column?":12 "?column?":17 + │ ├── scan a + │ │ └── columns: k:1 i:2 f:3 s:4 j:5 + │ ├── inner-join (cross) + │ │ ├── columns: "?column?":12 "?column?":17 + │ │ ├── scalar-group-by + │ │ │ ├── columns: "?column?":12 + │ │ │ ├── project + │ │ │ │ ├── columns: "?column?":12 + │ │ │ │ ├── select + │ │ │ │ │ ├── columns: x:8 + │ │ │ │ │ ├── scan xy + │ │ │ │ │ │ └── columns: x:8 + │ │ │ │ │ └── filters + │ │ │ │ │ └── x:8 = k:1 + │ │ │ │ └── projections + │ │ │ │ └── 1 [as="?column?":12] + │ │ │ └── aggregations + │ │ │ └── any-not-null-agg [as="?column?":12] + │ │ │ └── "?column?":12 + │ │ ├── scalar-group-by + │ │ │ ├── columns: "?column?":17 + │ │ │ ├── project + │ │ │ │ ├── columns: "?column?":17 + │ │ │ │ ├── select + │ │ │ │ │ ├── columns: u:13 + │ │ │ │ │ ├── scan uv + │ │ │ │ │ │ └── columns: u:13 + │ │ │ │ │ └── filters + │ │ │ │ │ └── u:13 = 1 + │ │ │ │ └── projections + │ │ │ │ └── 1 [as="?column?":17] + │ │ │ └── aggregations + │ │ │ └── any-not-null-agg [as="?column?":17] + │ │ │ └── "?column?":17 + │ │ └── filters (true) + │ └── filters (true) + └── projections + └── CASE WHEN COALESCE("?column?":12, "?column?":17) IS NOT NULL THEN 1 ELSE 0 END [as=case:20] + +# Case with more than one union operator. +norm expect=TryDecorrelateUnion format=(hide-all,show-columns) +SELECT *, CASE WHEN EXISTS ( + SELECT 1 FROM xy WHERE x = k + UNION ALL SELECT 1 FROM uv WHERE v = k + UNION ALL SELECT 1 FROM cd WHERE d = k +) THEN 1 ELSE 0 END FROM a; +---- +project + ├── columns: k:1 i:2 f:3 s:4 j:5 case:26 + ├── group-by (hash) + │ ├── columns: k:1 i:2 f:3 s:4 j:5 "?column?":18 "?column?":23 + │ ├── grouping columns: k:1 + │ ├── left-join (hash) + │ │ ├── columns: k:1 i:2 f:3 s:4 j:5 "?column?":18 d:20 "?column?":23 + │ │ ├── project + │ │ │ ├── columns: "?column?":18 k:1 i:2 f:3 s:4 j:5 + │ │ │ ├── group-by (hash) + │ │ │ │ ├── columns: k:1 i:2 f:3 s:4 j:5 "?column?":12 "?column?":17 + │ │ │ │ ├── grouping columns: k:1 + │ │ │ │ ├── left-join (hash) + │ │ │ │ │ ├── columns: k:1 i:2 f:3 s:4 j:5 x:8 "?column?":12 v:14 "?column?":17 + │ │ │ │ │ ├── left-join (hash) + │ │ │ │ │ │ ├── columns: k:1 i:2 f:3 s:4 j:5 x:8 "?column?":12 + │ │ │ │ │ │ ├── scan a + │ │ │ │ │ │ │ └── columns: k:1 i:2 f:3 s:4 j:5 + │ │ │ │ │ │ ├── project + │ │ │ │ │ │ │ ├── columns: "?column?":12 x:8 + │ │ │ │ │ │ │ ├── scan xy + │ │ │ │ │ │ │ │ └── columns: x:8 + │ │ │ │ │ │ │ └── projections + │ │ │ │ │ │ │ └── 1 [as="?column?":12] + │ │ │ │ │ │ └── filters + │ │ │ │ │ │ └── x:8 = k:1 + │ │ │ │ │ ├── project + │ │ │ │ │ │ ├── columns: "?column?":17 v:14 + │ │ │ │ │ │ ├── scan uv + │ │ │ │ │ │ │ └── columns: v:14 + │ │ │ │ │ │ └── projections + │ │ │ │ │ │ └── 1 [as="?column?":17] + │ │ │ │ │ └── filters + │ │ │ │ │ └── v:14 = k:1 + │ │ │ │ └── aggregations + │ │ │ │ ├── any-not-null-agg [as="?column?":17] + │ │ │ │ │ └── "?column?":17 + │ │ │ │ ├── const-agg [as="?column?":12] + │ │ │ │ │ └── "?column?":12 + │ │ │ │ ├── const-agg [as=i:2] + │ │ │ │ │ └── i:2 + │ │ │ │ ├── const-agg [as=f:3] + │ │ │ │ │ └── f:3 + │ │ │ │ ├── const-agg [as=s:4] + │ │ │ │ │ └── s:4 + │ │ │ │ └── const-agg [as=j:5] + │ │ │ │ └── j:5 + │ │ │ └── projections + │ │ │ └── COALESCE("?column?":12, "?column?":17) [as="?column?":18] + │ │ ├── project + │ │ │ ├── columns: "?column?":23 d:20 + │ │ │ ├── scan cd + │ │ │ │ └── columns: d:20 + │ │ │ └── projections + │ │ │ └── 1 [as="?column?":23] + │ │ └── filters + │ │ └── d:20 = k:1 + │ └── aggregations + │ ├── any-not-null-agg [as="?column?":23] + │ │ └── "?column?":23 + │ ├── const-agg [as="?column?":18] + │ │ └── "?column?":18 + │ ├── const-agg [as=i:2] + │ │ └── i:2 + │ ├── const-agg [as=f:3] + │ │ └── f:3 + │ ├── const-agg [as=s:4] + │ │ └── s:4 + │ └── const-agg [as=j:5] + │ └── j:5 + └── projections + └── CASE WHEN COALESCE("?column?":18, "?column?":23) IS NOT NULL THEN 1 ELSE 0 END [as=case:26] + +# No-op because there is no Union. +norm expect-not=TryDecorrelateUnion format=(hide-all,show-columns) +SELECT *, CASE WHEN EXISTS (SELECT 1 FROM xy WHERE x = k) THEN 1 ELSE 0 END FROM a; +---- +project + ├── columns: k:1 i:2 f:3 s:4 j:5 case:14 + ├── left-join (hash) + │ ├── columns: k:1 i:2 f:3 s:4 j:5 x:8 + │ ├── scan a + │ │ └── columns: k:1 i:2 f:3 s:4 j:5 + │ ├── scan xy + │ │ └── columns: x:8 + │ └── filters + │ └── x:8 = k:1 + └── projections + └── CASE WHEN x:8 IS NOT NULL THEN 1 ELSE 0 END [as=case:14] + +# No-op because there's an Intersect instead of a Union. +norm expect-not=TryDecorrelateUnion format=(hide-all,show-columns) +SELECT *, CASE WHEN EXISTS (SELECT 1 FROM xy WHERE x = k INTERSECT SELECT 1 FROM uv WHERE v = k) THEN 1 ELSE 0 END FROM a; +---- +project + ├── columns: k:1 i:2 f:3 s:4 j:5 case:19 + ├── left-join-apply + │ ├── columns: k:1 i:2 f:3 s:4 j:5 "?column?":12 + │ ├── scan a + │ │ └── columns: k:1 i:2 f:3 s:4 j:5 + │ ├── intersect-all + │ │ ├── columns: "?column?":12 + │ │ ├── left columns: "?column?":12 + │ │ ├── right columns: "?column?":17 + │ │ ├── project + │ │ │ ├── columns: "?column?":12 + │ │ │ ├── select + │ │ │ │ ├── columns: x:8 + │ │ │ │ ├── scan xy + │ │ │ │ │ └── columns: x:8 + │ │ │ │ └── filters + │ │ │ │ └── x:8 = k:1 + │ │ │ └── projections + │ │ │ └── 1 [as="?column?":12] + │ │ └── project + │ │ ├── columns: "?column?":17 + │ │ ├── select + │ │ │ ├── columns: v:14 + │ │ │ ├── scan uv + │ │ │ │ └── columns: v:14 + │ │ │ └── filters + │ │ │ └── v:14 = k:1 + │ │ └── projections + │ │ └── 1 [as="?column?":17] + │ └── filters (true) + └── projections + └── CASE WHEN "?column?":12 IS NOT NULL THEN 1 ELSE 0 END [as=case:19] + +# No-op case because one of the aggregations isn't any-not-null. +norm expect-not=TryDecorrelateUnion format=(hide-all,show-columns) +SELECT * FROM a INNER JOIN LATERAL (SELECT sum(x) FROM (SELECT x FROM xy WHERE x = k UNION ALL SELECT v FROM uv WHERE v = k)) ON True; +---- +group-by (hash) + ├── columns: k:1 i:2 f:3 s:4 j:5 sum:17 + ├── grouping columns: k:1 + ├── left-join-apply + │ ├── columns: k:1 i:2 f:3 s:4 j:5 x:16 + │ ├── scan a + │ │ └── columns: k:1 i:2 f:3 s:4 j:5 + │ ├── union-all + │ │ ├── columns: x:16 + │ │ ├── left columns: xy.x:8 + │ │ ├── right columns: v:13 + │ │ ├── select + │ │ │ ├── columns: xy.x:8 + │ │ │ ├── scan xy + │ │ │ │ └── columns: xy.x:8 + │ │ │ └── filters + │ │ │ └── xy.x:8 = k:1 + │ │ └── select + │ │ ├── columns: v:13 + │ │ ├── scan uv + │ │ │ └── columns: v:13 + │ │ └── filters + │ │ └── v:13 = k:1 + │ └── filters (true) + └── aggregations + ├── sum [as=sum:17] + │ └── x:16 + ├── const-agg [as=i:2] + │ └── i:2 + ├── const-agg [as=f:3] + │ └── f:3 + ├── const-agg [as=s:4] + │ └── s:4 + └── const-agg [as=j:5] + └── j:5 + +# No-op case because the Union isn't correlated. +norm expect-not=TryDecorrelateUnion format=(hide-all,show-columns) +SELECT * FROM a INNER JOIN LATERAL (SELECT 1 FROM xy UNION ALL SELECT 1 FROM uv) ON True; +---- +inner-join (cross) + ├── columns: k:1 i:2 f:3 s:4 j:5 "?column?":18 + ├── scan a + │ └── columns: k:1 i:2 f:3 s:4 j:5 + ├── union-all + │ ├── columns: "?column?":18 + │ ├── left columns: "?column?":12 + │ ├── right columns: "?column?":17 + │ ├── project + │ │ ├── columns: "?column?":12 + │ │ ├── scan xy + │ │ └── projections + │ │ └── 1 [as="?column?":12] + │ └── project + │ ├── columns: "?column?":17 + │ ├── scan uv + │ └── projections + │ └── 1 [as="?column?":17] + └── filters (true)