Skip to content

Commit

Permalink
feat: add MakeDecimal support to spark module
Browse files Browse the repository at this point in the history
The Spark query optimiser injects an internal function (MakeDecimal) when numeric literals appear in a query.
This commit adds support for this, which drastically improves the pass rate for the TPC-DS test suite.

Signed-off-by: Andrew Coleman <[email protected]>
  • Loading branch information
andrew-coleman committed Sep 30, 2024
1 parent f22b3d0 commit 0346df2
Show file tree
Hide file tree
Showing 4 changed files with 49 additions and 21 deletions.
9 changes: 9 additions & 0 deletions spark/src/main/resources/spark.yml
Original file line number Diff line number Diff line change
Expand Up @@ -32,3 +32,12 @@ scalar_functions:
- args:
- value: DECIMAL<P,S>
return: i64
-
name: make_decimal
description: >-
Return the Decimal value of an unscaled Long.
Note: this expression is internal and created only by the optimizer,
impls:
- args:
- value: i64
return: DECIMAL<P,S>
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ class FunctionMappings {
s[Year]("year"),

// internal
s[MakeDecimal]("make_decimal"),
s[UnscaledValue]("unscaled")
)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,8 @@ package io.substrait.spark.expression

import io.substrait.spark.{DefaultExpressionVisitor, HasOutputStack, ToSubstraitType}
import io.substrait.spark.logical.ToLogicalPlan
import org.apache.spark.sql.catalyst.expressions.{CaseWhen, Cast, Expression, In, Literal, NamedExpression, ScalarSubquery}
import org.apache.spark.sql.types.{Decimal, NullType}
import org.apache.spark.sql.catalyst.expressions.{CaseWhen, Cast, Expression, In, Literal, MakeDecimal, NamedExpression, ScalarSubquery}
import org.apache.spark.sql.types.Decimal
import org.apache.spark.unsafe.types.UTF8String
import io.substrait.`type`.{StringTypeVisitor, Type}
import io.substrait.{expression => exp}
Expand Down Expand Up @@ -131,23 +131,32 @@ class ToSparkExpression(
arg.accept(expr.declaration(), i, this)
}

scalarFunctionConverter
.getSparkExpressionFromSubstraitFunc(expr.declaration().key(), expr.outputType())
.flatMap(sig => Option(sig.makeCall(args)))
.getOrElse({
val msg = String.format(
"Unable to convert scalar function %s(%s).",
expr.declaration.name,
expr.arguments.asScala
.map {
case ea: exp.EnumArg => ea.value.toString
case e: SExpression => e.getType.accept(new StringTypeVisitor)
case t: Type => t.accept(new StringTypeVisitor)
case a => throw new IllegalStateException("Unexpected value: " + a)
}
.mkString(", ")
)
throw new IllegalArgumentException(msg)
})
expr.declaration.name match {
case "make_decimal" => expr.outputType match {
// Need special case handing of this internal function (not nice, I know).
// Because the precision and scale arguments are extracted from the output type,
// we can't use the generic scalar function conversion mechanism here.
case d: Type.Decimal => MakeDecimal(args.head, d.precision, d.scale)
case _ => throw new IllegalArgumentException("Output type of MakeDecimal must be a decimal type")
}
case _ => scalarFunctionConverter
.getSparkExpressionFromSubstraitFunc(expr.declaration().key(), expr.outputType())
.flatMap(sig => Option(sig.makeCall(args)))
.getOrElse({
val msg = String.format(
"Unable to convert scalar function %s(%s).",
expr.declaration.name,
expr.arguments.asScala
.map {
case ea: exp.EnumArg => ea.value.toString
case e: SExpression => e.getType.accept(new StringTypeVisitor)
case t: Type => t.accept(new StringTypeVisitor)
case a => throw new IllegalStateException("Unexpected value: " + a)
}
.mkString(", ")
)
throw new IllegalArgumentException(msg)
})
}
}
}
11 changes: 10 additions & 1 deletion spark/src/test/scala/io/substrait/spark/TPCDSPlan.scala
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,16 @@ class TPCDSPlan extends TPCDSBase with SubstraitPlanTestBase {
}

// "q9" failed in spark 3.3
val successfulSQL: Set[String] = Set("q4", "q7", "q18", "q22", "q26", "q28", "q29", "q37", "q41", "q48", "q50", "q62", "q69", "q82", "q85", "q88", "q90", "q93", "q96", "q97", "q99")
val successfulSQL: Set[String] = Set("q1", "q3", "q4", "q7",
"q11", "q13", "q15", "q16", "q18", "q19",
"q22", "q25", "q26", "q28", "q29",
"q30", "q31", "q32", "q37",
"q41", "q42", "q43", "q46", "q48",
"q50", "q52", "q55", "q58", "q59",
"q61", "q62", "q65", "q68", "q69",
"q79",
"q81", "q82", "q85", "q88",
"q90", "q91", "q92", "q93", "q94", "q95", "q96", "q97", "q99")

tpcdsQueries.foreach {
q =>
Expand Down

0 comments on commit 0346df2

Please sign in to comment.