Skip to content

Commit

Permalink
Modifies definition of CASE_WHEN in Plan (#1257)
Browse files Browse the repository at this point in the history
  • Loading branch information
johnedquinn authored Nov 2, 2023
1 parent fb25dfa commit a365a45
Show file tree
Hide file tree
Showing 4 changed files with 179 additions and 11 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ import org.partiql.types.BagType
import org.partiql.types.ListType
import org.partiql.types.SexpType
import org.partiql.types.StaticType
import org.partiql.types.StaticType.Companion.BOOL
import org.partiql.types.StaticType.Companion.DATE
import org.partiql.types.StaticType.Companion.DECIMAL
import org.partiql.types.StaticType.Companion.INT
Expand Down Expand Up @@ -125,6 +126,11 @@ class PartiQLSchemaInferencerTests {
@Execution(ExecutionMode.CONCURRENT)
fun testPathExpressions(tc: TestCase) = runTest(tc)

@ParameterizedTest
@MethodSource("caseWhens")
@Execution(ExecutionMode.CONCURRENT)
fun testCaseWhens(tc: TestCase) = runTest(tc)

companion object {

private val root = this::class.java.getResource("/catalogs/default")!!.toURI().toPath().pathString
Expand Down Expand Up @@ -2137,6 +2143,150 @@ class PartiQLSchemaInferencerTests {
),
)

@JvmStatic
fun caseWhens() = listOf(
SuccessTestCase(
name = "Easy case when",
query = """
CASE
WHEN FALSE THEN 0
WHEN TRUE THEN 1
ELSE 2
END;
""",
expected = INT4
),
SuccessTestCase(
name = "Folded case when to grab the true",
query = """
CASE
WHEN FALSE THEN 0
WHEN TRUE THEN 'hello'
END;
""",
expected = STRING
),
SuccessTestCase(
name = "Boolean case when",
query = """
CASE 'Hello World'
WHEN 'Hello World' THEN TRUE
ELSE FALSE
END;
""",
expected = BOOL
),
SuccessTestCase(
name = "Folded out false",
query = """
CASE
WHEN FALSE THEN 'IMPOSSIBLE TO GET'
ELSE TRUE
END;
""",
expected = BOOL
),
SuccessTestCase(
name = "Folded out false without default",
query = """
CASE
WHEN FALSE THEN 'IMPOSSIBLE TO GET'
END;
""",
expected = NULL
),
SuccessTestCase(
name = "Not folded gives us a nullable without default",
query = """
CASE 1
WHEN 1 THEN TRUE
WHEN 2 THEN FALSE
END;
""",
expected = BOOL.asNullable()
),
SuccessTestCase(
name = "Not folded gives us a nullable without default for query",
query = """
SELECT
CASE breed
WHEN 'golden retriever' THEN 'fluffy dog'
WHEN 'pitbull' THEN 'short-haired dog'
END AS breed_descriptor
FROM dogs
""",
catalog = "pql",
catalogPath = listOf("main"),
expected = BagType(
StructType(
fields = mapOf(
"breed_descriptor" to STRING.asNullable(),
),
contentClosed = true,
constraints = setOf(
TupleConstraint.Open(false),
TupleConstraint.UniqueAttrs(true),
TupleConstraint.Ordered
)
)
)
),
SuccessTestCase(
name = "Query",
query = """
SELECT
CASE breed
WHEN 'golden retriever' THEN 'fluffy dog'
WHEN 'pitbull' THEN 'short-haired dog'
ELSE 'something else'
END AS breed_descriptor
FROM dogs
""",
catalog = "pql",
catalogPath = listOf("main"),
expected = BagType(
StructType(
fields = mapOf(
"breed_descriptor" to STRING,
),
contentClosed = true,
constraints = setOf(
TupleConstraint.Open(false),
TupleConstraint.UniqueAttrs(true),
TupleConstraint.Ordered
)
)
)
),
SuccessTestCase(
name = "Query with heterogeneous data",
query = """
SELECT
CASE breed
WHEN 'golden retriever' THEN 'fluffy dog'
WHEN 'pitbull' THEN 2
ELSE 2.0
END AS breed_descriptor
FROM dogs
""",
catalog = "pql",
catalogPath = listOf("main"),
expected = BagType(
StructType(
fields = mapOf(
"breed_descriptor" to unionOf(STRING, INT4, DECIMAL),
),
contentClosed = true,
constraints = setOf(
TupleConstraint.Open(false),
TupleConstraint.UniqueAttrs(true),
TupleConstraint.Ordered
)
)
)
),
)

@JvmStatic
fun pathExpressions() = listOf(
SuccessTestCase(
Expand Down
1 change: 1 addition & 0 deletions partiql-plan/src/main/resources/partiql_plan_0_1.ion
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,7 @@ rex::{

case::{
branches: list::[branch],
default: rex,
_: [
branch::{
condition: rex,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -162,10 +162,7 @@ internal object RexConverter {

override fun visitExprCase(node: Expr.Case, context: Env) = plan {
val type = (StaticType.ANY)
val rex = when (node.expr) {
null -> null
else -> visitExpr(node.expr!!, context) // match `rex
}
val rex = node.expr?.let { visitExpr(it, context) }

// Converts AST CASE (x) WHEN y THEN z --> Plan CASE WHEN x = y THEN z
val id = identifierSymbol(Expr.Binary.Op.EQ.name.lowercase(), Identifier.CaseSensitivity.SENSITIVE)
Expand All @@ -188,8 +185,7 @@ internal object RexConverter {
null -> rex(type = StaticType.NULL, op = rexOpLit(value = nullValue()))
else -> visitExpr(default, context)
}
branches += rexOpCaseBranch(bool(true), defaultRex)
val op = rexOpCase(branches)
val op = rexOpCase(branches = branches, default = defaultRex)
rex(type, op)
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@ import org.partiql.spi.BindingPath
import org.partiql.types.AnyOfType
import org.partiql.types.AnyType
import org.partiql.types.BagType
import org.partiql.types.BoolType
import org.partiql.types.CollectionType
import org.partiql.types.IntType
import org.partiql.types.ListType
Expand Down Expand Up @@ -564,12 +565,32 @@ internal class PlanTyper(
.map { visitRexOpCaseBranch(it, it.rex.type) }
.filterNot { isLiteralBool(it.condition, false) }

newBranches.forEach { branch ->
if (canBeBoolean(branch.condition.type).not()) {
onProblem.invoke(
Problem(
UNKNOWN_PROBLEM_LOCATION,
PlanningProblemDetails.IncompatibleTypesForOp(branch.condition.type.allTypes, "CASE_WHEN")
)
)
}
}
val default = visitRex(node.default, node.default.type)

// Calculate final expression (short-circuit to first branch if the condition is always TRUE).
val resultTypes = newBranches.map { it.rex }.map { it.type }
val firstBranch = newBranches.firstOrNull() ?: error("CASE_WHEN has NO branches.")
return when (isLiteralBool(firstBranch.condition, true)) {
true -> firstBranch.rex
false -> rex(type = StaticType.unionOf(resultTypes.toSet()).flatten(), node.copy(branches = newBranches))
val resultTypes = newBranches.map { it.rex }.map { it.type } + listOf(default.type)
return when (newBranches.size) {
0 -> default
else -> when (isLiteralBool(newBranches[0].condition, true)) {
true -> newBranches[0].rex
false -> rex(type = StaticType.unionOf(resultTypes.toSet()).flatten(), node.copy(branches = newBranches, default = default))
}
}
}

private fun canBeBoolean(type: StaticType): Boolean {
return type.flatten().allTypes.any {
it is BoolType
}
}

Expand Down

0 comments on commit a365a45

Please sign in to comment.