Skip to content

Commit

Permalink
Updates Modulo's precision/scale calculation
Browse files Browse the repository at this point in the history
Adds a shared static decimal per PR feedback
  • Loading branch information
johnedquinn committed Nov 26, 2024
1 parent ecbe859 commit 473a5e0
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 10 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -8,24 +8,26 @@ import org.partiql.spi.function.Parameter
import org.partiql.spi.function.builtins.internal.AccumulatorAvg
import org.partiql.types.PType

private val AVG_DECIMAL = PType.decimal(38, 19)

internal val Agg_AVG__INT8__INT8 = Aggregation.static(
name = "avg",
returns = PType.decimal(38, 19),
returns = AVG_DECIMAL,
parameters = arrayOf(Parameter("value", PType.tinyint())),
accumulator = ::AccumulatorAvg,
)

internal val Agg_AVG__INT16__INT16 = Aggregation.static(
name = "avg",
returns = PType.decimal(38, 19),
returns = AVG_DECIMAL,
parameters = arrayOf(Parameter("value", PType.smallint())),
accumulator = ::AccumulatorAvg,
)

internal val Agg_AVG__INT32__INT32 = Aggregation.static(

name = "avg",
returns = PType.decimal(38, 19),
returns = AVG_DECIMAL,
parameters = arrayOf(
Parameter("value", PType.integer()),
),
Expand All @@ -35,7 +37,7 @@ internal val Agg_AVG__INT32__INT32 = Aggregation.static(
internal val Agg_AVG__INT64__INT64 = Aggregation.static(

name = "avg",
returns = PType.decimal(38, 19),
returns = AVG_DECIMAL,
parameters = arrayOf(
Parameter("value", PType.bigint()),
),
Expand All @@ -45,7 +47,7 @@ internal val Agg_AVG__INT64__INT64 = Aggregation.static(
internal val Agg_AVG__INT__INT = Aggregation.static(

name = "avg",
returns = PType.decimal(38, 19),
returns = AVG_DECIMAL,
parameters = arrayOf(
@Suppress("DEPRECATION") Parameter("value", PType.numeric()),
),
Expand All @@ -55,9 +57,9 @@ internal val Agg_AVG__INT__INT = Aggregation.static(
internal val Agg_AVG__DECIMAL_ARBITRARY__DECIMAL_ARBITRARY = Aggregation.static(

name = "avg",
returns = PType.decimal(38, 19),
returns = AVG_DECIMAL,
parameters = arrayOf(
@Suppress("DEPRECATION") Parameter("value", PType.decimal()),
Parameter("value", AVG_DECIMAL),
),
accumulator = ::AccumulatorAvg,
)
Expand Down Expand Up @@ -85,7 +87,7 @@ internal val Agg_AVG__FLOAT64__FLOAT64 = Aggregation.static(
internal val Agg_AVG__ANY__ANY = Aggregation.static(

name = "avg",
returns = PType.decimal(38, 19),
returns = PType.dynamic(),
parameters = arrayOf(
Parameter("value", PType.dynamic()),
),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -53,9 +53,14 @@ internal object FnModulo : DiadicArithmeticOperator("modulo") {
}
}

/**
* SQL Server
* p = min(p1 - s1, p2 - s2) + max(s1, s2)
* s = max(s1, s2)
*/
override fun getDecimalInstance(decimalLhs: PType, decimalRhs: PType): Function.Instance {
val p = decimalLhs.precision - decimalLhs.scale + decimalRhs.scale + Math.max(6, decimalLhs.scale + decimalRhs.precision + 1)
val s = Math.max(6, decimalLhs.scale + decimalRhs.precision + 1)
val p = Math.min(decimalLhs.precision - decimalLhs.scale, decimalRhs.precision - decimalRhs.scale) + Math.max(decimalLhs.scale, decimalRhs.scale)
val s = Math.max(decimalLhs.scale, decimalRhs.scale)
return basic(PType.decimal()) { args ->
val arg0 = args[0].bigDecimal
val arg1 = args[1].bigDecimal
Expand Down

0 comments on commit 473a5e0

Please sign in to comment.