Skip to content

Commit

Permalink
feat(spark): add some numeric function mappings (#317)
Browse files Browse the repository at this point in the history
Signed-off-by: Andrew Coleman <[email protected]>
  • Loading branch information
andrew-coleman authored Nov 26, 2024
1 parent 7a9ac66 commit 6bb46ac
Show file tree
Hide file tree
Showing 3 changed files with 70 additions and 12 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,26 @@ class FunctionMappings {
s[Subtract]("subtract"),
s[Multiply]("multiply"),
s[Divide]("divide"),
s[Abs]("abs"),
s[Remainder]("modulus"),
s[Pow]("power"),
s[Exp]("exp"),
s[Sqrt]("sqrt"),
s[Sin]("sin"),
s[Cos]("cos"),
s[Tan]("tan"),
s[Asin]("asin"),
s[Acos]("acos"),
s[Atan]("atan"),
s[Atan2]("atan2"),
s[Sinh]("sinh"),
s[Cosh]("cosh"),
s[Tanh]("tanh"),
s[Asinh]("asinh"),
s[Acosh]("acosh"),
s[Atanh]("atanh"),
s[Log]("ln"),
s[Log10]("log10"),
s[And]("and"),
s[Or]("or"),
s[Not]("not"),
Expand Down Expand Up @@ -77,7 +97,8 @@ class FunctionMappings {
s[Min]("min"),
s[Max]("max"),
s[First]("any_value"),
s[HyperLogLogPlusPlus]("approx_count_distinct")
s[HyperLogLogPlusPlus]("approx_count_distinct"),
s[StddevSamp]("std_dev")
)

val WINDOW_SIGS: Seq[Sig] = Seq(
Expand Down
40 changes: 40 additions & 0 deletions spark/src/test/scala/io/substrait/spark/NumericSuite.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
package io.substrait.spark

import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.test.SharedSparkSession

class NumericSuite extends SparkFunSuite with SharedSparkSession with SubstraitPlanTestBase {

override def beforeAll(): Unit = {
super.beforeAll()
sparkContext.setLogLevel("WARN")
}

test("basic") {
assertSqlSubstraitRelRoundTrip(
"select sqrt(abs(num)), mod(num, 2) from (values (-5), (7.4)) as table(num)"
)
}

test("exponentials") {
assertSqlSubstraitRelRoundTrip(
"select power(num, 3), exp(num), ln(num), log10(num) from (values (5), (17)) as table(num)"
)
}

test("trig") {
assertSqlSubstraitRelRoundTrip(
"select sin(num), cos(num), tan(num) from (values (30), (90)) as table(num)"
)
assertSqlSubstraitRelRoundTrip(
"select asin(num), acos(num), atan(num) from (values (0.5), (-0.5)) as table(num)"
)
assertSqlSubstraitRelRoundTrip(
"select sinh(num), cosh(num), tanh(num) from (values (30), (90)) as table(num)"
)
assertSqlSubstraitRelRoundTrip(
"select asinh(num), acosh(num), atanh(num) from (values (0.5), (-0.5)) as table(num)"
)
}

}
19 changes: 8 additions & 11 deletions spark/src/test/scala/io/substrait/spark/TPCDSPlan.scala
Original file line number Diff line number Diff line change
Expand Up @@ -32,21 +32,18 @@ class TPCDSPlan extends TPCDSBase with SubstraitPlanTestBase {
}

// spotless:off
val successfulSQL: Set[String] = Set("q1", "q3", "q4", "q5", "q7", "q8",
"q11", "q12", "q13", "q14a", "q14b", "q15", "q16", "q18", "q19",
"q20", "q21", "q22", "q23a", "q23b", "q24a", "q24b", "q25", "q26", "q27", "q28", "q29",
"q30", "q31", "q32", "q33", "q36", "q37", "q38",
"q40", "q41", "q42", "q43", "q44", "q46", "q48", "q49",
"q50", "q52", "q54", "q55", "q56", "q58", "q59",
"q60", "q61", "q62", "q65", "q66", "q67", "q68", "q69",
"q70", "q71", "q73", "q76", "q77", "q79",
"q80", "q81", "q82", "q85", "q86", "q87", "q88",
"q90", "q91", "q92", "q93", "q94", "q95", "q96", "q97", "q98", "q99")
val failingSQL: Set[String] = Set(
"q2", // because round() isn't defined in substrait to work with Decimal. https://github.com/substrait-io/substrait/pull/713
"q9", // requires implementation of named_struct()
"q10", "q35", "q45", // Unsupported join type ExistenceJoin (this is an internal spark type)
"q51", "q83", "q84", // TBD
"q72" //requires implementation of date_add()
)
// spotless:on

tpcdsQueries.foreach {
q =>
if (runAllQueriesIncludeFailed || successfulSQL.contains(q)) {
if (runAllQueriesIncludeFailed || !failingSQL.contains(q)) {
test(s"check simplified (tpcds-v1.4/$q)") {
testQuery("tpcds", q)
}
Expand Down

0 comments on commit 6bb46ac

Please sign in to comment.