Skip to content

Commit

Permalink
Allows providing headers to PartiQLPlanner
Browse files Browse the repository at this point in the history
  • Loading branch information
rchowell committed Oct 27, 2023
1 parent 7d9cf21 commit 94c02ef
Show file tree
Hide file tree
Showing 12 changed files with 1,235 additions and 1,121 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,11 @@ class PartiQLSchemaInferencerTests {
@Execution(ExecutionMode.CONCURRENT)
fun testTupleUnion(tc: TestCase) = runTest(tc)

@ParameterizedTest
@MethodSource("aggregationCases")
@Execution(ExecutionMode.CONCURRENT)
fun testAggregations(tc: TestCase) = runTest(tc)

companion object {

private val root = this::class.java.getResource("/catalogs/default")!!.toURI().toPath().pathString
Expand Down Expand Up @@ -2123,6 +2128,50 @@ class PartiQLSchemaInferencerTests {
),
),
)

@JvmStatic
fun aggregationCases() = listOf(
SuccessTestCase(
name = "AGGREGATE over INTS",
query = "SELECT a, COUNT(*) AS c, SUM(a) AS s, MIN(b) AS m FROM << {'a': 1, 'b': 2} >> GROUP BY a",
expected = BagType(
StructType(
fields = mapOf(
"a" to INT,
"c" to INT4,
"s" to INT.asNullable(),
"m" to INT.asNullable(),
),
contentClosed = true,
constraints = setOf(
TupleConstraint.Open(false),
TupleConstraint.UniqueAttrs(true),
TupleConstraint.Ordered
)
)
)
),
SuccessTestCase(
name = "AGGREGATE over DECIMALS",
query = "SELECT a, COUNT(*) AS c, SUM(a) AS s, MIN(b) AS m FROM << {'a': 1.0, 'b': 2.0}, {'a': 1.0, 'b': 2.0} >> GROUP BY a",
expected = BagType(
StructType(
fields = mapOf(
"a" to StaticType.DECIMAL,
"c" to INT4,
"s" to StaticType.DECIMAL.asNullable(),
"m" to StaticType.DECIMAL.asNullable(),
),
contentClosed = true,
constraints = setOf(
TupleConstraint.Open(false),
TupleConstraint.UniqueAttrs(true),
TupleConstraint.Ordered
)
)
)
),
)
}

sealed class TestCase {
Expand Down Expand Up @@ -2790,46 +2839,6 @@ class PartiQLSchemaInferencerTests {
)
)
),
SuccessTestCase(
name = "AGGREGATE over INTS",
query = "SELECT a, COUNT(*) AS c, SUM(a) AS s, MIN(b) AS m FROM << {'a': 1, 'b': 2} >> GROUP BY a",
expected = BagType(
StructType(
fields = mapOf(
"a" to INT,
"c" to INT,
"s" to INT,
"m" to INT,
),
contentClosed = true,
constraints = setOf(
TupleConstraint.Open(false),
TupleConstraint.UniqueAttrs(true),
TupleConstraint.Ordered
)
)
)
),
SuccessTestCase(
name = "AGGREGATE over DECIMALS",
query = "SELECT a, COUNT(*) AS c, SUM(a) AS s, MIN(b) AS m FROM << {'a': 1.0, 'b': 2.0}, {'a': 1.0, 'b': 2.0} >> GROUP BY a",
expected = BagType(
StructType(
fields = mapOf(
"a" to StaticType.DECIMAL,
"c" to INT,
"s" to StaticType.DECIMAL,
"m" to StaticType.DECIMAL,
),
contentClosed = true,
constraints = setOf(
TupleConstraint.Open(false),
TupleConstraint.UniqueAttrs(true),
TupleConstraint.Ordered
)
)
)
),
SuccessTestCase(
name = "Current User",
query = "CURRENT_USER",
Expand Down
85 changes: 7 additions & 78 deletions partiql-planner/src/main/kotlin/org/partiql/planner/Env.kt
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,7 @@ import org.partiql.plan.Rex
import org.partiql.plan.global
import org.partiql.plan.identifierQualified
import org.partiql.plan.identifierSymbol
import org.partiql.planner.typer.FunctionResolver
import org.partiql.planner.typer.Mapping
import org.partiql.planner.typer.toRuntimeType
import org.partiql.planner.typer.FnResolver
import org.partiql.spi.BindingCase
import org.partiql.spi.BindingName
import org.partiql.spi.BindingPath
Expand All @@ -25,9 +23,6 @@ import org.partiql.spi.connector.Constants
import org.partiql.types.StaticType
import org.partiql.types.StructType
import org.partiql.types.TupleConstraint
import org.partiql.types.function.FunctionParameter
import org.partiql.types.function.FunctionSignature
import org.partiql.value.PartiQLValueExperimental

/**
* Handle for associating a catalog with the metadata; pair of catalog to data.
Expand Down Expand Up @@ -70,32 +65,6 @@ internal class TypeEnv(
}
}

/**
* Result of attempting to match an unresolved function.
*/
internal sealed class FnMatch<T : FunctionSignature> {

/**
* 7.1 Inputs with wrong types
* It follows that all functions return MISSING when one of their inputs is MISSING
*
* @property signature
* @property mapping
* @property isMissable TRUE when anyone of the arguments _could_ be MISSING. We *always* propagate MISSING.
*/
public data class Ok<T : FunctionSignature>(
public val signature: T,
public val mapping: Mapping,
public val isMissable: Boolean,
) : FnMatch<T>()

public data class Error<T : FunctionSignature>(
public val identifier: Identifier,
public val args: List<Rex>,
public val candidates: List<FunctionSignature>,
) : FnMatch<T>()
}

/**
* Metadata regarding a resolved variable.
*/
Expand Down Expand Up @@ -148,13 +117,12 @@ internal enum class ResolutionStrategy {
/**
* PartiQL Planner Global Environment of Catalogs backed by given plugins.
*
* @property header List of namespaced definitions
* @property headers List of namespaced definitions
* @property plugins List of plugins for global resolution
* @property session Session details
*/
@OptIn(PartiQLValueExperimental::class)
internal class Env(
private val header: Header,
private val headers: List<Header>,
private val plugins: List<Plugin>,
private val session: PartiQLPlanner.Session,
) {
Expand All @@ -165,9 +133,9 @@ internal class Env(
public val globals = mutableListOf<Global>()

/**
* Encapsulate function matching logic in
* Encapsulate all function resolving logic within [FnResolver].
*/
public val functionResolver = FunctionResolver(header)
public val fnResolver = FnResolver(headers)

private val connectorSession = object : ConnectorSession {
override fun getQueryId(): String = session.queryId
Expand Down Expand Up @@ -197,51 +165,12 @@ internal class Env(
/**
* Leverages a [FunctionResolver] to find a matching function defined in the [Header] scalar function catalog.
*/
internal fun resolveFn(fn: Fn.Unresolved, args: List<Rex>): FnMatch<FunctionSignature.Scalar> {
val candidates = header.lookup(fn)
var hadMissingArg = false
val parameters = args.mapIndexed { i, arg ->
if (!hadMissingArg && arg.type.isMissable()) {
hadMissingArg = true
}
FunctionParameter("arg-$i", arg.type.toRuntimeType())
}
val match = functionResolver.match(candidates, parameters)
return when (match) {
null -> FnMatch.Error(fn.identifier, args, candidates)
else -> {
val isMissable = hadMissingArg || header.isUnsafeCast(match.signature.specific)
FnMatch.Ok(match.signature, match.mapping, isMissable)
}
}
}
internal fun resolveFn(fn: Fn.Unresolved, args: List<Rex>) = fnResolver.resolveFn(fn, args)

/**
* Leverages a [FunctionResolver] to find a matching function defined in the [Header] aggregation function catalog.
*/
internal fun resolveAgg(agg: Agg.Unresolved, args: List<Rex>): FnMatch<FunctionSignature.Aggregation> {
val candidates = header.lookup(agg)
var hadMissingArg = false
val parameters = args.mapIndexed { i, arg ->
if (!hadMissingArg && arg.type.isMissable()) {
hadMissingArg = true
}
FunctionParameter("arg-$i", arg.type.toRuntimeType())
}
val match = functionResolver.match(candidates, parameters)
return when (match) {
null -> FnMatch.Error(agg.identifier, args, candidates)
else -> {
val isMissable = hadMissingArg || header.isUnsafeCast(match.signature.specific)
FnMatch.Ok(match.signature, match.mapping, isMissable)
}
}
}

/**
* TODO
*/
internal fun getFnAggHandle(identifier: Identifier): Nothing = TODO()
internal fun resolveAgg(agg: Agg.Unresolved, args: List<Rex>) = fnResolver.resolveAgg(agg, args)

/**
* Fetch global object metadata from the given [BindingPath].
Expand Down
Loading

0 comments on commit 94c02ef

Please sign in to comment.