Skip to content

Commit

Permalink
Adds basic scalar and row-value subquery coercion (#1258)
Browse files Browse the repository at this point in the history
  • Loading branch information
rchowell authored Nov 2, 2023
1 parent 262074f commit 5946539
Show file tree
Hide file tree
Showing 11 changed files with 272 additions and 252 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ import org.partiql.ast.AstNode
import org.partiql.ast.Expr
import org.partiql.ast.From
import org.partiql.ast.Statement
import org.partiql.ast.builder.ast
import org.partiql.ast.fromJoin
import org.partiql.ast.helpers.toBinder
import org.partiql.ast.util.AstRewriter

Expand All @@ -31,27 +31,33 @@ internal object NormalizeFromSource : AstPass {

private object Visitor : AstRewriter<Int>() {

// Each SFW starts the ctx count again.
override fun visitExprSFW(node: Expr.SFW, ctx: Int): AstNode = super.visitExprSFW(node, 0)

override fun visitStatementDMLBatchLegacy(node: Statement.DML.BatchLegacy, ctx: Int): AstNode =
super.visitStatementDMLBatchLegacy(node, 0)

override fun visitFrom(node: From, ctx: Int) = super.visitFrom(node, ctx) as From

override fun visitFromJoin(node: From.Join, ctx: Int) = ast {
override fun visitFromJoin(node: From.Join, ctx: Int): From {
val lhs = visitFrom(node.lhs, ctx)
val rhs = visitFrom(node.rhs, ctx + 1)
val condition = node.condition?.let { visitExpr(it, ctx) as Expr }
if (lhs !== node.lhs || rhs !== node.rhs || condition !== node.condition) {
return if (lhs !== node.lhs || rhs !== node.rhs || condition !== node.condition) {
fromJoin(lhs, rhs, node.type, condition)
} else {
node
}
}

override fun visitFromValue(node: From.Value, ctx: Int) = when (node.asAlias) {
null -> node.copy(asAlias = node.expr.toBinder(ctx))
else -> node
override fun visitFromValue(node: From.Value, ctx: Int): From {
val expr = visitExpr(node.expr, ctx) as Expr
val asAlias = node.asAlias ?: expr.toBinder(ctx)
return if (expr !== node.expr || asAlias !== node.asAlias) {
node.copy(expr = expr, asAlias = asAlias)
} else {
node
}
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,6 @@ import org.partiql.value.stringValue
* } FROM A AS x
* ```
*
* TODO: GROUP BY
* TODO: LET
*
* Requires [NormalizeFromSource].
Expand Down

This file was deleted.

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@ import org.junit.jupiter.params.provider.ArgumentsSource
import org.junit.jupiter.params.provider.MethodSource
import org.partiql.annotations.ExperimentalPartiQLSchemaInferencer
import org.partiql.errors.Problem
import org.partiql.errors.ProblemHandler
import org.partiql.errors.UNKNOWN_PROBLEM_LOCATION
import org.partiql.lang.errors.ProblemCollector
import org.partiql.lang.planner.transforms.PartiQLSchemaInferencerTests.TestCase.ErrorTestCase
Expand Down Expand Up @@ -131,6 +130,11 @@ class PartiQLSchemaInferencerTests {
@Execution(ExecutionMode.CONCURRENT)
fun testCaseWhens(tc: TestCase) = runTest(tc)

@ParameterizedTest
@MethodSource("subqueryCases")
@Execution(ExecutionMode.CONCURRENT)
fun testSubqueries(tc: TestCase) = runTest(tc)

companion object {

private val root = this::class.java.getResource("/catalogs/default")!!.toURI().toPath().pathString
Expand All @@ -156,6 +160,10 @@ class PartiQLSchemaInferencerTests {
field("connector_name", ionString("local")),
field("root", ionString("$root/pql")),
),
"subqueries" to ionStructOf(
field("connector_name", ionString("local")),
field("root", ionString("$root/subqueries")),
),
)

const val CATALOG_AWS = "aws"
Expand Down Expand Up @@ -2463,6 +2471,69 @@ class PartiQLSchemaInferencerTests {
)
),
)

@JvmStatic
fun subqueryCases() = listOf(
SuccessTestCase(
name = "Subquery IN collection",
catalog = "subqueries",
key = PartiQLTest.Key("subquery", "subquery-00"),
expected = BagType(
StructType(
fields = mapOf(
"x" to INT4,
),
contentClosed = true,
constraints = setOf(
TupleConstraint.Open(false),
TupleConstraint.UniqueAttrs(true),
TupleConstraint.Ordered
)
)
)
),
SuccessTestCase(
name = "Subquery scalar coercion",
catalog = "subqueries",
key = PartiQLTest.Key("subquery", "subquery-01"),
expected = BagType(
StructType(
fields = mapOf(
"x" to INT4,
),
contentClosed = true,
constraints = setOf(
TupleConstraint.Open(false),
TupleConstraint.UniqueAttrs(true),
TupleConstraint.Ordered
)
)
)
),
SuccessTestCase(
name = "Subquery simple JOIN",
catalog = "subqueries",
key = PartiQLTest.Key("subquery", "subquery-02"),
expected = BagType(
StructType(
fields = mapOf(
"x" to INT4,
"y" to INT4,
"z" to INT4,
"a" to INT4,
"b" to INT4,
"c" to INT4,
),
contentClosed = true,
constraints = setOf(
TupleConstraint.Open(false),
TupleConstraint.UniqueAttrs(true),
TupleConstraint.Ordered
)
)
)
),
)
}

sealed class TestCase {
Expand All @@ -2474,7 +2545,7 @@ class PartiQLSchemaInferencerTests {
val catalog: String? = null,
val catalogPath: List<String> = emptyList(),
val expected: StaticType,
val warnings: ProblemHandler? = null
val warnings: ProblemHandler? = null,
) : TestCase() {
override fun toString(): String = "$name : $query"
}
Expand Down
8 changes: 3 additions & 5 deletions partiql-plan/src/main/resources/partiql_plan_0_1.ion
Original file line number Diff line number Diff line change
Expand Up @@ -151,11 +151,9 @@ rex::{
rel: rel,
},

coll_to_scalar::{
subquery: {
select: select,
type: static_type // reify `select` type
}
subquery::{
select: select,
coercion: [ SCALAR, ROW ],
},

select::{
Expand Down
Loading

0 comments on commit 5946539

Please sign in to comment.