From acec20673b4e1b9fcb31a58b8aca4f81786ac64e Mon Sep 17 00:00:00 2001 From: yliuuuu <107505258+yliuuuu@users.noreply.github.com> Date: Wed, 20 Dec 2023 14:51:38 -0800 Subject: [PATCH] Fixes variable resolution (#1322) --- .../org/partiql/planner/internal/Env.kt | 26 ++- .../planner/internal/typer/PlanTyper.kt | 48 +++-- .../kotlin/org/partiql/planner/PlanTest.kt | 164 ++++++++++++++++++ .../planner/util/PlanNodeEquivalentVisitor.kt | 158 +++++++++++++++++ .../test/resources/outputs/basics/select.sql | 50 ++++++ 5 files changed, 424 insertions(+), 22 deletions(-) create mode 100644 partiql-planner/src/test/kotlin/org/partiql/planner/PlanTest.kt create mode 100644 partiql-planner/src/test/kotlin/org/partiql/planner/util/PlanNodeEquivalentVisitor.kt create mode 100644 partiql-planner/src/test/resources/outputs/basics/select.sql diff --git a/partiql-planner/src/main/kotlin/org/partiql/planner/internal/Env.kt b/partiql-planner/src/main/kotlin/org/partiql/planner/internal/Env.kt index 80a7eff38..5bc50a343 100644 --- a/partiql-planner/src/main/kotlin/org/partiql/planner/internal/Env.kt +++ b/partiql-planner/src/main/kotlin/org/partiql/planner/internal/Env.kt @@ -77,16 +77,18 @@ internal sealed interface ResolvedVar { * * @property type Resolved StaticType * @property ordinal Index offset in [TypeEnv] - * @property replacementSteps Path steps to replace. - * @property depth The depth/level of the path match. + * @property resolvedSteps The fully resolved path steps.s */ class Local( override val type: StaticType, override val ordinal: Int, val rootType: StaticType, - val replacementSteps: List, - override val depth: Int - ) : ResolvedVar + val resolvedSteps: List, + ) : ResolvedVar { + // the depth are always going to be 1 because this is local variable. + // the global path, however the path length maybe, going to be replaced by a binding name. + override val depth: Int = 1 + } /** * Metadata for a resolved global variable @@ -233,7 +235,7 @@ internal class Env( catalogs[catalogIndex] = catalogs[catalogIndex].copy( symbols = symbols + listOf(Catalog.Symbol(valuePath, valueType)) ) - catalogIndex to 0 + catalogIndex to catalogs[catalogIndex].symbols.lastIndex } else -> { catalogIndex to index @@ -325,7 +327,7 @@ internal class Env( locals.forEachIndexed { ordinal, binding -> val root = path.steps[0] if (root.isEquivalentTo(binding.name)) { - return ResolvedVar.Local(binding.type, ordinal, binding.type, emptyList(), 1) + return ResolvedVar.Local(binding.type, ordinal, binding.type, path.steps) } } @@ -333,11 +335,17 @@ internal class Env( val matches = mutableListOf() for (ordinal in locals.indices) { val rootType = locals[ordinal].type + val pathPrefix = BindingName(locals[ordinal].name, BindingCase.SENSITIVE) if (rootType is StructType) { val varType = inferStructLookup(rootType, path) if (varType != null) { // we found this path within a struct! - val match = ResolvedVar.Local(varType.resolvedType, ordinal, rootType, varType.replacementPath.steps, varType.replacementPath.steps.size) + val match = ResolvedVar.Local( + varType.resolvedType, + ordinal, + rootType, + listOf(pathPrefix) + varType.replacementPath.steps, + ) matches.add(match) } } @@ -413,7 +421,7 @@ internal class Env( } } // 3. Struct is open - else -> null + else -> key to StaticType.ANY } } diff --git a/partiql-planner/src/main/kotlin/org/partiql/planner/internal/typer/PlanTyper.kt b/partiql-planner/src/main/kotlin/org/partiql/planner/internal/typer/PlanTyper.kt index f9cfd4a9f..f2d9ca566 100644 --- a/partiql-planner/src/main/kotlin/org/partiql/planner/internal/typer/PlanTyper.kt +++ b/partiql-planner/src/main/kotlin/org/partiql/planner/internal/typer/PlanTyper.kt @@ -451,25 +451,47 @@ internal class PlanTyper( return rex(ANY, rexOpErr("Undefined variable ${node.identifier}")) } val type = resolvedVar.type - val op = when (resolvedVar) { - is ResolvedVar.Global -> rexOpGlobal(catalogSymbolRef(resolvedVar.ordinal, resolvedVar.position)) - is ResolvedVar.Local -> rexOpVarResolved(resolvedVar.ordinal) // resolvedLocalPath(resolvedVar) - } - val variable = rex(type, op) - return when (resolvedVar.depth) { - path.steps.size -> variable - else -> { - val foldedPath = path.steps.subList(resolvedVar.depth, path.steps.size).fold(variable) { current, step -> - when (step.bindingCase) { - BindingCase.SENSITIVE -> rex(ANY, rexOpPathKey(current, rex(STRING, rexOpLit(stringValue(step.name))))) - BindingCase.INSENSITIVE -> rex(ANY, rexOpPathSymbol(current, step.name)) + return when (resolvedVar) { + is ResolvedVar.Global -> { + val variable = rex(type, rexOpGlobal(catalogSymbolRef(resolvedVar.ordinal, resolvedVar.position))) + when (resolvedVar.depth) { + path.steps.size -> variable + else -> { + val foldedPath = foldPath(path.steps, resolvedVar.depth, path.steps.size, variable) + visitRex(foldedPath, ctx) + } + } + } + is ResolvedVar.Local -> { + val variable = rex(type, rexOpVarResolved(resolvedVar.ordinal)) + when { + path.isEquivalentTo(resolvedVar.resolvedSteps) && path.steps.size == resolvedVar.depth -> variable + else -> { + val foldedPath = foldPath(resolvedVar.resolvedSteps, resolvedVar.depth, resolvedVar.resolvedSteps.size, variable) + visitRex(foldedPath, ctx) } } - visitRex(foldedPath, ctx) } } } + private fun foldPath(path: List, start: Int, end: Int, global: Rex) = + path.subList(start, end).fold(global) { current, step -> + when (step.bindingCase) { + BindingCase.SENSITIVE -> rex(ANY, rexOpPathKey(current, rex(STRING, rexOpLit(stringValue(step.name))))) + BindingCase.INSENSITIVE -> rex(ANY, rexOpPathSymbol(current, step.name)) + } + } + + private fun BindingPath.isEquivalentTo(other: List): Boolean { + this.steps.forEachIndexed { index, bindingName -> + if (bindingName != other[index]) { + return false + } + } + return true + } + override fun visitRexOpGlobal(node: Rex.Op.Global, ctx: StaticType?): Rex { val catalog = env.catalogs[node.ref.catalog] val type = catalog.symbols[node.ref.symbol].type diff --git a/partiql-planner/src/test/kotlin/org/partiql/planner/PlanTest.kt b/partiql-planner/src/test/kotlin/org/partiql/planner/PlanTest.kt new file mode 100644 index 000000000..d420d75b9 --- /dev/null +++ b/partiql-planner/src/test/kotlin/org/partiql/planner/PlanTest.kt @@ -0,0 +1,164 @@ +package org.partiql.planner + +import org.junit.jupiter.api.DynamicContainer +import org.junit.jupiter.api.DynamicContainer.dynamicContainer +import org.junit.jupiter.api.DynamicNode +import org.junit.jupiter.api.DynamicTest +import org.junit.jupiter.api.TestFactory +import org.partiql.parser.PartiQLParser +import org.partiql.plan.PlanNode +import org.partiql.plan.debug.PlanPrinter +import org.partiql.planner.test.PartiQLTest +import org.partiql.planner.test.PartiQLTestProvider +import org.partiql.planner.util.PlanNodeEquivalentVisitor +import org.partiql.planner.util.ProblemCollector +import org.partiql.plugins.memory.MemoryConnector +import org.partiql.types.BagType +import org.partiql.types.StaticType +import org.partiql.types.StructType +import org.partiql.types.TupleConstraint +import java.io.File +import java.nio.file.Path +import java.time.Instant +import java.util.stream.Stream +import kotlin.io.path.toPath + +// Prevent Unintentional break of the plan +// We currently don't have a good way to assert on the result plan +// so we assert on having the partiql text. +// The input text and the normalized partiql text should produce identical plan. +// I.e., +// if the input text is `SELECT a,b,c FROM T` +// the produced plan will be identical as the normalized query: +// `SELECT "T"['a'] AS "a", "T"['b'] AS "b", "T"['c'] AS "c" FROM "default"."T" AS "T";` +class PlanTest { + val root: Path = this::class.java.getResource("/outputs")!!.toURI().toPath() + + val input = PartiQLTestProvider().apply { load() } + + val session: (PartiQLTest.Key) -> PartiQLPlanner.Session = { key -> + PartiQLPlanner.Session( + queryId = key.toString(), + userId = "user_id", + currentCatalog = "default", + currentDirectory = listOf(), + instant = Instant.now(), + ) + } + + val metadata = MemoryConnector.Metadata.of( + "default.t" to BagType( + StructType( + listOf( + StructType.Field("a", StaticType.BOOL), + StructType.Field("b", StaticType.INT4), + StructType.Field("c", StaticType.STRING), + StructType.Field( + "d", + StructType( + listOf(StructType.Field("e", StaticType.STRING)), + contentClosed = true, + emptyList(), + setOf(TupleConstraint.Open(false)), + emptyMap() + ) + ), + StructType.Field("x", StaticType.ANY), + StructType.Field("z", StaticType.STRING), + StructType.Field("v", StaticType.STRING), + ), + contentClosed = true, + emptyList(), + setOf(TupleConstraint.Open(false)), + emptyMap() + ) + ) + ) + + val pipeline: (PartiQLTest) -> PartiQLPlanner.Result = { test -> + val problemCollector = ProblemCollector() + val ast = PartiQLParser.default().parse(test.statement).root + val planner = PartiQLPlannerBuilder() + .addCatalog("default", metadata) + .build() + planner.plan(ast, session(test.key), problemCollector) + } + + @TestFactory + fun factory(): Stream { + val r = root.toFile() + return r + .listFiles { f -> f.isDirectory }!! + .mapNotNull { load(r, it) } + .stream() + } + + private fun load(parent: File, file: File): DynamicNode? = when { + file.isDirectory -> loadD(parent, file) + file.extension == "sql" -> loadF(parent, file) + else -> null + } + + private fun loadD(parent: File, file: File): DynamicContainer { + val name = file.name + val children = file.listFiles()!!.map { load(file, it) } + return dynamicContainer(name, children) + } + + private fun loadF(parent: File, file: File): DynamicContainer { + val group = parent.name + val tests = parse(group, file) + + val children = tests.map { + // Prepare + val displayName = it.key.toString() + + // Assert + DynamicTest.dynamicTest(displayName) { + val input = input[it.key] ?: error("no test cases") + + val inputPlan = pipeline.invoke(input).plan + val outputPlan = pipeline.invoke(it).plan + assert(inputPlan.isEquaivalentTo(outputPlan)) { + buildString { + this.appendLine("expect plan equivalence") + PlanPrinter.append(this, inputPlan) + PlanPrinter.append(this, outputPlan) + } + } + } + } + return dynamicContainer(file.nameWithoutExtension, children) + } + + private fun parse(group: String, file: File): List { + val tests = mutableListOf() + var name = "" + val statement = StringBuilder() + for (line in file.readLines()) { + // start of test + if (line.startsWith("--#[") and line.endsWith("]")) { + name = line.substring(4, line.length - 1) + statement.clear() + } + if (name.isNotEmpty() && line.isNotBlank()) { + // accumulating test statement + statement.appendLine(line) + } else { + // skip these lines + continue + } + // Finish & Reset + if (line.endsWith(";")) { + val key = PartiQLTest.Key(group, name) + tests.add(PartiQLTest(key, statement.toString())) + name = "" + statement.clear() + } + } + return tests + } + + private fun PlanNode.isEquaivalentTo(other: PlanNode): Boolean = + PlanNodeEquivalentVisitor().visit(this, other) +} diff --git a/partiql-planner/src/test/kotlin/org/partiql/planner/util/PlanNodeEquivalentVisitor.kt b/partiql-planner/src/test/kotlin/org/partiql/planner/util/PlanNodeEquivalentVisitor.kt new file mode 100644 index 000000000..04179370b --- /dev/null +++ b/partiql-planner/src/test/kotlin/org/partiql/planner/util/PlanNodeEquivalentVisitor.kt @@ -0,0 +1,158 @@ +package org.partiql.planner.util + +import org.partiql.plan.Agg +import org.partiql.plan.Catalog +import org.partiql.plan.Fn +import org.partiql.plan.Identifier +import org.partiql.plan.PlanNode +import org.partiql.plan.Rel +import org.partiql.plan.Rex +import org.partiql.plan.visitor.PlanBaseVisitor +import org.partiql.value.PartiQLValueExperimental + +// Work around to assert plan equivalence, +// perhaps the easier way is to have an is equivalent method at code generation time +// but this is good enough for the purpose of testing at the moment. +class PlanNodeEquivalentVisitor : PlanBaseVisitor() { + override fun visit(node: PlanNode, ctx: PlanNode): Boolean = node.accept(this, ctx) + + override fun visitCatalog(node: Catalog, ctx: PlanNode): Boolean { + if (!super.visitCatalog(node, ctx)) return false + ctx as Catalog + if (node.name != ctx.name) return false + return true + } + + override fun visitCatalogSymbol(node: Catalog.Symbol, ctx: PlanNode): Boolean { + if (!super.visitCatalogSymbol(node, ctx)) return false + ctx as Catalog.Symbol + if (node.path != ctx.path) return false + if (node.type != ctx.type) return false + return true + } + + override fun visitCatalogSymbolRef(node: Catalog.Symbol.Ref, ctx: PlanNode): Boolean { + if (!super.visitCatalogSymbolRef(node, ctx)) return false + ctx as Catalog.Symbol.Ref + if (node.catalog != ctx.catalog) return false + if (node.symbol != ctx.symbol) return false + return true + } + + override fun visitFn(node: Fn, ctx: PlanNode): Boolean { + if (!super.visitFn(node, ctx)) return false + ctx as Fn + if (node.signature != ctx.signature) return false + return true + } + + override fun visitAgg(node: Agg, ctx: PlanNode): Boolean { + if (!super.visitAgg(node, ctx)) return false + ctx as Agg + if (node.signature != ctx.signature) return false + return true + } + + override fun visitIdentifierSymbol(node: Identifier.Symbol, ctx: PlanNode): Boolean { + if (!super.visitIdentifierSymbol(node, ctx)) return false + ctx as Identifier.Symbol + if (node.symbol != ctx.symbol) return false + if (node.caseSensitivity != ctx.caseSensitivity) return false + return true + } + + override fun visitRex(node: Rex, ctx: PlanNode): Boolean { + if (!super.visitRex(node, ctx)) return false + ctx as Rex + if (node.type != ctx.type) return false + return true + } + + @OptIn(PartiQLValueExperimental::class) + override fun visitRexOpLit(node: Rex.Op.Lit, ctx: PlanNode): Boolean { + if (!super.visitRexOpLit(node, ctx)) return false + ctx as Rex.Op.Lit + if (node.value != ctx.value) return false + return true + } + + override fun visitRexOpVar(node: Rex.Op.Var, ctx: PlanNode): Boolean { + if (!super.visitRexOpVar(node, ctx)) return false + ctx as Rex.Op.Var + if (node.ref != ctx.ref) return false + return true + } + + override fun visitRexOpPathSymbol(node: Rex.Op.Path.Symbol, ctx: PlanNode): Boolean { + if (!super.visitRexOpPathSymbol(node, ctx)) return false + ctx as Rex.Op.Path.Symbol + if (node.key != ctx.key) return false + return true + } + + override fun visitRexOpErr(node: Rex.Op.Err, ctx: PlanNode): Boolean { + if (!super.visitRexOpErr(node, ctx)) return false + ctx as Rex.Op.Err + if (node.message != ctx.message) return false + return true + } + + override fun visitRelType(node: Rel.Type, ctx: PlanNode): Boolean { + if (!super.visitRelType(node, ctx)) return false + ctx as Rel.Type + if (node.props != ctx.props) return false + return true + } + + override fun visitRelOpSortSpec(node: Rel.Op.Sort.Spec, ctx: PlanNode): Boolean { + if (!super.visitRelOpSortSpec(node, ctx)) return false + ctx as Rel.Op.Sort.Spec + if (node.order != ctx.order) return false + return true + } + + override fun visitRelOpJoin(node: Rel.Op.Join, ctx: PlanNode): Boolean { + if (!super.visitRelOpJoin(node, ctx)) return false + ctx as Rel.Op.Join + if (node.type != ctx.type) return false + return true + } + + override fun visitRelOpAggregate(node: Rel.Op.Aggregate, ctx: PlanNode): Boolean { + if (!super.visitRelOpAggregate(node, ctx)) return false + ctx as Rel.Op.Aggregate + if (node.strategy != ctx.strategy) return false + return true + } + + override fun visitRelOpExcludeStepCollIndex(node: Rel.Op.Exclude.Step.CollIndex, ctx: PlanNode): Boolean { + if (!super.visitRelOpExcludeStepCollIndex(node, ctx)) return false + ctx as Rel.Op.Exclude.Step.CollIndex + if (node.index != ctx.index) return false + return true + } + + override fun visitRelOpErr(node: Rel.Op.Err, ctx: PlanNode): Boolean { + if (!super.visitRelOpErr(node, ctx)) return false + ctx as Rel.Op.Err + if (node.message != ctx.message) return false + return true + } + + override fun visitRelBinding(node: Rel.Binding, ctx: PlanNode): Boolean { + if (!super.visitRelBinding(node, ctx)) return false + ctx as Rel.Binding + if (node.name != ctx.name) return false + if (node.type != ctx.type) return false + return true + } + + override fun defaultReturn(node: PlanNode, ctx: PlanNode): Boolean { + if (ctx.javaClass != node.javaClass) return false + if (node.children.size != ctx.children.size) return false + node.children.forEachIndexed { index, planNode -> + if (planNode.accept(this, ctx.children[index])) return false + } + return true + } +} diff --git a/partiql-planner/src/test/resources/outputs/basics/select.sql b/partiql-planner/src/test/resources/outputs/basics/select.sql new file mode 100644 index 000000000..ceead5f36 --- /dev/null +++ b/partiql-planner/src/test/resources/outputs/basics/select.sql @@ -0,0 +1,50 @@ +--#[select-00] +SELECT "T"['a'] AS "a", "T"['b'] AS "b", "T"['c'] AS "c" FROM "default"."T" AS "T"; + +--#[select-01] +SELECT "T".* FROM "default"."T" AS "T"; + +--#[select-02] +SELECT "T"['a'] AS "a", "T"['b'] AS "b", "T"['c'] AS "c" FROM "default"."T" AS "T"; + +--#[select-03] +SELECT VALUE "T"['a'] FROM "default"."T" AS "T"; + +--#[select-04] +SELECT "t1".*, "t2".* FROM "default"."T" AS "t1" INNER JOIN "default"."T" AS "t2" ON true; + +--#[select-05] +SELECT "T"['d'].* FROM "default"."T" AS "T"; + +--#[select-06] +SELECT "T" AS "t", "T"['d'].* FROM "default"."T" AS "T"; + +--#[select-07] +SELECT "T"['d'].*, "T"['d'].* FROM "default"."T" AS "T"; + +--#[select-08] +SELECT "T"['d'].* FROM "default"."T" AS "T"; + +--#[select-09] +SELECT "T".* FROM "default"."T" AS "T"; + +--#[select-10] +SELECT "T"['c'] || CURRENT_USER AS "_1" FROM "default"."T" AS "T"; + +--#[select-11] +SELECT CURRENT_USER AS "CURRENT_USER" FROM "default"."T" AS "T"; + +--#[select-12] +SELECT CURRENT_DATE AS "CURRENT_DATE" FROM "default"."T" AS "T"; + +--#[select-13] +SELECT DATE_DIFF(DAY, CURRENT_DATE, CURRENT_DATE) AS "_1" FROM "default"."T" AS "T"; + +--#[select-14] +SELECT DATE_ADD(DAY, 5, CURRENT_DATE) AS "_1" FROM "default"."T" AS "T" + +--#[select-15] +SELECT DATE_ADD(DAY, -5, CURRENT_DATE) AS "_1" FROM "default"."T" AS "T" + +--#[select-16] +SELECT "t"['a'] AS "a" FROM "default"."T" AS "t"; \ No newline at end of file