Skip to content

Commit

Permalink
Fixes aggregations of attribute references to values of union types
Browse files Browse the repository at this point in the history
  • Loading branch information
johnedquinn committed Mar 6, 2024
1 parent 5121093 commit 9084f54
Show file tree
Hide file tree
Showing 9 changed files with 248 additions and 37 deletions.
4 changes: 3 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -32,18 +32,20 @@ Thank you to all who have contributed!

### Changed
- Function resolution logic: Now the function resolver would match all possible candidate(based on if the argument can be coerced to the Signature parameter type). If there are multiple match it will first attempt to pick the one requires the least cast, then pick the function with the highest precedence.
- **Behavioral change**: The COUNT aggregate function now returns INT64.

### Deprecated

### Fixed
- Fixes aggregations of attribute references to values of union types. This fix also allows for proper error handling by passing the UnknownAggregateFunction problem to the ProblemCallback. Please note that, with this change, the planner will no longer immediately throw an IllegalStateException for this exact scenario.

### Removed

### Security

### Contributors
Thank you to all who have contributed!
- @<your-username>
- @johnedquinn

## [0.14.3] - 2024-02-14

Expand Down
8 changes: 8 additions & 0 deletions partiql-planner/src/main/kotlin/org/partiql/planner/Errors.kt
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,14 @@ public sealed class PlanningProblemDetails(
"Unknown function `$identifier($types)"
})

public data class UnknownAggregateFunction(
val identifier: String,
val args: List<StaticType>,
) : PlanningProblemDetails(ProblemSeverity.ERROR, {
val types = args.joinToString { "<${it.toString().lowercase()}>" }
"Unknown aggregate function `$identifier($types)"
})

public object ExpressionAlwaysReturnsNullOrMissing : PlanningProblemDetails(
severity = ProblemSeverity.ERROR,
messageFormatter = { "Expression always returns null or missing." }
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -702,13 +702,13 @@ internal object PartiQLHeader : Header() {
private fun count() = listOf(
FunctionSignature.Aggregation(
name = "count",
returns = INT32,
returns = INT64,
parameters = listOf(FunctionParameter("value", ANY)),
isNullable = false,
),
FunctionSignature.Aggregation(
name = "count_star",
returns = INT32,
returns = INT64,
parameters = listOf(),
isNullable = false,
),
Expand Down Expand Up @@ -741,6 +741,15 @@ internal object PartiQLHeader : Header() {
)
}

/**
* According to SQL:1999 Section 6.16 Syntax Rule 14.c and Rule 14.d:
* > If AVG is specified and DT is exact numeric, then the declared type of the result is exact
* numeric with implementation-defined precision not less than the precision of DT and
* implementation-defined scale not less than the scale of DT.
*
* > If DT is approximate numeric, then the declared type of the result is approximate numeric
* with implementation-defined precision not less than the precision of DT.
*/
private fun avg() = types.numeric.map {
FunctionSignature.Aggregation(
name = "avg",
Expand Down
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
package org.partiql.planner.internal.transforms

import org.partiql.errors.Problem
import org.partiql.errors.ProblemCallback
import org.partiql.errors.UNKNOWN_PROBLEM_LOCATION
import org.partiql.plan.PlanNode
import org.partiql.plan.partiQLPlan
import org.partiql.planner.PlanningProblemDetails
import org.partiql.planner.internal.ir.Agg
import org.partiql.planner.internal.ir.Catalog
import org.partiql.planner.internal.ir.Fn
Expand All @@ -12,7 +15,9 @@ import org.partiql.planner.internal.ir.Rel
import org.partiql.planner.internal.ir.Rex
import org.partiql.planner.internal.ir.Statement
import org.partiql.planner.internal.ir.visitor.PlanBaseVisitor
import org.partiql.types.function.FunctionSignature
import org.partiql.value.PartiQLValueExperimental
import org.partiql.value.PartiQLValueType

/**
* This is an internal utility to translate from the internal unresolved plan used for typing to the public plan IR.
Expand Down Expand Up @@ -58,7 +63,7 @@ internal object PlanTransform : PlanBaseVisitor<PlanNode, ProblemCallback>() {
override fun visitAggResolved(node: Agg.Resolved, ctx: ProblemCallback) = org.partiql.plan.Agg(node.signature)

override fun visitAggUnresolved(node: Agg.Unresolved, ctx: ProblemCallback): org.partiql.plan.Rex.Op {
error("Unresolved aggregation ${node.identifier}")
error("Internal error: This should have been handled somewhere else. Cause: Unresolved aggregation ${node.identifier}.")
}

override fun visitStatement(node: Statement, ctx: ProblemCallback) =
Expand Down Expand Up @@ -331,11 +336,56 @@ internal object PlanTransform : PlanBaseVisitor<PlanNode, ProblemCallback>() {
groups = node.groups.map { visitRex(it, ctx) },
)

override fun visitRelOpAggregateCall(node: Rel.Op.Aggregate.Call, ctx: ProblemCallback) =
org.partiql.plan.Rel.Op.Aggregate.Call(
agg = visitAgg(node.agg, ctx),
@OptIn(PartiQLValueExperimental::class)
override fun visitRelOpAggregateCall(node: Rel.Op.Aggregate.Call, ctx: ProblemCallback): org.partiql.plan.Rel.Op.Aggregate.Call {
val agg = when (val agg = node.agg) {
is Agg.Unresolved -> {
val name = agg.identifier.toNormalizedString()
ctx.invoke(
Problem(
UNKNOWN_PROBLEM_LOCATION,
PlanningProblemDetails.UnknownAggregateFunction(
agg.identifier.toString(),
node.args.map { it.type }
)
)
)
org.partiql.plan.Agg(
FunctionSignature.Aggregation(
"UNKNOWN_AGG::$name",
returns = PartiQLValueType.MISSING,
parameters = emptyList()
)
)
}
is Agg.Resolved -> {
visitAggResolved(agg, ctx)
}
}
return org.partiql.plan.Rel.Op.Aggregate.Call(
agg = agg,
args = node.args.map { visitRex(it, ctx) },
)
}

private fun Identifier.toNormalizedString(): String {
return when (this) {
is Identifier.Symbol -> this.toNormalizedString()
is Identifier.Qualified -> {
val toJoin = listOf(this.root) + this.steps
toJoin.joinToString(separator = ".") { ident ->
ident.toNormalizedString()
}
}
}
}

private fun Identifier.Symbol.toNormalizedString(): String {
return when (this.caseSensitivity) {
Identifier.CaseSensitivity.SENSITIVE -> "\"${this.symbol}\""
Identifier.CaseSensitivity.INSENSITIVE -> this.symbol
}
}

override fun visitRelOpExclude(node: Rel.Op.Exclude, ctx: ProblemCallback) = org.partiql.plan.Rel.Op.Exclude(
input = visitRel(node.input, ctx),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1168,7 +1168,7 @@ internal class PlanTyper(
fun resolveAgg(agg: Agg.Unresolved, arguments: List<Rex>): Pair<Rel.Op.Aggregate.Call, StaticType> {
var missingArg = false
val args = arguments.map {
val arg = visitRex(it, null)
val arg = visitRex(it, it.type)
if (arg.type.isMissable()) missingArg = true
arg
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@ import org.partiql.planner.internal.ir.rexOpVarResolved
import org.partiql.spi.BindingCase
import org.partiql.spi.BindingName
import org.partiql.spi.BindingPath
import org.partiql.types.AnyOfType
import org.partiql.types.AnyType
import org.partiql.types.StaticType
import org.partiql.types.StructType
import org.partiql.types.TupleConstraint
Expand Down Expand Up @@ -85,30 +87,28 @@ internal class TypeEnv(public val schema: List<Rel.Binding>) {
for (i in schema.indices) {
val local = schema[i]
val type = local.type
if (type is StructType) {
when (type.containsKey(name)) {
true -> {
if (c != null && known) {
// TODO root was already definitively matched, emit ambiguous error.
return null
}
c = rex(type, rexOpVarResolved(i))
known = true
when (type.containsKey(name)) {
true -> {
if (c != null && known) {
// TODO root was already definitively matched, emit ambiguous error.
return null
}
null -> {
if (c != null) {
if (known) {
continue
} else {
// TODO we have more than one possible match, emit ambiguous error.
return null
}
c = rex(type, rexOpVarResolved(i))
known = true
}
null -> {
if (c != null) {
if (known) {
continue
} else {
// TODO we have more than one possible match, emit ambiguous error.
return null
}
c = rex(type, rexOpVarResolved(i))
known = false
}
false -> continue
c = rex(type, rexOpVarResolved(i))
known = false
}
false -> continue
}
}
return c
Expand Down Expand Up @@ -152,4 +152,39 @@ internal class TypeEnv(public val schema: List<Rel.Binding>) {
val closed = constraints.contains(TupleConstraint.Open(false))
return if (closed) false else null
}

/**
* Searches for the [BindingName] within the given [StaticType].
*
* Returns
* - true iff known to contain key
* - false iff known to NOT contain key
* - null iff NOT known to contain key
*
* @param name
* @return
*/
private fun StaticType.containsKey(name: BindingName): Boolean? {
return when (val type = this.flatten()) {
is StructType -> type.containsKey(name)
is AnyOfType -> {
val anyKnownToContainKey = type.allTypes.any { it.containsKey(name) == true }
val anyKnownToNotContainKey = type.allTypes.any { it.containsKey(name) == false }
val anyNotKnownToContainKey = type.allTypes.any { it.containsKey(name) == null }
when {
// There are:
// - No subtypes that are known to not contain the key
// - No subtypes that are not known to contain the key
anyKnownToNotContainKey.not() && anyNotKnownToContainKey.not() -> true
// There are:
// - No subtypes that are known to contain the key
// - No subtypes that are not known to contain the key
anyKnownToContainKey.not() && anyNotKnownToContainKey.not() -> false
else -> null
}
}
is AnyType -> null
else -> false
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ internal fun StaticType.toRuntimeType(): PartiQLValueType {
// handle anyOf(null, T) cases
val t = types.filter { it !is NullType && it !is MissingType }
return if (t.size != 1) {
error("Cannot have a UNION runtime type: $this")
PartiQLValueType.ANY
} else {
t.first().asRuntimeType()
}
Expand Down
Loading

0 comments on commit 9084f54

Please sign in to comment.