Skip to content

Commit

Permalink
Adds support for ROW and SCALAR subquery coercion
Browse files Browse the repository at this point in the history
  • Loading branch information
johnedquinn committed Jan 29, 2024
1 parent 101c19b commit cda3869
Show file tree
Hide file tree
Showing 9 changed files with 247 additions and 60 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@ public fun Statement.normalize(): Statement {
// could be a fold, but this is nice for setting breakpoints
var ast = this
ast = NormalizeFromSource.apply(ast)
ast = NormalizeSelect.apply(ast)
ast = NormalizeGroupBy.apply(ast)
return ast
}
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ import org.partiql.eval.internal.operator.rex.ExprPathSymbol
import org.partiql.eval.internal.operator.rex.ExprPivot
import org.partiql.eval.internal.operator.rex.ExprSelect
import org.partiql.eval.internal.operator.rex.ExprStruct
import org.partiql.eval.internal.operator.rex.ExprSubquery
import org.partiql.eval.internal.operator.rex.ExprTupleUnion
import org.partiql.eval.internal.operator.rex.ExprVar
import org.partiql.plan.PartiQLPlan
Expand Down Expand Up @@ -91,12 +92,20 @@ internal class Compiler @OptIn(PartiQLFunctionExperimental::class) constructor(
return ExprStruct(fields)
}

override fun visitRexOpSelect(node: Rex.Op.Select, ctx: Unit): Operator {
override fun visitRexOpSelect(node: Rex.Op.Select, ctx: Unit): Operator.Expr {
val rel = visitRel(node.rel, ctx)
val constructor = visitRex(node.constructor, ctx)
return ExprSelect(rel, constructor)
}

override fun visitRexOpSubquery(node: Rex.Op.Subquery, ctx: Unit): Operator {
val select = visitRexOpSelect(node.select, ctx)
return when (node.coercion) {
Rex.Op.Subquery.Coercion.SCALAR -> ExprSubquery.Scalar(select)
Rex.Op.Subquery.Coercion.ROW -> ExprSubquery.Row(select)
}
}

override fun visitRexOpPivot(node: Rex.Op.Pivot, ctx: Unit): Operator {
val rel = visitRel(node.rel, ctx)
val key = visitRex(node.key, ctx)
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
package org.partiql.eval.internal.operator.rex

import org.partiql.errors.TypeCheckException
import org.partiql.eval.internal.Record
import org.partiql.eval.internal.operator.Operator
import org.partiql.value.CollectionValue
import org.partiql.value.PartiQLValue
import org.partiql.value.PartiQLValueExperimental
import org.partiql.value.StructValue
import org.partiql.value.check
import org.partiql.value.listValue

internal abstract class ExprSubquery : Operator.Expr {

abstract val subquery: Operator.Expr

internal class Row(
override val subquery: Operator.Expr
) : ExprSubquery() {
@PartiQLValueExperimental
override fun eval(record: Record): PartiQLValue {
val values = getFirstAndOnlyTupleValues(record)
return listValue(values.asSequence().toList())
}
}

internal class Scalar(
override val subquery: Operator.Expr
) : ExprSubquery() {
@PartiQLValueExperimental
override fun eval(record: Record): PartiQLValue {
val values = getFirstAndOnlyTupleValues(record)
if (values.hasNext().not()) {
throw TypeCheckException()
}
val singleValue = values.next()
if (values.hasNext()) {
throw TypeCheckException()
}
return singleValue
}
}

/**
* Procedure is as follows:
* 1. Asserts that the [subquery] returns a collection containing a single value. Throws a [TypeCheckException] if not.
* 2. Gets the first value from [subquery].
* 3. Asserts that the first value is a TUPLE ([StructValue]). Throws a [TypeCheckException] if not.
* 4. Returns an [Iterator] of the values contained within the [StructValue].
*/
@OptIn(PartiQLValueExperimental::class)
fun getFirstAndOnlyTupleValues(record: Record): Iterator<PartiQLValue> {
val result = subquery.eval(record)
if (result !is CollectionValue<*>) {
throw TypeCheckException()
}
val resultIterator = result.iterator()
if (resultIterator.hasNext().not()) {
throw TypeCheckException()
}
val tuple = resultIterator.next().check<StructValue<*>>()
if (resultIterator.hasNext()) {
throw TypeCheckException()
}
return tuple.values.iterator()
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -268,6 +268,57 @@ class PartiQLEngineDefaultTest {
`null.bool` IS NULL
""".trimIndent(),
expected = boolValue(true)
),
SuccessTestCase(
input = """
1 + (SELECT t.a FROM << { 'a': 3 } >> AS t)
""".trimIndent(),
expected = int32Value(4)
),
// SELECT * without nested coercion
SuccessTestCase(
input = """
SELECT *
FROM (
SELECT t.a AS "first", t.b AS "second"
FROM << { 'a': 3, 'b': 5 } >> AS t
);
""".trimIndent(),
expected = bagValue(
structValue(
"first" to int32Value(3),
"second" to int32Value(5)
)
)
),
// SELECT list without nested coercion
SuccessTestCase(
input = """
SELECT "first", "second"
FROM (
SELECT t.a AS "first", t.b AS "second"
FROM << { 'a': 3, 'b': 5 } >> AS t
);
""".trimIndent(),
expected = bagValue(
structValue(
"first" to int32Value(3),
"second" to int32Value(5)
)
)
),
// SELECT value without nested coercion
SuccessTestCase(
input = """
SELECT VALUE "first"
FROM (
SELECT t.a AS "first", t.b AS "second"
FROM << { 'a': 3, 'b': 5 } >> AS t
);
""".trimIndent(),
expected = bagValue(
int32Value(3),
)
)
)
}
Expand Down Expand Up @@ -326,4 +377,25 @@ class PartiQLEngineDefaultTest {
input = "SELECT DISTINCT VALUE t * 100 FROM <<0, 1, 2.0, 3.0>> AS t;",
expected = bagValue(int32Value(0), int32Value(100), int32Value(200), int32Value(300))
).assert()

@Test
@Disabled("Support for ORDER BY needs to be added for this to pass.")
// PartiQL Specification says that SQL's SELECT is coerced, but SELECT VALUE is not.
fun selectValueNoCoercion() =
SuccessTestCase(
input = """
(4, 5) < (SELECT VALUE t.a FROM << { 'a': 3 }, { 'a': 4 } >> AS t ORDER BY t.a)
""".trimIndent(),
expected = boolValue(false)
).assert()

@Test
@Disabled("This is appropriately coerced, but this test is failing because LT currently doesn't support LISTS.")
fun rowCoercion() =
SuccessTestCase(
input = """
(4, 5) < (SELECT t.a, t.a FROM << { 'a': 3 } >> AS t)
""".trimIndent(),
expected = boolValue(false)
).assert()
}
Original file line number Diff line number Diff line change
Expand Up @@ -12,14 +12,12 @@
* language governing permissions and limitations under the License.
*/

package org.partiql.ast.normalize
package org.partiql.planner.internal.transforms

import org.partiql.ast.AstNode
import org.partiql.ast.Expr
import org.partiql.ast.From
import org.partiql.ast.Identifier
import org.partiql.ast.Select
import org.partiql.ast.Statement
import org.partiql.ast.exprCall
import org.partiql.ast.exprCase
import org.partiql.ast.exprCaseBranch
Expand Down Expand Up @@ -81,13 +79,26 @@ import org.partiql.value.stringValue
* } FROM A AS x
* ```
*
* TODO: LET
* NOTE: This does NOT transform subqueries. It operates directly on an [Expr.SFW] -- and that is it. Therefore:
* ```
* SELECT
* (SELECT 1 FROM T AS "T")
* FROM R AS "R"
* ```
* will be transformed to:
* ```
* SELECT VALUE {
* '_1': (SELECT 1 FROM T AS "T") -- notice that SELECT 1 didn't get transformed.
* } FROM R AS "R"
* ```
*
* Requires [NormalizeFromSource].
*/
internal object NormalizeSelect : AstPass {
internal object NormalizeSelect {

override fun apply(statement: Statement): Statement = Visitor.visitStatement(statement, newCtx()) as Statement
internal fun normalize(node: Expr.SFW): Expr.SFW {
return Visitor.visitSFW(node, newCtx())
}

/**
* Closure for incrementing a derived binding counter
Expand Down Expand Up @@ -124,15 +135,19 @@ internal object NormalizeSelect : AstPass {
*/
private val col = { index: Int -> "_${index + 1}" }

override fun visitExprSFW(node: Expr.SFW, ctx: () -> Int): Expr.SFW {
internal fun visitSFW(node: Expr.SFW, ctx: () -> Int): Expr.SFW {
val sfw = super.visitExprSFW(node, ctx) as Expr.SFW
return when (val select = sfw.select) {
is Select.Star -> sfw.copy(select = visitSelectAll(select, sfw.from))
else -> sfw
}
}

override fun visitSelectProject(node: Select.Project, ctx: () -> Int): AstNode {
override fun visitExprSFW(node: Expr.SFW, ctx: () -> Int): Expr.SFW {
return node
}

override fun visitSelectProject(node: Select.Project, ctx: () -> Int): Select.Value {

// Visit items, adding a binder if necessary
var diff = false
Expand Down Expand Up @@ -200,7 +215,7 @@ internal object NormalizeSelect : AstPass {
)
}

private fun visitSelectProjectWithProjectAll(node: Select.Project): AstNode {
private fun visitSelectProjectWithProjectAll(node: Select.Project): Select.Value {
val tupleUnionArgs = node.items.mapIndexed { index, item ->
when (item) {
is Select.Project.Item.All -> buildCaseWhenStruct(item.expr, index)
Expand All @@ -221,7 +236,7 @@ internal object NormalizeSelect : AstPass {
}

@OptIn(PartiQLValueExperimental::class)
private fun visitSelectProjectWithoutProjectAll(node: Select.Project): AstNode {
private fun visitSelectProjectWithoutProjectAll(node: Select.Project): Select.Value {
val structFields = node.items.map { item ->
val itemExpr = item as? Select.Project.Item.Expression ?: error("Expected the projection to be an expression.")
exprStructField(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -81,8 +81,9 @@ internal object RelConverter {
* Here we convert an SFW to composed [Rel]s, then apply the appropriate relation-value projection to get a [Rex].
*/
internal fun apply(sfw: Expr.SFW, env: Env): Rex {
val rel = sfw.accept(ToRel(env), nil)
val rex = when (val projection = sfw.select) {
val normalizedSfw = NormalizeSelect.normalize(sfw)
val rel = normalizedSfw.accept(ToRel(env), nil)
val rex = when (val projection = normalizedSfw.select) {
// PIVOT ... FROM
is Select.Pivot -> {
val key = projection.key.toRex(env)
Expand Down Expand Up @@ -149,15 +150,11 @@ internal object RelConverter {
rel = convertExclude(rel, sel.exclude)
// append SQL projection if present
rel = when (val projection = sel.select) {
is Select.Project -> {
val project = visitSelectProject(projection, rel)
visitSetQuantifier(projection.setq, project)
}
is Select.Value -> {
val project = visitSelectValue(projection, rel)
visitSetQuantifier(projection.setq, project)
}
is Select.Star -> error("AST not normalized, found project star")
is Select.Star, is Select.Project -> error("AST not normalized, found ${projection.javaClass.simpleName}")
is Select.Pivot -> rel // Skip PIVOT
}
return rel
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ package org.partiql.planner.internal.transforms
import org.partiql.ast.AstNode
import org.partiql.ast.DatetimeField
import org.partiql.ast.Expr
import org.partiql.ast.Select
import org.partiql.ast.Type
import org.partiql.ast.visitor.AstBaseVisitor
import org.partiql.planner.internal.Env
Expand Down Expand Up @@ -106,9 +107,9 @@ internal object RexConverter {
*/
private fun visitExprCoerce(node: Expr, ctx: Env, coercion: Rex.Op.Subquery.Coercion = Rex.Op.Subquery.Coercion.SCALAR): Rex {
val rex = super.visitExpr(node, ctx)
return when (rex.op is Rex.Op.Select) {
return when (isSqlSelect(node)) {
true -> rex(StaticType.ANY, rexOpSubquery(rex.op as Rex.Op.Select, coercion))
else -> rex
false -> rex
}
}

Expand Down Expand Up @@ -137,26 +138,57 @@ internal object RexConverter {

override fun visitExprBinary(node: Expr.Binary, context: Env): Rex {
val type = (StaticType.ANY)
// Args
val lhs = visitExprCoerce(node.lhs, context)
val rhs = visitExprCoerce(node.rhs, context)
val args = listOf(lhs, rhs)
val args = when (node.op) {
Expr.Binary.Op.LT, Expr.Binary.Op.GT,
Expr.Binary.Op.LTE, Expr.Binary.Op.GTE,
Expr.Binary.Op.EQ, Expr.Binary.Op.NE -> {
when {
// Example: [1, 2] < (SELECT a, b FROM t)
isLiteralArray(node.lhs) && isSqlSelect(node.rhs) -> {
val lhs = visitExprCoerce(node.lhs, context)
val rhs = visitExprCoerce(node.rhs, context, Rex.Op.Subquery.Coercion.ROW)
listOf(lhs, rhs)
}
// Example: (SELECT a, b FROM t) < [1, 2]
isSqlSelect(node.lhs) && isLiteralArray(node.rhs) -> {
val lhs = visitExprCoerce(node.lhs, context, Rex.Op.Subquery.Coercion.ROW)
val rhs = visitExprCoerce(node.rhs, context)
listOf(lhs, rhs)
}
// Example: 1 < 2
else -> {
val lhs = visitExprCoerce(node.lhs, context)
val rhs = visitExprCoerce(node.rhs, context)
listOf(lhs, rhs)
}
}
}
// Example: 1 + 2
else -> {
val lhs = visitExprCoerce(node.lhs, context)
val rhs = visitExprCoerce(node.rhs, context)
listOf(lhs, rhs)
}
}
// Wrap if a NOT if necessary
return when (node.op) {
Expr.Binary.Op.NE -> {
val op = negate(call("eq", lhs, rhs))
val op = negate(call("eq", *args.toTypedArray()))
rex(type, op)
}
else -> {
// Fn
val id = identifierSymbol(node.op.name.lowercase(), Identifier.CaseSensitivity.SENSITIVE)
val fn = fnUnresolved(id, true)
// Rex
val op = rexOpCallStatic(fn, args)
rex(type, op)
}
}
}

private fun isLiteralArray(node: Expr): Boolean = node is Expr.Collection && (node.type == Expr.Collection.Type.ARRAY || node.type == Expr.Collection.Type.LIST)

private fun isSqlSelect(node: Expr): Boolean = node is Expr.SFW && (node.select is Select.Project || node.select is Select.Star)

private fun mergeIdentifiers(root: Identifier, steps: List<Identifier>): Identifier {
if (steps.isEmpty()) {
return root
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -606,7 +606,7 @@ internal class PlanTyper(
}
is FnMatch.Error -> {
handleUnknownFunction(match)
rexErr("Unknown scalar function $fn")
rexErr("Unknown scalar function $fn for args: $args")
}
}
}
Expand Down
Loading

0 comments on commit cda3869

Please sign in to comment.