Skip to content

Commit

Permalink
Adds basic aggregations to the partiql-planner (#1247)
Browse files Browse the repository at this point in the history
  • Loading branch information
RCHowell authored Oct 20, 2023
1 parent 5ca3723 commit 72a76cc
Show file tree
Hide file tree
Showing 19 changed files with 1,268 additions and 340 deletions.
14 changes: 14 additions & 0 deletions partiql-ast/src/main/kotlin/org/partiql/ast/normalize/AstPass.kt
Original file line number Diff line number Diff line change
@@ -1,3 +1,17 @@
/*
* Copyright 2022 Amazon.com, Inc. or its affiliates. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License").
* You may not use this file except in compliance with the License.
* A copy of the License is located at:
*
* http://aws.amazon.com/apache2.0/
*
* or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific
* language governing permissions and limitations under the License.
*/

package org.partiql.ast.normalize

import org.partiql.ast.Statement
Expand Down
15 changes: 15 additions & 0 deletions partiql-ast/src/main/kotlin/org/partiql/ast/normalize/Normalize.kt
Original file line number Diff line number Diff line change
@@ -1,3 +1,17 @@
/*
* Copyright 2022 Amazon.com, Inc. or its affiliates. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License").
* You may not use this file except in compliance with the License.
* A copy of the License is located at:
*
* http://aws.amazon.com/apache2.0/
*
* or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific
* language governing permissions and limitations under the License.
*/

package org.partiql.ast.normalize

import org.partiql.ast.Statement
Expand All @@ -10,5 +24,6 @@ public fun Statement.normalize(): Statement {
var ast = this
ast = NormalizeFromSource.apply(ast)
ast = NormalizeSelect.apply(ast)
ast = NormalizeGroupBy.apply(ast)
return ast
}
Original file line number Diff line number Diff line change
@@ -1,3 +1,17 @@
/*
* Copyright 2022 Amazon.com, Inc. or its affiliates. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License").
* You may not use this file except in compliance with the License.
* A copy of the License is located at:
*
* http://aws.amazon.com/apache2.0/
*
* or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific
* language governing permissions and limitations under the License.
*/

package org.partiql.ast.normalize

import org.partiql.ast.AstNode
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
/*
* Copyright 2022 Amazon.com, Inc. or its affiliates. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License").
* You may not use this file except in compliance with the License.
* A copy of the License is located at:
*
* http://aws.amazon.com/apache2.0/
*
* or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific
* language governing permissions and limitations under the License.
*/

package org.partiql.ast.normalize

import org.partiql.ast.Expr
import org.partiql.ast.GroupBy
import org.partiql.ast.Statement
import org.partiql.ast.groupByKey
import org.partiql.ast.helpers.toBinder
import org.partiql.ast.util.AstRewriter

/**
* Adds a unique binder to each group key.
*/
object NormalizeGroupBy : AstPass {

override fun apply(statement: Statement) = Visitor.visitStatement(statement, 0) as Statement

private object Visitor : AstRewriter<Int>() {

override fun visitGroupByKey(node: GroupBy.Key, ctx: Int): GroupBy.Key {
val expr = visitExpr(node.expr, 0) as Expr
val alias = when (node.asAlias) {
null -> expr.toBinder(ctx)
else -> node.asAlias
}
return if (expr !== node.expr || alias !== node.asAlias) {
groupByKey(expr, alias)
} else {
node
}
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
/*
* Copyright 2022 Amazon.com, Inc. or its affiliates. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License").
* You may not use this file except in compliance with the License.
* A copy of the License is located at:
*
* http://aws.amazon.com/apache2.0/
*
* or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific
* language governing permissions and limitations under the License.
*/

package org.partiql.ast.normalize

import org.partiql.ast.Expr
import org.partiql.ast.Select
import org.partiql.ast.Statement
import org.partiql.ast.builder.ast
import org.partiql.ast.helpers.toBinder
import org.partiql.ast.util.AstRewriter

/**
* Adds an `as` alias to every select-list item.
*
* - [org.partiql.ast.helpers.toBinder]
* - https://partiql.org/assets/PartiQL-Specification.pdf#page=28
* - https://web.cecs.pdx.edu/~len/sql1999.pdf#page=287
*/
internal object NormalizeSelectList : AstPass {

override fun apply(statement: Statement) = Visitor.visitStatement(statement, 0) as Statement

private object Visitor : AstRewriter<Int>() {

override fun visitSelectProject(node: Select.Project, ctx: Int) = ast {
if (node.items.isEmpty()) {
return@ast node
}
var diff = false
val transformed = ArrayList<Select.Project.Item>(node.items.size)
node.items.forEachIndexed { i, n ->
val item = visitSelectProjectItem(n, i) as Select.Project.Item
if (item !== n) diff = true
transformed.add(item)
}
// We don't want to create a new list unless we have to, as to not trigger further rewrites up the tree.
if (diff) selectProject(transformed) else node
}

override fun visitSelectProjectItemAll(node: Select.Project.Item.All, ctx: Int) = node.copy()

override fun visitSelectProjectItemExpression(node: Select.Project.Item.Expression, ctx: Int) = ast {
val expr = visitExpr(node.expr, 0) as Expr
val alias = when (node.asAlias) {
null -> expr.toBinder(ctx)
else -> node.asAlias
}
if (expr != node.expr || alias != node.asAlias) {
selectProjectItemExpression(expr, alias)
} else {
node
}
}
}
}
Original file line number Diff line number Diff line change
@@ -1,3 +1,17 @@
/*
* Copyright 2022 Amazon.com, Inc. or its affiliates. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License").
* You may not use this file except in compliance with the License.
* A copy of the License is located at:
*
* http://aws.amazon.com/apache2.0/
*
* or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific
* language governing permissions and limitations under the License.
*/

package org.partiql.ast.normalize

import org.partiql.ast.Expr
Expand Down
20 changes: 15 additions & 5 deletions partiql-plan/src/main/resources/partiql_plan_0_1.ion
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,8 @@ imports::{
kotlin: [
partiql_value::'org.partiql.value.PartiQLValue',
static_type::'org.partiql.types.StaticType',
function_signature::'org.partiql.types.function.FunctionSignature',
scalar_signature::'org.partiql.types.function.FunctionSignature$Scalar',
aggregation_signature::'org.partiql.types.function.FunctionSignature$Aggregation',
],
}

Expand All @@ -28,7 +29,16 @@ global::{

fn::[
resolved::{
signature: function_signature,
signature: scalar_signature,
},
unresolved::{
identifier: identifier,
},
]

agg::[
resolved::{
signature: aggregation_signature,
},
unresolved::{
identifier: identifier,
Expand Down Expand Up @@ -262,11 +272,11 @@ rel::{
aggregate::{
input: rel,
strategy: [ FULL, PARTIAL ],
aggs: list::[agg],
calls: list::[call],
groups: list::[rex],
_: [
agg::{
fn: fn,
call::{
agg: agg,
args: list::[rex],
},
],
Expand Down
45 changes: 33 additions & 12 deletions partiql-planner/src/main/kotlin/org/partiql/planner/Env.kt
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
package org.partiql.planner

import org.partiql.plan.Agg
import org.partiql.plan.Fn
import org.partiql.plan.Global
import org.partiql.plan.Identifier
Expand All @@ -10,7 +11,6 @@ import org.partiql.plan.identifierQualified
import org.partiql.plan.identifierSymbol
import org.partiql.planner.typer.FunctionResolver
import org.partiql.planner.typer.Mapping
import org.partiql.planner.typer.isNullOrMissing
import org.partiql.planner.typer.toRuntimeType
import org.partiql.spi.BindingCase
import org.partiql.spi.BindingName
Expand Down Expand Up @@ -73,7 +73,7 @@ internal class TypeEnv(
/**
* Result of attempting to match an unresolved function.
*/
internal sealed class FnMatch {
internal sealed class FnMatch<T : FunctionSignature> {

/**
* 7.1 Inputs with wrong types
Expand All @@ -83,17 +83,17 @@ internal sealed class FnMatch {
* @property mapping
* @property isMissable TRUE when anyone of the arguments _could_ be MISSING. We *always* propagate MISSING.
*/
public data class Ok(
public val signature: FunctionSignature,
public data class Ok<T : FunctionSignature>(
public val signature: T,
public val mapping: Mapping,
public val isMissable: Boolean,
) : FnMatch()
) : FnMatch<T>()

public data class Error(
public val fn: Fn.Unresolved,
public data class Error<T : FunctionSignature>(
public val identifier: Identifier,
public val args: List<Rex>,
public val candidates: List<FunctionSignature>,
) : FnMatch()
) : FnMatch<T>()
}

/**
Expand Down Expand Up @@ -195,21 +195,42 @@ internal class Env(
}

/**
* Leverages a [FunctionResolver] to find a matching function defined in the [Header].
* Leverages a [FunctionResolver] to find a matching function defined in the [Header] scalar function catalog.
*/
internal fun resolveFn(fn: Fn.Unresolved, args: List<Rex>): FnMatch {
internal fun resolveFn(fn: Fn.Unresolved, args: List<Rex>): FnMatch<FunctionSignature.Scalar> {
val candidates = header.lookup(fn)
var hadMissingArg = false
val parameters = args.mapIndexed { i, arg ->
if (!hadMissingArg && arg.type.isMissable()) {
hadMissingArg = true
}
arg.type.isNullOrMissing()
FunctionParameter("arg-$i", arg.type.toRuntimeType())
}
val match = functionResolver.match(candidates, parameters)
return when (match) {
null -> FnMatch.Error(fn, args, candidates)
null -> FnMatch.Error(fn.identifier, args, candidates)
else -> {
val isMissable = hadMissingArg || header.isUnsafeCast(match.signature.specific)
FnMatch.Ok(match.signature, match.mapping, isMissable)
}
}
}

/**
* Leverages a [FunctionResolver] to find a matching function defined in the [Header] aggregation function catalog.
*/
internal fun resolveAgg(agg: Agg.Unresolved, args: List<Rex>): FnMatch<FunctionSignature.Aggregation> {
val candidates = header.lookup(agg)
var hadMissingArg = false
val parameters = args.mapIndexed { i, arg ->
if (!hadMissingArg && arg.type.isMissable()) {
hadMissingArg = true
}
FunctionParameter("arg-$i", arg.type.toRuntimeType())
}
val match = functionResolver.match(candidates, parameters)
return when (match) {
null -> FnMatch.Error(agg.identifier, args, candidates)
else -> {
val isMissable = hadMissingArg || header.isUnsafeCast(match.signature.specific)
FnMatch.Ok(match.signature, match.mapping, isMissable)
Expand Down
Loading

0 comments on commit 72a76cc

Please sign in to comment.