Skip to content

Commit

Permalink
plan Typer tests
Browse files Browse the repository at this point in the history
  • Loading branch information
yliuuuu committed Nov 14, 2023
1 parent 9e1bebe commit 140abd0
Show file tree
Hide file tree
Showing 25 changed files with 1,640 additions and 78 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -2052,6 +2052,7 @@ internal class PartiQLParserDefault : PartiQLParser {
val n = ctx.arg0?.text?.toInt()
when (ctx.datatype.type) {
GeneratedParser.FLOAT -> when (n) {
null -> typeFloat64()
32 -> typeFloat32()
64 -> typeFloat64()
else -> throw error(ctx.datatype, "Invalid FLOAT precision. Expected 32 or 64")
Expand Down
1 change: 1 addition & 0 deletions partiql-planner/build.gradle.kts
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ dependencies {
// Test
testImplementation(project(":partiql-parser"))
testImplementation(project(":plugins:partiql-local"))
testImplementation(project(":plugins:partiql-memory"))
// Test Fixtures
testFixturesImplementation(project(":partiql-spi"))
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -233,28 +233,31 @@ object PartiQLHeader : Header() {

// SPECIAL FORMS

private fun like(): List<FunctionSignature.Scalar> = listOf(
FunctionSignature.Scalar(
name = "like",
returns = BOOL,
parameters = listOf(
FunctionParameter("value", STRING),
FunctionParameter("pattern", STRING),
private fun like(): List<FunctionSignature.Scalar> = types.text.flatMap { t ->
listOf(
FunctionSignature.Scalar(
name = "like",
returns = BOOL,
parameters = listOf(
FunctionParameter("value", t),
FunctionParameter("pattern", t),
),
isNullCall = true,
isNullable = false,
),
isNullCall = true,
isNullable = false,
),
FunctionSignature.Scalar(
name = "like_escape",
returns = BOOL,
parameters = listOf(
FunctionParameter("value", STRING),
FunctionParameter("pattern", STRING),
FunctionParameter("escape", STRING),
FunctionSignature.Scalar(
name = "like_escape",
returns = BOOL,
parameters = listOf(
FunctionParameter("value", t),
FunctionParameter("pattern", t),
FunctionParameter("escape", t),
),
isNullCall = true,
isNullable = false,
),
isNullCall = true,
isNullable = false,
),
)
} + listOf(
FunctionSignature.Scalar(
name = "like",
returns = BOOL,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -224,8 +224,14 @@ internal class FnResolver(private val headers: List<Header>) {
}
}
}
// we made a match
return mapping
// if all elements requires casting, then no match
// because there must be another function definition that requires no casting
return if (mapping.contains(null)) {
// we made a match
mapping
} else {
null
}
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,7 @@ import org.partiql.types.TupleConstraint
import org.partiql.types.function.FunctionSignature
import org.partiql.value.BoolValue
import org.partiql.value.PartiQLValueExperimental
import org.partiql.value.PartiQLValueType
import org.partiql.value.TextValue
import org.partiql.value.boolValue

Expand Down Expand Up @@ -510,6 +511,7 @@ internal class PlanTyper(
// Type the arguments
val fn = node.fn as Fn.Unresolved
val isEq = fn.isEq()
val isTypeAssertion = fn.isTypeAssertion()
var missingArg = false
val args = node.args.map {
val arg = visitRex(it, null)
Expand All @@ -518,7 +520,7 @@ internal class PlanTyper(
}

// 7.1 All functions return MISSING when one of their inputs is MISSING (except `=`)
if (missingArg && !isEq) {
if (missingArg && !isEq && !isTypeAssertion) {
handleAlwaysMissing()
return rex(StaticType.MISSING, rexOpCall(fn, args))
}
Expand All @@ -532,6 +534,13 @@ internal class PlanTyper(
val newArgs = rewriteFnArgs(match.mapping, args)
val returns = newFn.signature.returns

// dynamic function resolution should only be called upon the arg type is any type or any of type
if (newFn.signature.parameters.all { it.type == PartiQLValueType.ANY } && newFn.signature.isMissable) {
if (!newArgs.map { it.type }.any { it is AnyType || it is AnyOfType }) {
handleAlwaysMissing()
return rex(StaticType.MISSING, rexOpCall(fn, args))
}
}
// Determine the nullability of the return type
var isNull = false // True iff NULL CALL and has a NULL arg
var isNullable = false // True iff NULL CALL and has a NULLABLE arg; or is a NULLABLE operator
Expand All @@ -557,7 +566,7 @@ internal class PlanTyper(
}

// Some operators can return MISSING during runtime
if (match.isMissable && !isEq) {
if (match.isMissable && !isEq && !isTypeAssertion) {
type = StaticType.unionOf(type, StaticType.MISSING)
}

Expand Down Expand Up @@ -1348,6 +1357,10 @@ internal class PlanTyper(
return (identifier is Identifier.Symbol && (identifier as Identifier.Symbol).symbol == "eq")
}

private fun Fn.Unresolved.isTypeAssertion(): Boolean {
return (identifier is Identifier.Symbol && (identifier as Identifier.Symbol).symbol.startsWith("is"))
}

/**
* This will make all binding values nullables. If the value is a struct, each field will be nullable.
*
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -181,17 +181,17 @@ internal class TypeLattice private constructor(
)
graph[BOOL] = relationships(
BOOL to coercion(),
INT8 to coercion(),
INT16 to coercion(),
INT32 to coercion(),
INT64 to coercion(),
INT to coercion(),
DECIMAL to coercion(),
FLOAT32 to coercion(),
FLOAT64 to coercion(),
CHAR to coercion(),
STRING to coercion(),
SYMBOL to coercion(),
INT8 to explicit(),
INT16 to explicit(),
INT32 to explicit(),
INT64 to explicit(),
INT to explicit(),
DECIMAL to explicit(),
FLOAT32 to explicit(),
FLOAT64 to explicit(),
CHAR to explicit(),
STRING to explicit(),
SYMBOL to explicit(),
)
graph[INT8] = relationships(
BOOL to explicit(),
Expand Down
Original file line number Diff line number Diff line change
@@ -1,13 +1,12 @@
package org.partiql.planner

import org.junit.jupiter.api.Disabled
import org.junit.jupiter.api.Test

class HeaderTest {

@Test
@Disabled
// @Disabled
fun print() {
println(PartiQLHeader)
println(PartiQLHeader.toString())
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,124 @@
package org.partiql.planner.typer

import com.amazon.ionelement.api.ionString
import com.amazon.ionelement.api.ionStructOf
import org.junit.jupiter.api.DynamicContainer
import org.junit.jupiter.api.DynamicTest
import org.partiql.errors.Problem
import org.partiql.errors.ProblemCallback
import org.partiql.errors.ProblemSeverity
import org.partiql.parser.PartiQLParserBuilder
import org.partiql.plan.Statement
import org.partiql.planner.PartiQLPlanner
import org.partiql.planner.PartiQLPlannerBuilder
import org.partiql.planner.test.PartiQLTest
import org.partiql.planner.test.PartiQLTestProvider
import org.partiql.plugins.memory.MemoryCatalog
import org.partiql.plugins.memory.MemoryPlugin
import org.partiql.types.StaticType
import java.util.Random
import java.util.stream.Stream

abstract class PartiQLTyperTestBase {
sealed class TestResult {
data class Success(val expectedType: StaticType) : TestResult() {
override fun toString(): String = "Success_$expectedType"
}

object Failure : TestResult() {
override fun toString(): String = "Failure"
}
}

internal class ProblemCollector : ProblemCallback {
private val problemList = mutableListOf<Problem>()

val problems: List<Problem>
get() = problemList

val hasErrors: Boolean
get() = problemList.any { it.details.severity == ProblemSeverity.ERROR }

val hasWarnings: Boolean
get() = problemList.any { it.details.severity == ProblemSeverity.WARNING }

override fun invoke(problem: Problem) {
problemList.add(problem)
}
}

companion object {
internal val session: ((String) -> PartiQLPlanner.Session) = { catalog ->
PartiQLPlanner.Session(
queryId = Random().nextInt().toString(),
userId = "test-user",
currentCatalog = catalog,
catalogConfig = mapOf(
catalog to ionStructOf(
"connector_name" to ionString("memory")
)
)
)
}
}

val inputs = PartiQLTestProvider().apply { load() }

val testingPipeline: ((String, String, MemoryCatalog.Provider, ProblemCallback) -> PartiQLPlanner.Result) = { query, catalog, catalogProvider, collector ->
val ast = PartiQLParserBuilder.standard().build().parse(query).root
val planner = PartiQLPlannerBuilder().plugins(listOf(MemoryPlugin(catalogProvider))).build()
planner.plan(ast, session(catalog), collector)
}

fun testGen(
testCategory: String,
tests: List<PartiQLTest>,
argsMap: Map<TestResult, Set<List<StaticType>>>,
): Stream<DynamicContainer> {
val catalogProvider = MemoryCatalog.Provider()

return tests.map { test ->
val group = test.statement
val children = argsMap.flatMap { (key, value) ->
value.mapIndexed { index: Int, types: List<StaticType> ->
val testName = "${testCategory}_${key}_$index"
catalogProvider[testName] = MemoryCatalog.of(
*(
types.mapIndexed { i, t ->
"t${i + 1}" to t
}.toTypedArray()
)
)
val displayName = "$group | $testName | $types"
val statement = test.statement
// Assert
DynamicTest.dynamicTest(displayName) {
val pc = ProblemCollector()
if (key is TestResult.Success) {
val result = testingPipeline(statement, testName, catalogProvider, pc)
val root = (result.plan.statement as Statement.Query).root
val actualType = root.type
assert(actualType == key.expectedType) {
"""
expected Type is : ${key.expectedType}
actual Type is : $actualType
""".trimIndent()
}
} else {
val result = testingPipeline(statement, testName, catalogProvider, pc)
val root = (result.plan.statement as Statement.Query).root
val actualType = root.type
assert(actualType == StaticType.MISSING) {
"""
expected Type is : missing
actual Type is : $actualType
""".trimIndent()
}
}
}
}
}
DynamicContainer.dynamicContainer(group, children)
}.stream()
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
package org.partiql.planner.typer.operator

import org.junit.jupiter.api.DynamicContainer
import org.junit.jupiter.api.TestFactory
import org.partiql.planner.typer.PartiQLTyperTestBase
import org.partiql.planner.util.CastType
import org.partiql.planner.util.allNumberType
import org.partiql.planner.util.allSupportedType
import org.partiql.planner.util.cartesianProduct
import org.partiql.planner.util.castTable
import org.partiql.types.StaticType
import java.util.stream.Stream

class OpArithmeticTest : PartiQLTyperTestBase() {
@TestFactory
fun arithmetic(): Stream<DynamicContainer> {
val tests = listOf(
"expr-61",
"expr-62",
"expr-63",
"expr-64",
"expr-65",
).map { inputs.get("basics", it)!! }

val argsMap: Map<TestResult, Set<List<StaticType>>> = buildMap {
val successArgs = (allNumberType + listOf(StaticType.NULL, StaticType.MISSING))
.let { cartesianProduct(it, it) }
val failureArgs = cartesianProduct(
allSupportedType,
allSupportedType
).filterNot {
successArgs.contains(it)
}.toSet()

successArgs.forEach { args: List<StaticType> ->
val arg0 = args.first()
val arg1 = args[1]
if (args.contains(StaticType.MISSING)) {
(this[TestResult.Success(StaticType.MISSING)] ?: setOf(args)).let {
put(TestResult.Success(StaticType.MISSING), it + setOf(args))
}
} else if (args.contains(StaticType.NULL)) {
(this[TestResult.Success(StaticType.NULL)] ?: setOf(args)).let {
put(TestResult.Success(StaticType.NULL), it + setOf(args))
}
} else if (arg0 == arg1) {
(this[TestResult.Success(arg1)] ?: setOf(args)).let {
put(TestResult.Success(arg1), it + setOf(args))
}
} else if (castTable(arg1, arg0) == CastType.COERCION) {
(this[TestResult.Success(arg0)] ?: setOf(args)).let {
put(TestResult.Success(arg0), it + setOf(args))
}
} else {
(this[TestResult.Success(arg1)] ?: setOf(args)).let {
put(TestResult.Success(arg1), it + setOf(args))
}
}
Unit
}

put(TestResult.Failure, failureArgs)
}

return super.testGen("arithmetic", tests, argsMap)
}
}
Loading

0 comments on commit 140abd0

Please sign in to comment.