diff --git a/partiql-planner/src/main/kotlin/org/partiql/planner/internal/transforms/PlanTransform.kt b/partiql-planner/src/main/kotlin/org/partiql/planner/internal/transforms/PlanTransform.kt index 9ee252a170..df17f8ddd7 100644 --- a/partiql-planner/src/main/kotlin/org/partiql/planner/internal/transforms/PlanTransform.kt +++ b/partiql-planner/src/main/kotlin/org/partiql/planner/internal/transforms/PlanTransform.kt @@ -185,216 +185,217 @@ internal object PlanTransform { override fun visitRexOpCase(node: Rex.Op.Case, ctx: Unit) = org.partiql.plan.Rex.Op.Case( branches = node.branches.map { visitRexOpCaseBranch(it, ctx) }, default = visitRex(node.default, ctx) - ) - - override fun visitRexOpCaseBranch(node: Rex.Op.Case.Branch, ctx: Unit) = org.partiql.plan.Rex.Op.Case.Branch( - condition = visitRex(node.condition, ctx), rex = visitRex(node.rex, ctx) - ) - - override fun visitRexOpCollection(node: Rex.Op.Collection, ctx: Unit) = - org.partiql.plan.Rex.Op.Collection(values = node.values.map { visitRex(it, ctx) }) + ) - override fun visitRexOpStruct(node: Rex.Op.Struct, ctx: Unit) = - org.partiql.plan.Rex.Op.Struct(fields = node.fields.map { visitRexOpStructField(it, ctx) }) + override fun visitRexOpCaseBranch(node: Rex.Op.Case.Branch, ctx: Unit) = org.partiql.plan.Rex.Op.Case.Branch( + condition = visitRex(node.condition, ctx), rex = visitRex(node.rex, ctx) + ) - override fun visitRexOpStructField(node: Rex.Op.Struct.Field, ctx: Unit) = org.partiql.plan.Rex.Op.Struct.Field( - k = visitRex(node.k, ctx), - v = visitRex(node.v, ctx), - ) + override fun visitRexOpCollection(node: Rex.Op.Collection, ctx: Unit) = + org.partiql.plan.Rex.Op.Collection(values = node.values.map { visitRex(it, ctx) }) - override fun visitRexOpPivot(node: Rex.Op.Pivot, ctx: Unit) = org.partiql.plan.Rex.Op.Pivot( - key = visitRex(node.key, ctx), - value = visitRex(node.value, ctx), - rel = visitRel(node.rel, ctx), - ) + override fun visitRexOpStruct(node: Rex.Op.Struct, ctx: Unit) = + org.partiql.plan.Rex.Op.Struct(fields = node.fields.map { visitRexOpStructField(it, ctx) }) - override fun visitRexOpSubquery(node: Rex.Op.Subquery, ctx: Unit) = org.partiql.plan.Rex.Op.Subquery( - select = visitRexOpSelect(node.select, ctx), - coercion = when (node.coercion) { - Rex.Op.Subquery.Coercion.SCALAR -> org.partiql.plan.Rex.Op.Subquery.Coercion.SCALAR - Rex.Op.Subquery.Coercion.ROW -> org.partiql.plan.Rex.Op.Subquery.Coercion.ROW - } - ) + override fun visitRexOpStructField(node: Rex.Op.Struct.Field, ctx: Unit) = org.partiql.plan.Rex.Op.Struct.Field( + k = visitRex(node.k, ctx), + v = visitRex(node.v, ctx), + ) - override fun visitRexOpSelect(node: Rex.Op.Select, ctx: Unit) = org.partiql.plan.Rex.Op.Select( - constructor = visitRex(node.constructor, ctx), - rel = visitRel(node.rel, ctx), - ) + override fun visitRexOpPivot(node: Rex.Op.Pivot, ctx: Unit) = org.partiql.plan.Rex.Op.Pivot( + key = visitRex(node.key, ctx), + value = visitRex(node.value, ctx), + rel = visitRel(node.rel, ctx), + ) - override fun visitRexOpTupleUnion(node: Rex.Op.TupleUnion, ctx: Unit) = - org.partiql.plan.Rex.Op.TupleUnion(args = node.args.map { visitRex(it, ctx) }) + override fun visitRexOpSubquery(node: Rex.Op.Subquery, ctx: Unit) = org.partiql.plan.Rex.Op.Subquery( + select = visitRexOpSelect(node.select, ctx), + coercion = when (node.coercion) { + Rex.Op.Subquery.Coercion.SCALAR -> org.partiql.plan.Rex.Op.Subquery.Coercion.SCALAR + Rex.Op.Subquery.Coercion.ROW -> org.partiql.plan.Rex.Op.Subquery.Coercion.ROW + } + ) - override fun visitRexOpErr(node: Rex.Op.Err, ctx: Unit) = org.partiql.plan.Rex.Op.Err(node.message) + override fun visitRexOpSelect(node: Rex.Op.Select, ctx: Unit) = org.partiql.plan.Rex.Op.Select( + constructor = visitRex(node.constructor, ctx), + rel = visitRel(node.rel, ctx), + ) - // RELATION OPERATORS + override fun visitRexOpTupleUnion(node: Rex.Op.TupleUnion, ctx: Unit) = + org.partiql.plan.Rex.Op.TupleUnion(args = node.args.map { visitRex(it, ctx) }) - override fun visitRel(node: Rel, ctx: Unit) = org.partiql.plan.Rel( - type = visitRelType(node.type, ctx), - op = visitRelOp(node.op, ctx), - ) + override fun visitRexOpErr(node: Rex.Op.Err, ctx: Unit) = org.partiql.plan.Rex.Op.Err(node.message) - override fun visitRelType(node: Rel.Type, ctx: Unit) = - org.partiql.plan.Rel.Type( - schema = node.schema.map { visitRelBinding(it, ctx) }, - props = node.props.map { - when (it) { - Rel.Prop.ORDERED -> org.partiql.plan.Rel.Prop.ORDERED - } - }.toSet() + // RELATION OPERATORS + override fun visitRel(node: Rel, ctx: Unit) = org.partiql.plan.Rel( + type = visitRelType(node.type, ctx), + op = visitRelOp(node.op, ctx), ) - override fun visitRelOp(node: Rel.Op, ctx: Unit) = super.visitRelOp(node, ctx) as org.partiql.plan.Rel.Op + override fun visitRelType(node: Rel.Type, ctx: Unit) = + org.partiql.plan.Rel.Type( + schema = node.schema.map { visitRelBinding(it, ctx) }, + props = node.props.map { + when (it) { + Rel.Prop.ORDERED -> org.partiql.plan.Rel.Prop.ORDERED + } + }.toSet() - override fun visitRelOpScan(node: Rel.Op.Scan, ctx: Unit) = org.partiql.plan.Rel.Op.Scan( - rex = visitRex(node.rex, ctx), - ) + ) - override fun visitRelOpScanIndexed(node: Rel.Op.ScanIndexed, ctx: Unit) = org.partiql.plan.Rel.Op.ScanIndexed( - rex = visitRex(node.rex, ctx), - ) + override fun visitRelOp(node: Rel.Op, ctx: Unit) = super.visitRelOp(node, ctx) as org.partiql.plan.Rel.Op - override fun visitRelOpUnpivot(node: Rel.Op.Unpivot, ctx: Unit) = org.partiql.plan.Rel.Op.Unpivot( - rex = visitRex(node.rex, ctx), - ) + override fun visitRelOpScan(node: Rel.Op.Scan, ctx: Unit) = org.partiql.plan.Rel.Op.Scan( + rex = visitRex(node.rex, ctx), + ) - override fun visitRelOpDistinct(node: Rel.Op.Distinct, ctx: Unit) = org.partiql.plan.Rel.Op.Distinct( - input = visitRel(node.input, ctx), - ) + override fun visitRelOpScanIndexed(node: Rel.Op.ScanIndexed, ctx: Unit) = org.partiql.plan.Rel.Op.ScanIndexed( + rex = visitRex(node.rex, ctx), + ) - override fun visitRelOpFilter(node: Rel.Op.Filter, ctx: Unit) = org.partiql.plan.Rel.Op.Filter( - input = visitRel(node.input, ctx), - predicate = visitRex(node.predicate, ctx), - ) + override fun visitRelOpUnpivot(node: Rel.Op.Unpivot, ctx: Unit) = org.partiql.plan.Rel.Op.Unpivot( + rex = visitRex(node.rex, ctx), + ) - override fun visitRelOpSort(node: Rel.Op.Sort, ctx: Unit) = - org.partiql.plan.Rel.Op.Sort( + override fun visitRelOpDistinct(node: Rel.Op.Distinct, ctx: Unit) = org.partiql.plan.Rel.Op.Distinct( input = visitRel(node.input, ctx), - specs = node.specs.map { visitRelOpSortSpec(it, ctx) } ) - override fun visitRelOpSortSpec(node: Rel.Op.Sort.Spec, ctx: Unit) = org.partiql.plan.Rel.Op.Sort.Spec( - rex = visitRex(node.rex, ctx), - order = when (node.order) { - Rel.Op.Sort.Order.ASC_NULLS_LAST -> org.partiql.plan.Rel.Op.Sort.Order.ASC_NULLS_LAST - Rel.Op.Sort.Order.ASC_NULLS_FIRST -> org.partiql.plan.Rel.Op.Sort.Order.ASC_NULLS_FIRST - Rel.Op.Sort.Order.DESC_NULLS_LAST -> org.partiql.plan.Rel.Op.Sort.Order.DESC_NULLS_LAST - Rel.Op.Sort.Order.DESC_NULLS_FIRST -> org.partiql.plan.Rel.Op.Sort.Order.DESC_NULLS_FIRST - } - ) - - override fun visitRelOpUnion(node: Rel.Op.Union, ctx: Unit) = org.partiql.plan.Rel.Op.Union( - lhs = visitRel(node.lhs, ctx), - rhs = visitRel(node.rhs, ctx), - ) + override fun visitRelOpFilter(node: Rel.Op.Filter, ctx: Unit) = org.partiql.plan.Rel.Op.Filter( + input = visitRel(node.input, ctx), + predicate = visitRex(node.predicate, ctx), + ) - override fun visitRelOpIntersect(node: Rel.Op.Intersect, ctx: Unit) = org.partiql.plan.Rel.Op.Intersect( - lhs = visitRel(node.lhs, ctx), - rhs = visitRel(node.rhs, ctx), - ) + override fun visitRelOpSort(node: Rel.Op.Sort, ctx: Unit) = + org.partiql.plan.Rel.Op.Sort( + input = visitRel(node.input, ctx), + specs = node.specs.map { visitRelOpSortSpec(it, ctx) } + ) + + override fun visitRelOpSortSpec(node: Rel.Op.Sort.Spec, ctx: Unit) = org.partiql.plan.Rel.Op.Sort.Spec( + rex = visitRex(node.rex, ctx), + order = when (node.order) { + Rel.Op.Sort.Order.ASC_NULLS_LAST -> org.partiql.plan.Rel.Op.Sort.Order.ASC_NULLS_LAST + Rel.Op.Sort.Order.ASC_NULLS_FIRST -> org.partiql.plan.Rel.Op.Sort.Order.ASC_NULLS_FIRST + Rel.Op.Sort.Order.DESC_NULLS_LAST -> org.partiql.plan.Rel.Op.Sort.Order.DESC_NULLS_LAST + Rel.Op.Sort.Order.DESC_NULLS_FIRST -> org.partiql.plan.Rel.Op.Sort.Order.DESC_NULLS_FIRST + } + ) - override fun visitRelOpExcept(node: Rel.Op.Except, ctx: Unit) = org.partiql.plan.Rel.Op.Except( - lhs = visitRel(node.lhs, ctx), - rhs = visitRel(node.rhs, ctx), - ) + override fun visitRelOpUnion(node: Rel.Op.Union, ctx: Unit) = org.partiql.plan.Rel.Op.Union( + lhs = visitRel(node.lhs, ctx), + rhs = visitRel(node.rhs, ctx), + ) - override fun visitRelOpLimit(node: Rel.Op.Limit, ctx: Unit) = org.partiql.plan.Rel.Op.Limit( - input = visitRel(node.input, ctx), - limit = visitRex(node.limit, ctx), - ) + override fun visitRelOpIntersect(node: Rel.Op.Intersect, ctx: Unit) = org.partiql.plan.Rel.Op.Intersect( + lhs = visitRel(node.lhs, ctx), + rhs = visitRel(node.rhs, ctx), + ) - override fun visitRelOpOffset(node: Rel.Op.Offset, ctx: Unit) = org.partiql.plan.Rel.Op.Offset( - input = visitRel(node.input, ctx), - offset = visitRex(node.offset, ctx), - ) + override fun visitRelOpExcept(node: Rel.Op.Except, ctx: Unit) = org.partiql.plan.Rel.Op.Except( + lhs = visitRel(node.lhs, ctx), + rhs = visitRel(node.rhs, ctx), + ) - override fun visitRelOpProject(node: Rel.Op.Project, ctx: Unit) = org.partiql.plan.Rel.Op.Project( - input = visitRel(node.input, ctx), - projections = node.projections.map { visitRex(it, ctx) }, - ) + override fun visitRelOpLimit(node: Rel.Op.Limit, ctx: Unit) = org.partiql.plan.Rel.Op.Limit( + input = visitRel(node.input, ctx), + limit = visitRex(node.limit, ctx), + ) - override fun visitRelOpJoin(node: Rel.Op.Join, ctx: Unit) = org.partiql.plan.Rel.Op.Join( - lhs = visitRel(node.lhs, ctx), - rhs = visitRel(node.rhs, ctx), - rex = visitRex(node.rex, ctx), - type = when (node.type) { - Rel.Op.Join.Type.INNER -> org.partiql.plan.Rel.Op.Join.Type.INNER - Rel.Op.Join.Type.LEFT -> org.partiql.plan.Rel.Op.Join.Type.LEFT - Rel.Op.Join.Type.RIGHT -> org.partiql.plan.Rel.Op.Join.Type.RIGHT - Rel.Op.Join.Type.FULL -> org.partiql.plan.Rel.Op.Join.Type.FULL - } - ) + override fun visitRelOpOffset(node: Rel.Op.Offset, ctx: Unit) = org.partiql.plan.Rel.Op.Offset( + input = visitRel(node.input, ctx), + offset = visitRex(node.offset, ctx), + ) - override fun visitRelOpAggregate(node: Rel.Op.Aggregate, ctx: Unit) = org.partiql.plan.Rel.Op.Aggregate( - input = visitRel(node.input, ctx), - strategy = when (node.strategy) { - Rel.Op.Aggregate.Strategy.FULL -> org.partiql.plan.Rel.Op.Aggregate.Strategy.FULL - Rel.Op.Aggregate.Strategy.PARTIAL -> org.partiql.plan.Rel.Op.Aggregate.Strategy.PARTIAL - }, - calls = node.calls.map { visitRelOpAggregateCall(it, ctx) }, - groups = node.groups.map { visitRex(it, ctx) }, - ) + override fun visitRelOpProject(node: Rel.Op.Project, ctx: Unit) = org.partiql.plan.Rel.Op.Project( + input = visitRel(node.input, ctx), + projections = node.projections.map { visitRex(it, ctx) }, + ) - override fun visitRelOpAggregateCall(node: Rel.Op.Aggregate.Call, ctx: Unit) = - super.visitRelOpAggregateCall(node, ctx) as org.partiql.plan.Rel.Op.Aggregate.Call + override fun visitRelOpJoin(node: Rel.Op.Join, ctx: Unit) = org.partiql.plan.Rel.Op.Join( + lhs = visitRel(node.lhs, ctx), + rhs = visitRel(node.rhs, ctx), + rex = visitRex(node.rex, ctx), + type = when (node.type) { + Rel.Op.Join.Type.INNER -> org.partiql.plan.Rel.Op.Join.Type.INNER + Rel.Op.Join.Type.LEFT -> org.partiql.plan.Rel.Op.Join.Type.LEFT + Rel.Op.Join.Type.RIGHT -> org.partiql.plan.Rel.Op.Join.Type.RIGHT + Rel.Op.Join.Type.FULL -> org.partiql.plan.Rel.Op.Join.Type.FULL + } + ) - override fun visitRelOpAggregateCallUnresolved(node: Rel.Op.Aggregate.Call.Unresolved, ctx: Unit): PlanNode { - error("Unresolved aggregate call $node") - } + override fun visitRelOpAggregate(node: Rel.Op.Aggregate, ctx: Unit) = org.partiql.plan.Rel.Op.Aggregate( + input = visitRel(node.input, ctx), + strategy = when (node.strategy) { + Rel.Op.Aggregate.Strategy.FULL -> org.partiql.plan.Rel.Op.Aggregate.Strategy.FULL + Rel.Op.Aggregate.Strategy.PARTIAL -> org.partiql.plan.Rel.Op.Aggregate.Strategy.PARTIAL + }, + calls = node.calls.map { visitRelOpAggregateCall(it, ctx) }, + groups = node.groups.map { visitRex(it, ctx) }, + ) - override fun visitRelOpAggregateCallResolved(node: Rel.Op.Aggregate.Call.Resolved, ctx: Unit): PlanNode { - val agg = node.agg.name - val args = node.args.map { visitRex(it, ctx) } - return org.partiql.plan.relOpAggregateCall(node.agg.name, args) - } + override fun visitRelOpAggregateCall(node: Rel.Op.Aggregate.Call, ctx: Unit) = + super.visitRelOpAggregateCall(node, ctx) as org.partiql.plan.Rel.Op.Aggregate.Call - override fun visitRelOpExclude(node: Rel.Op.Exclude, ctx: Unit) = org.partiql.plan.Rel.Op.Exclude( - input = visitRel(node.input, ctx), - items = node.items.map { visitRelOpExcludeItem(it, ctx) }, - ) + override fun visitRelOpAggregateCallUnresolved(node: Rel.Op.Aggregate.Call.Unresolved, ctx: Unit): PlanNode { + error("Unresolved aggregate call $node") + } - override fun visitRelOpExcludeItem( - node: Rel.Op.Exclude.Item, - ctx: Unit, - ): org.partiql.plan.Rel.Op.Exclude.Item { - val root = when (node.root) { - is Rex.Op.Var.Resolved -> visitRexOpVar(node.root, ctx) as org.partiql.plan.Rex.Op.Var - is Rex.Op.Var.Unresolved -> org.partiql.plan.Rex.Op.Var(-1) // unresolved in `PlanTyper` results in error + override fun visitRelOpAggregateCallResolved(node: Rel.Op.Aggregate.Call.Resolved, ctx: Unit): PlanNode { + val agg = node.agg.name + val args = node.args.map { visitRex(it, ctx) } + return org.partiql.plan.relOpAggregateCall(node.agg.name, args) } - return org.partiql.plan.Rel.Op.Exclude.Item( - root = root, - steps = node.steps.map { visitRelOpExcludeStep(it, ctx) }, + + override fun visitRelOpExclude(node: Rel.Op.Exclude, ctx: Unit) = org.partiql.plan.Rel.Op.Exclude( + input = visitRel(node.input, ctx), + items = node.items.map { visitRelOpExcludeItem(it, ctx) }, ) - } - override fun visitRelOpExcludeStep(node: Rel.Op.Exclude.Step, ctx: Unit) = - super.visit(node, ctx) as org.partiql.plan.Rel.Op.Exclude.Step + override fun visitRelOpExcludeItem( + node: Rel.Op.Exclude.Item, + ctx: Unit, + ): org.partiql.plan.Rel.Op.Exclude.Item { + val root = when (node.root) { + is Rex.Op.Var.Resolved -> visitRexOpVar(node.root, ctx) as org.partiql.plan.Rex.Op.Var + is Rex.Op.Var.Unresolved -> org.partiql.plan.Rex.Op.Var(-1) // unresolved in `PlanTyper` results in error + } + return org.partiql.plan.Rel.Op.Exclude.Item( + root = root, + steps = node.steps.map { visitRelOpExcludeStep(it, ctx) }, + ) + } - override fun visitRelOpExcludeStepStructField(node: Rel.Op.Exclude.Step.StructField, ctx: Unit) = - org.partiql.plan.Rel.Op.Exclude.Step.StructField( - symbol = visitIdentifierSymbol(node.symbol, ctx), - ) + override fun visitRelOpExcludeStep(node: Rel.Op.Exclude.Step, ctx: Unit) = + super.visit(node, ctx) as org.partiql.plan.Rel.Op.Exclude.Step - override fun visitRelOpExcludeStepCollIndex(node: Rel.Op.Exclude.Step.CollIndex, ctx: Unit) = - org.partiql.plan.Rel.Op.Exclude.Step.CollIndex( - index = node.index, - ) + override fun visitRelOpExcludeStepStructField(node: Rel.Op.Exclude.Step.StructField, ctx: Unit) = + org.partiql.plan.Rel.Op.Exclude.Step.StructField( + symbol = visitIdentifierSymbol(node.symbol, ctx), + ) - override fun visitRelOpExcludeStepStructWildcard( - node: Rel.Op.Exclude.Step.StructWildcard, - ctx: Unit, - ) = org.partiql.plan.Rel.Op.Exclude.Step.StructWildcard() + override fun visitRelOpExcludeStepCollIndex(node: Rel.Op.Exclude.Step.CollIndex, ctx: Unit) = + org.partiql.plan.Rel.Op.Exclude.Step.CollIndex( + index = node.index, + ) - override fun visitRelOpExcludeStepCollWildcard( - node: Rel.Op.Exclude.Step.CollWildcard, - ctx: Unit, - ) = org.partiql.plan.Rel.Op.Exclude.Step.CollWildcard() + override fun visitRelOpExcludeStepStructWildcard( + node: Rel.Op.Exclude.Step.StructWildcard, + ctx: Unit, + ) = org.partiql.plan.Rel.Op.Exclude.Step.StructWildcard() - override fun visitRelOpErr(node: Rel.Op.Err, ctx: Unit) = org.partiql.plan.Rel.Op.Err(node.message) + override fun visitRelOpExcludeStepCollWildcard( + node: Rel.Op.Exclude.Step.CollWildcard, + ctx: Unit, + ) = org.partiql.plan.Rel.Op.Exclude.Step.CollWildcard() - override fun visitRelBinding(node: Rel.Binding, ctx: Unit) = org.partiql.plan.Rel.Binding( - name = node.name, - type = node.type, - ) + override fun visitRelOpErr(node: Rel.Op.Err, ctx: Unit) = org.partiql.plan.Rel.Op.Err(node.message) + + override fun visitRelBinding(node: Rel.Binding, ctx: Unit) = org.partiql.plan.Rel.Binding( + name = node.name, + type = node.type, + ) + } } -} + \ No newline at end of file