diff --git a/partiql-spi/src/main/kotlin/org/partiql/spi/function/builtins/AggAvg.kt b/partiql-spi/src/main/kotlin/org/partiql/spi/function/builtins/AggAvg.kt index b5bc61e2e..3dcfd2143 100644 --- a/partiql-spi/src/main/kotlin/org/partiql/spi/function/builtins/AggAvg.kt +++ b/partiql-spi/src/main/kotlin/org/partiql/spi/function/builtins/AggAvg.kt @@ -8,16 +8,18 @@ import org.partiql.spi.function.Parameter import org.partiql.spi.function.builtins.internal.AccumulatorAvg import org.partiql.types.PType +private val AVG_DECIMAL = PType.decimal(38, 19) + internal val Agg_AVG__INT8__INT8 = Aggregation.static( name = "avg", - returns = PType.decimal(38, 19), + returns = AVG_DECIMAL, parameters = arrayOf(Parameter("value", PType.tinyint())), accumulator = ::AccumulatorAvg, ) internal val Agg_AVG__INT16__INT16 = Aggregation.static( name = "avg", - returns = PType.decimal(38, 19), + returns = AVG_DECIMAL, parameters = arrayOf(Parameter("value", PType.smallint())), accumulator = ::AccumulatorAvg, ) @@ -25,7 +27,7 @@ internal val Agg_AVG__INT16__INT16 = Aggregation.static( internal val Agg_AVG__INT32__INT32 = Aggregation.static( name = "avg", - returns = PType.decimal(38, 19), + returns = AVG_DECIMAL, parameters = arrayOf( Parameter("value", PType.integer()), ), @@ -35,7 +37,7 @@ internal val Agg_AVG__INT32__INT32 = Aggregation.static( internal val Agg_AVG__INT64__INT64 = Aggregation.static( name = "avg", - returns = PType.decimal(38, 19), + returns = AVG_DECIMAL, parameters = arrayOf( Parameter("value", PType.bigint()), ), @@ -45,7 +47,7 @@ internal val Agg_AVG__INT64__INT64 = Aggregation.static( internal val Agg_AVG__INT__INT = Aggregation.static( name = "avg", - returns = PType.decimal(38, 19), + returns = AVG_DECIMAL, parameters = arrayOf( @Suppress("DEPRECATION") Parameter("value", PType.numeric()), ), @@ -55,9 +57,9 @@ internal val Agg_AVG__INT__INT = Aggregation.static( internal val Agg_AVG__DECIMAL_ARBITRARY__DECIMAL_ARBITRARY = Aggregation.static( name = "avg", - returns = PType.decimal(38, 19), + returns = AVG_DECIMAL, parameters = arrayOf( - @Suppress("DEPRECATION") Parameter("value", PType.decimal()), + Parameter("value", AVG_DECIMAL), ), accumulator = ::AccumulatorAvg, ) @@ -85,7 +87,7 @@ internal val Agg_AVG__FLOAT64__FLOAT64 = Aggregation.static( internal val Agg_AVG__ANY__ANY = Aggregation.static( name = "avg", - returns = PType.decimal(38, 19), + returns = PType.dynamic(), parameters = arrayOf( Parameter("value", PType.dynamic()), ), diff --git a/partiql-spi/src/main/kotlin/org/partiql/spi/function/builtins/FnModulo.kt b/partiql-spi/src/main/kotlin/org/partiql/spi/function/builtins/FnModulo.kt index e16cef6cc..772390887 100644 --- a/partiql-spi/src/main/kotlin/org/partiql/spi/function/builtins/FnModulo.kt +++ b/partiql-spi/src/main/kotlin/org/partiql/spi/function/builtins/FnModulo.kt @@ -53,9 +53,14 @@ internal object FnModulo : DiadicArithmeticOperator("modulo") { } } + /** + * SQL Server + * p = min(p1 - s1, p2 - s2) + max(s1, s2) + * s = max(s1, s2) + */ override fun getDecimalInstance(decimalLhs: PType, decimalRhs: PType): Function.Instance { - val p = decimalLhs.precision - decimalLhs.scale + decimalRhs.scale + Math.max(6, decimalLhs.scale + decimalRhs.precision + 1) - val s = Math.max(6, decimalLhs.scale + decimalRhs.precision + 1) + val p = Math.min(decimalLhs.precision - decimalLhs.scale, decimalRhs.precision - decimalRhs.scale) + Math.max(decimalLhs.scale, decimalRhs.scale) + val s = Math.max(decimalLhs.scale, decimalRhs.scale) return basic(PType.decimal()) { args -> val arg0 = args[0].bigDecimal val arg1 = args[1].bigDecimal