Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fixes derived binder numbering #1269

Merged
merged 1 commit into from
Nov 21, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 12 additions & 4 deletions partiql-ast/src/main/kotlin/org/partiql/ast/helpers/ToBinder.kt
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ import org.partiql.ast.builder.ast
import org.partiql.value.PartiQLValueExperimental
import org.partiql.value.StringValue

private val col = { index: Int -> "_${index + 1}" }
private val col = { index: () -> Int -> "_${index()}" }

/**
* Produces a "binder" (AS alias) for an expression following the given rules:
Expand All @@ -18,14 +18,22 @@ private val col = { index: Int -> "_${index + 1}" }
*
* See https://github.com/partiql/partiql-lang-kotlin/issues/1122
*/
public fun Expr.toBinder(index: Int): Identifier.Symbol = when (this) {
public fun Expr.toBinder(index: () -> Int): Identifier.Symbol = when (this) {
is Expr.Var -> this.identifier.toBinder()
is Expr.Path -> this.toBinder(index)
is Expr.Cast -> this.value.toBinder(index)
is Expr.SessionAttribute -> this.attribute.name.uppercase().toBinder()
else -> col(index).toBinder()
}

/**
* Simple toBinder that uses an int literal rather than a closure.
*
* @param index
* @return
*/
public fun Expr.toBinder(index: Int): Identifier.Symbol = toBinder { index }

private fun String.toBinder(): Identifier.Symbol = ast {
// Every binder preserves case
identifierSymbol(this@toBinder, Identifier.CaseSensitivity.SENSITIVE)
Expand All @@ -40,14 +48,14 @@ private fun Identifier.toBinder(): Identifier.Symbol = when (this@toBinder) {
}

@OptIn(PartiQLValueExperimental::class)
private fun Expr.Path.toBinder(index: Int): Identifier.Symbol {
private fun Expr.Path.toBinder(index: () -> Int): Identifier.Symbol {
if (steps.isEmpty()) return root.toBinder(index)
return when (val last = steps.last()) {
is Expr.Path.Step.Symbol -> last.symbol.toBinder()
is Expr.Path.Step.Index -> {
val k = last.key
if (k is Expr.Lit && k.value is StringValue) {
(k.value as StringValue).value!!.toBinder()
k.value.value!!.toBinder()
} else {
col(index).toBinder()
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ import org.partiql.ast.exprStructField
import org.partiql.ast.exprVar
import org.partiql.ast.helpers.toBinder
import org.partiql.ast.identifierSymbol
import org.partiql.ast.selectProject
import org.partiql.ast.selectProjectItemExpression
import org.partiql.ast.selectValue
import org.partiql.ast.typeStruct
Expand Down Expand Up @@ -86,9 +87,20 @@ import org.partiql.value.stringValue
*/
internal object NormalizeSelect : AstPass {

override fun apply(statement: Statement): Statement = Visitor.visitStatement(statement, 0) as Statement
override fun apply(statement: Statement): Statement = Visitor.visitStatement(statement, newCtx()) as Statement

private object Visitor : AstRewriter<Int>() {
/**
* Closure for incrementing a derived binding counter
*/
private fun newCtx(): () -> Int = run {
var i = 1;
{ i++ }
}

/**
* The type parameter () -> Int
*/
private object Visitor : AstRewriter<() -> Int>() {

/**
* This is used to give projections a name. For example:
Expand All @@ -112,25 +124,35 @@ internal object NormalizeSelect : AstPass {
*/
private val col = { index: Int -> "_${index + 1}" }

override fun visitExprSFW(node: Expr.SFW, ctx: Int): Expr.SFW {
override fun visitExprSFW(node: Expr.SFW, ctx: () -> Int): Expr.SFW {
val sfw = super.visitExprSFW(node, ctx) as Expr.SFW
return when (val select = sfw.select) {
is Select.Star -> sfw.copy(select = visitSelectAll(select, sfw.from))
else -> sfw
}
}

override fun visitSelectProject(node: Select.Project, ctx: Int): AstNode {
val visitedNode = super.visitSelectProject(node, ctx) as? Select.Project
?: error("VisitSelectProject should have returned a Select.Project")
override fun visitSelectProject(node: Select.Project, ctx: () -> Int): AstNode {

// Visit items, adding a binder if necessary
var diff = false
val visitedItems = ArrayList<Select.Project.Item>(node.items.size)
node.items.forEach { n ->
val item = visitSelectProjectItem(n, ctx) as Select.Project.Item
if (item !== n) diff = true
visitedItems.add(item)
}
val visitedNode = if (diff) selectProject(visitedItems, node.setq) else node

// Rewrite selection
return when (node.items.any { it is Select.Project.Item.All }) {
false -> visitSelectProjectWithoutProjectAll(visitedNode)
true -> visitSelectProjectWithProjectAll(visitedNode)
}
}

override fun visitSelectProjectItemExpression(node: Select.Project.Item.Expression, ctx: Int): Select.Project.Item.Expression {
val expr = visitExpr(node.expr, 0) as Expr
override fun visitSelectProjectItemExpression(node: Select.Project.Item.Expression, ctx: () -> Int): Select.Project.Item.Expression {
val expr = visitExpr(node.expr, newCtx()) as Expr
val alias = when (node.asAlias) {
null -> expr.toBinder(ctx)
else -> node.asAlias
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,184 @@
package org.partiql.ast.normalize

import org.junit.jupiter.api.Test
import org.partiql.ast.Expr
import org.partiql.ast.From
import org.partiql.ast.Identifier
import org.partiql.ast.Select
import org.partiql.ast.builder.ast
import org.partiql.ast.exprLit
import org.partiql.ast.exprVar
import org.partiql.ast.identifierSymbol
import org.partiql.ast.selectProjectItemExpression
import org.partiql.value.PartiQLValueExperimental
import org.partiql.value.int32Value
import org.partiql.value.stringValue
import kotlin.test.assertEquals

class NormalizeSelectTest {

/**
* SELECT a, b, c FROM T
*
* SELECT VALUE {
* 'a': a,
* 'b': b,
* 'c': c
* } FROM T
*/
@Test
fun testDerivedBinders_00() {
val input = select(
varItem("a"),
varItem("b"),
varItem("c"),
)
val expected = selectValue(
"a" to variable("a"),
"b" to variable("b"),
"c" to variable("c"),
)
val actual = NormalizeSelect.apply(input)
assertEquals(expected, actual)
}

/**
* SELECT 1, 2, 3 FROM T
*
* SELECT VALUE {
* '_1': 1,
* '_2': 2,
* '_3': 3
* } FROM T
*/
@Test
fun testDerivedBinders_01() {
val input = select(
litItem(1),
litItem(2),
litItem(3),
)
val expected = selectValue(
"_1" to lit(1),
"_2" to lit(2),
"_3" to lit(3),
)
val actual = NormalizeSelect.apply(input)
assertEquals(expected, actual)
}

/**
* SELECT a, 2, 3 FROM T
*
* SELECT VALUE {
* 'a': a,
* '_1': 2,
* '_2': 3
* } FROM T
*/
@Test
fun testDerivedBinders_02() {
val input = select(
varItem("a"),
litItem(2),
litItem(3),
)
val expected = selectValue(
"a" to variable("a"),
"_1" to lit(2),
"_2" to lit(3),
)
val actual = NormalizeSelect.apply(input)
assertEquals(expected, actual)
}

/**
* SELECT a AS a, 2 AS b, 3 AS c FROM T
*
* SELECT VALUE {
* 'a': a,
* 'b': 2,
* 'c': 3
* } FROM T
*/
@Test
fun testDerivedBinders_03() {
val input = select(
varItem("a", "a"),
litItem(2, "b"),
litItem(3, "c"),
)
val expected = selectValue(
"a" to variable("a"),
"b" to lit(2),
"c" to lit(3),
)
val actual = NormalizeSelect.apply(input)
assertEquals(expected, actual)
}

// ----- HELPERS -------------------------

private fun variable(name: String) = exprVar(
identifier = identifierSymbol(
symbol = name,
caseSensitivity = Identifier.CaseSensitivity.INSENSITIVE,
),
scope = Expr.Var.Scope.DEFAULT,
)

private fun select(vararg items: Select.Project.Item) = ast {
statementQuery {
expr = exprSFW {
select = selectProject {
this.items += items
}
from = fromValue {
expr = variable("T")
type = From.Value.Type.SCAN
}
}
}
}

@OptIn(PartiQLValueExperimental::class)
private fun selectValue(vararg items: Pair<String, Expr>) = ast {
statementQuery {
expr = exprSFW {
select = selectValue {
constructor = exprStruct {
for ((k, v) in items) {
fields += exprStructField {
name = exprLit(stringValue(k))
value = v
}
}
}
}
from = fromValue {
expr = exprVar {
identifier = identifierSymbol {
symbol = "T"
caseSensitivity = Identifier.CaseSensitivity.INSENSITIVE
}
scope = Expr.Var.Scope.DEFAULT
}
type = From.Value.Type.SCAN
}
}
}
}

private fun varItem(symbol: String, asAlias: String? = null) = selectProjectItemExpression(
expr = variable(symbol),
asAlias = asAlias?.let { identifierSymbol(asAlias, Identifier.CaseSensitivity.INSENSITIVE) }
)

private fun litItem(value: Int, asAlias: String? = null) = selectProjectItemExpression(
expr = lit(value),
asAlias = asAlias?.let { identifierSymbol(asAlias, Identifier.CaseSensitivity.INSENSITIVE) }
)

@OptIn(PartiQLValueExperimental::class)
private fun lit(value: Int) = exprLit(int32Value(value))
}
Original file line number Diff line number Diff line change
Expand Up @@ -2474,7 +2474,27 @@ class PartiQLSchemaInferencerTests {
@JvmStatic
fun aggregationCases() = listOf(
SuccessTestCase(
name = "AGGREGATE over INTS",
name = "AGGREGATE over INTS, without alias",
query = "SELECT a, COUNT(*), SUM(a), MIN(b) FROM << {'a': 1, 'b': 2} >> GROUP BY a",
expected = BagType(
StructType(
fields = mapOf(
"a" to INT4,
"_1" to INT4,
"_2" to INT4.asNullable(),
"_3" to INT4.asNullable(),
),
contentClosed = true,
constraints = setOf(
TupleConstraint.Open(false),
TupleConstraint.UniqueAttrs(true),
TupleConstraint.Ordered
)
)
)
),
SuccessTestCase(
name = "AGGREGATE over INTS, with alias",
query = "SELECT a, COUNT(*) AS c, SUM(a) AS s, MIN(b) AS m FROM << {'a': 1, 'b': 2} >> GROUP BY a",
expected = BagType(
StructType(
Expand Down Expand Up @@ -2632,7 +2652,7 @@ class PartiQLSchemaInferencerTests {

class IgnoredTestCase(
val shouldBe: TestCase,
reason: String
reason: String,
) : TestCase() {
override fun toString(): String = "Disabled - $shouldBe"
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
{
type: "bag",
items: {
type: "struct",
constraints: [ closed, ordered, unique ],
fields: [
{
name: "a",
type: "bool",
},
{
name: "b",
type: "int32",
},
{
name: "c",
type: "string",
},
{
name: "d",
type: {
type: "struct",
constraints: [ closed, ordered, unique ],
fields: [
{
name: "e",
type: "string"
}
]
},
}
]
}
}
Loading
Loading