Skip to content

Commit

Permalink
Adds a DatumReader to conformance runner
Browse files Browse the repository at this point in the history
Fixes aggregations by specifying exact precision/scale

Adds comparison operators using new modeling

Adds in_collection function using new modeling

Rewrites times/division/mod operators using new modeling
  • Loading branch information
johnedquinn committed Nov 21, 2024
1 parent d56bc85 commit 8443f33
Show file tree
Hide file tree
Showing 27 changed files with 1,588 additions and 1,455 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@ import org.partiql.eval.internal.operator.rex.ExprCallDynamic.Candidate
import org.partiql.eval.internal.operator.rex.ExprCallDynamic.CoercionFamily.DYNAMIC
import org.partiql.eval.internal.operator.rex.ExprCallDynamic.CoercionFamily.UNKNOWN
import org.partiql.spi.function.Function
import org.partiql.spi.function.Parameter
import org.partiql.spi.value.Datum
import org.partiql.types.PType
import org.partiql.value.PartiQLValue
Expand Down Expand Up @@ -41,26 +40,6 @@ internal class ExprCallDynamic(
*/
private val paramIndices: IntRange = args.indices

/**
* @property paramTypes represents a two-dimensional array.
*
* 1. Dimension-1 maps to a candidate.
* 2. Dimension-2 maps to that candidate's parameter types.
*
* TODO actually make this an array instead of lists.
*/
private val paramTypes: List<List<Parameter>> = functions.map { c -> c.getParameters().toList() }

/**
* @property paramFamilies is a two-dimensional array.
*
* 1. Dimension-1 maps to the [candidates]
* 2. Dimension-2 maps to the [CoercionFamily].
*
* TODO actually make this an array instead of lists.
*/
private val paramFamilies: List<List<CoercionFamily>> = functions.map { c -> c.getParameters().map { p -> family(p.getType().kind) } }

/**
* A memoization cache for the [match] function.
*/
Expand Down Expand Up @@ -90,13 +69,14 @@ internal class ExprCallDynamic(
val argFamilies = args.map { family(it.kind) }
functions.indices.forEach { candidateIndex ->
var currentExactMatches = 0
val params = functions[candidateIndex].getInstance(args.toTypedArray())?.parameters ?: return@forEach
for (paramIndex in paramIndices) {
val argType = args[paramIndex]
val paramType = paramTypes[candidateIndex][paramIndex]
if (paramType.getMatch(argType) == argType) { currentExactMatches++ }
val paramType = params[paramIndex]
if (paramType.kind == argType.kind) { currentExactMatches++ } // TODO: Convert all functions to use the new modelling, or else we need to only check kinds
val argFamily = argFamilies[paramIndex]
val paramFamily = paramFamilies[candidateIndex][paramIndex]
if (paramFamily != argFamily && argFamily != CoercionFamily.UNKNOWN && paramFamily != CoercionFamily.DYNAMIC) { return@forEach }
val paramFamily = family(paramType.kind)
if (paramFamily != argFamily && argFamily != UNKNOWN && paramFamily != DYNAMIC) { return@forEach }
}
if (currentExactMatches > exactMatches) {
currentMatch = candidateIndex
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,8 @@ internal object FnResolver {
for (i in args.indices) {
val a = args[i]
val p = parameters[i]
if (p != a) return false
// TODO: Don't use kind! Once all functions use the new modelling, we can just make it p != a.
if (p.kind != a.kind) return false
}
return true
}
Expand All @@ -153,7 +154,7 @@ internal object FnResolver {
// check match
val p = parameters[i]
when {
p == a -> exactInputTypes++
p.kind == a.kind -> exactInputTypes++ // TODO: Don't use kind! Once all functions use the new modelling, we can just make it p == a.
else -> mapping[i] = coercion(a, p) ?: return null
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -166,7 +166,8 @@ internal class PlanTransform(private val flags: Set<PlannerFlag>) {
// TODO assert on function name in plan typer .. here is not the place.
val args = node.args.map { visitRex(it, ctx) }
val fns = node.candidates.map { it.fn.signature }
return factory.rexCallDynamic("unknown", fns, args)
val name = node.candidates.first().fn.name.getName()
return factory.rexCallDynamic(name, fns, args)
}

override fun visitRexOpCallStatic(node: IRex.Op.Call.Static, ctx: PType): Any {
Expand Down
1 change: 1 addition & 0 deletions partiql-spi/api/partiql-spi.api
Original file line number Diff line number Diff line change
Expand Up @@ -445,6 +445,7 @@ public final class org/partiql/spi/function/Parameter {
public final fun getType ()Lorg/partiql/types/PType;
public static final fun number (Ljava/lang/String;)Lorg/partiql/spi/function/Parameter;
public static final fun text (Ljava/lang/String;)Lorg/partiql/spi/function/Parameter;
public fun toString ()Ljava/lang/String;
}

public final class org/partiql/spi/function/Parameter$Companion {
Expand Down
2 changes: 1 addition & 1 deletion partiql-spi/src/main/java/org/partiql/spi/value/Datum.java
Original file line number Diff line number Diff line change
Expand Up @@ -615,7 +615,7 @@ static Datum decimal(@NotNull BigDecimal value) {
static Datum decimal(@NotNull BigDecimal value, int precision, int scale) throws DataException {
BigDecimal d = value.round(new MathContext(precision)).setScale(scale, RoundingMode.HALF_UP);
if (d.precision() > precision) {
throw new DataException("Value " + d + " could not fit into decimal with precision/scale.");
throw new DataException("Value " + d + " could not fit into decimal with precision " + precision + " and scale " + scale + ".");
}
return new DatumDecimal(d, PType.decimal(precision, scale));
}
Expand Down
92 changes: 8 additions & 84 deletions partiql-spi/src/main/kotlin/org/partiql/spi/function/Builtins.kt
Original file line number Diff line number Diff line change
Expand Up @@ -123,14 +123,7 @@ internal object Builtins {
Fn_DATE_DIFF_SECOND__TIMESTAMP_TIMESTAMP__INT64,
Fn_DATE_DIFF_YEAR__DATE_DATE__INT64,
Fn_DATE_DIFF_YEAR__TIMESTAMP_TIMESTAMP__INT64,
Fn_DIVIDE__INT8_INT8__INT8,
Fn_DIVIDE__INT16_INT16__INT16,
Fn_DIVIDE__INT32_INT32__INT32,
Fn_DIVIDE__INT64_INT64__INT64,
Fn_DIVIDE__INT_INT__INT,
Fn_DIVIDE__FLOAT32_FLOAT32__FLOAT32,
Fn_DIVIDE__FLOAT64_FLOAT64__FLOAT64,
Fn_DIVIDE__DECIMAL_ARBITRARY_DECIMAL_ARBITRARY__DECIMAL_ARBITRARY,
FnDivide,
FnEq,
Fn_EXTRACT_DAY__DATE__INT32,
Fn_EXTRACT_DAY__TIMESTAMP__INT32,
Expand All @@ -148,38 +141,9 @@ internal object Builtins {
Fn_EXTRACT_TIMEZONE_MINUTE__TIMESTAMP__INT32,
Fn_EXTRACT_YEAR__DATE__INT32,
Fn_EXTRACT_YEAR__TIMESTAMP__INT32,
Fn_GT__BOOL_BOOL__BOOL,
Fn_GT__INT8_INT8__BOOL,
Fn_GT__INT16_INT16__BOOL,
Fn_GT__INT32_INT32__BOOL,
Fn_GT__INT64_INT64__BOOL,
Fn_GT__INT_INT__BOOL,
Fn_GT__FLOAT32_FLOAT32__BOOL,
Fn_GT__FLOAT64_FLOAT64__BOOL,
Fn_GT__DECIMAL_ARBITRARY_DECIMAL_ARBITRARY__BOOL,
Fn_GT__STRING_STRING__BOOL,

Fn_GT__DATE_DATE__BOOL,
Fn_GT__TIME_TIME__BOOL,
Fn_GT__TIMESTAMP_TIMESTAMP__BOOL,
Fn_GTE__BOOL_BOOL__BOOL,
Fn_GTE__INT8_INT8__BOOL,
Fn_GTE__INT16_INT16__BOOL,
Fn_GTE__INT32_INT32__BOOL,
Fn_GTE__INT64_INT64__BOOL,
Fn_GTE__INT_INT__BOOL,
Fn_GTE__FLOAT32_FLOAT32__BOOL,
Fn_GTE__FLOAT64_FLOAT64__BOOL,
Fn_GTE__DECIMAL_ARBITRARY_DECIMAL_ARBITRARY__BOOL,
Fn_GTE__STRING_STRING__BOOL,

Fn_GTE__DATE_DATE__BOOL,
Fn_GTE__TIME_TIME__BOOL,
Fn_GTE__TIMESTAMP_TIMESTAMP__BOOL,

Fn_IN_COLLECTION__ANY_BAG__BOOL,
Fn_IN_COLLECTION__ANY_LIST__BOOL,

FnGt,
FnGte,
FnInCollection,
Fn_IS_ANY__ANY__BOOL,
Fn_IS_BAG__ANY__BOOL,
Fn_IS_BINARY__ANY__BOOL,
Expand Down Expand Up @@ -221,43 +185,10 @@ internal object Builtins {
Fn_LOWER__STRING__STRING,
Fn_LOWER__CLOB__CLOB,

Fn_LT__BOOL_BOOL__BOOL,
Fn_LT__INT8_INT8__BOOL,
Fn_LT__INT16_INT16__BOOL,
Fn_LT__INT32_INT32__BOOL,
Fn_LT__INT64_INT64__BOOL,
Fn_LT__INT_INT__BOOL,
Fn_LT__FLOAT32_FLOAT32__BOOL,
Fn_LT__FLOAT64_FLOAT64__BOOL,
Fn_LT__DECIMAL_ARBITRARY_DECIMAL_ARBITRARY__BOOL,
Fn_LT__STRING_STRING__BOOL,

Fn_LT__DATE_DATE__BOOL,
Fn_LT__TIME_TIME__BOOL,
Fn_LT__TIMESTAMP_TIMESTAMP__BOOL,
Fn_LTE__BOOL_BOOL__BOOL,
Fn_LTE__INT8_INT8__BOOL,
Fn_LTE__INT16_INT16__BOOL,
Fn_LTE__INT32_INT32__BOOL,
Fn_LTE__INT64_INT64__BOOL,
Fn_LTE__INT_INT__BOOL,
Fn_LTE__FLOAT32_FLOAT32__BOOL,
Fn_LTE__FLOAT64_FLOAT64__BOOL,
Fn_LTE__DECIMAL_ARBITRARY_DECIMAL_ARBITRARY__BOOL,
Fn_LTE__STRING_STRING__BOOL,

Fn_LTE__DATE_DATE__BOOL,
Fn_LTE__TIME_TIME__BOOL,
Fn_LTE__TIMESTAMP_TIMESTAMP__BOOL,
FnLt,
FnLte,
FnMinus,
Fn_MODULO__INT8_INT8__INT8,
Fn_MODULO__INT16_INT16__INT16,
Fn_MODULO__INT32_INT32__INT32,
Fn_MODULO__INT64_INT64__INT64,
Fn_MODULO__INT_INT__INT,
Fn_MODULO__FLOAT32_FLOAT32__FLOAT32,
Fn_MODULO__FLOAT64_FLOAT64__FLOAT64,
Fn_MODULO__DECIMAL_ARBITRARY_DECIMAL_ARBITRARY__DECIMAL_ARBITRARY,
FnModulo,
Fn_NEG__INT8__INT8,
Fn_NEG__INT16__INT16,
Fn_NEG__INT32__INT32,
Expand Down Expand Up @@ -287,14 +218,7 @@ internal object Builtins {
Fn_SUBSTRING__CLOB_INT64__CLOB,
Fn_SUBSTRING__CLOB_INT64_INT64__CLOB,

Fn_TIMES__INT8_INT8__INT8,
Fn_TIMES__INT16_INT16__INT16,
Fn_TIMES__INT32_INT32__INT32,
Fn_TIMES__INT64_INT64__INT64,
Fn_TIMES__INT_INT__INT,
Fn_TIMES__FLOAT32_FLOAT32__FLOAT32,
Fn_TIMES__FLOAT64_FLOAT64__FLOAT64,
Fn_TIMES__DECIMAL_ARBITRARY_DECIMAL_ARBITRARY__DECIMAL_ARBITRARY,
FnTimes,
Fn_TRIM__STRING__STRING,
Fn_TRIM__CLOB__CLOB,

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,10 @@ public class Parameter private constructor(
*/
public fun getType(): PType = type.preferred

override fun toString(): String {
return name + ": " + type.preferred.toString()
}

/**
* Get match is used for function resolution; it indicates an exact match, coercion, or no match.
*
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,22 +10,22 @@ import org.partiql.types.PType

internal val Agg_AVG__INT8__INT8 = Aggregation.static(
name = "avg",
returns = PType.decimal(),
returns = PType.decimal(38, 19),
parameters = arrayOf(Parameter("value", PType.tinyint())),
accumulator = ::AccumulatorAvg,
)

internal val Agg_AVG__INT16__INT16 = Aggregation.static(
name = "avg",
returns = PType.decimal(),
returns = PType.decimal(38, 19),
parameters = arrayOf(Parameter("value", PType.smallint())),
accumulator = ::AccumulatorAvg,
)

internal val Agg_AVG__INT32__INT32 = Aggregation.static(

name = "avg",
returns = PType.decimal(),
returns = PType.decimal(38, 19),
parameters = arrayOf(
Parameter("value", PType.integer()),
),
Expand All @@ -35,7 +35,7 @@ internal val Agg_AVG__INT32__INT32 = Aggregation.static(
internal val Agg_AVG__INT64__INT64 = Aggregation.static(

name = "avg",
returns = PType.decimal(),
returns = PType.decimal(38, 19),
parameters = arrayOf(
Parameter("value", PType.bigint()),
),
Expand All @@ -45,7 +45,7 @@ internal val Agg_AVG__INT64__INT64 = Aggregation.static(
internal val Agg_AVG__INT__INT = Aggregation.static(

name = "avg",
returns = PType.decimal(),
returns = PType.decimal(38, 19),
parameters = arrayOf(
@Suppress("DEPRECATION") Parameter("value", PType.numeric()),
),
Expand All @@ -55,7 +55,7 @@ internal val Agg_AVG__INT__INT = Aggregation.static(
internal val Agg_AVG__DECIMAL_ARBITRARY__DECIMAL_ARBITRARY = Aggregation.static(

name = "avg",
returns = PType.decimal(),
returns = PType.decimal(38, 19), // TODO
parameters = arrayOf(
@Suppress("DEPRECATION") Parameter("value", PType.decimal()),
),
Expand Down Expand Up @@ -85,7 +85,7 @@ internal val Agg_AVG__FLOAT64__FLOAT64 = Aggregation.static(
internal val Agg_AVG__ANY__ANY = Aggregation.static(

name = "avg",
returns = PType.decimal(),
returns = PType.decimal(38, 19),
parameters = arrayOf(
Parameter("value", PType.dynamic()),
),
Expand Down
Loading

0 comments on commit 8443f33

Please sign in to comment.