From 5d9d73a0e0e3e3615b556c841c6678b7334909df Mon Sep 17 00:00:00 2001 From: Andrew Coleman Date: Thu, 7 Nov 2024 13:57:43 +0000 Subject: [PATCH] feat(spark): add some numeric function mappings Signed-off-by: Andrew Coleman --- .../spark/expression/FunctionMappings.scala | 23 ++++++++++- .../io/substrait/spark/NumericSuite.scala | 40 +++++++++++++++++++ .../scala/io/substrait/spark/TPCDSPlan.scala | 19 ++++----- 3 files changed, 70 insertions(+), 12 deletions(-) create mode 100644 spark/src/test/scala/io/substrait/spark/NumericSuite.scala diff --git a/spark/src/main/scala/io/substrait/spark/expression/FunctionMappings.scala b/spark/src/main/scala/io/substrait/spark/expression/FunctionMappings.scala index d37274822..fa780f6b5 100644 --- a/spark/src/main/scala/io/substrait/spark/expression/FunctionMappings.scala +++ b/spark/src/main/scala/io/substrait/spark/expression/FunctionMappings.scala @@ -39,6 +39,26 @@ class FunctionMappings { s[Subtract]("subtract"), s[Multiply]("multiply"), s[Divide]("divide"), + s[Abs]("abs"), + s[Remainder]("modulus"), + s[Pow]("power"), + s[Exp]("exp"), + s[Sqrt]("sqrt"), + s[Sin]("sin"), + s[Cos]("cos"), + s[Tan]("tan"), + s[Asin]("asin"), + s[Acos]("acos"), + s[Atan]("atan"), + s[Atan2]("atan2"), + s[Sinh]("sinh"), + s[Cosh]("cosh"), + s[Tanh]("tanh"), + s[Asinh]("asinh"), + s[Acosh]("acosh"), + s[Atanh]("atanh"), + s[Log]("ln"), + s[Log10]("log10"), s[And]("and"), s[Or]("or"), s[Not]("not"), @@ -77,7 +97,8 @@ class FunctionMappings { s[Min]("min"), s[Max]("max"), s[First]("any_value"), - s[HyperLogLogPlusPlus]("approx_count_distinct") + s[HyperLogLogPlusPlus]("approx_count_distinct"), + s[StddevSamp]("std_dev") ) val WINDOW_SIGS: Seq[Sig] = Seq( diff --git a/spark/src/test/scala/io/substrait/spark/NumericSuite.scala b/spark/src/test/scala/io/substrait/spark/NumericSuite.scala new file mode 100644 index 000000000..1ca7bc931 --- /dev/null +++ b/spark/src/test/scala/io/substrait/spark/NumericSuite.scala @@ -0,0 +1,40 @@ +package io.substrait.spark + +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.test.SharedSparkSession + +class NumericSuite extends SparkFunSuite with SharedSparkSession with SubstraitPlanTestBase { + + override def beforeAll(): Unit = { + super.beforeAll() + sparkContext.setLogLevel("WARN") + } + + test("basic") { + assertSqlSubstraitRelRoundTrip( + "select sqrt(abs(num)), mod(num, 2) from (values (-5), (7.4)) as table(num)" + ) + } + + test("exponentials") { + assertSqlSubstraitRelRoundTrip( + "select power(num, 3), exp(num), ln(num), log10(num) from (values (5), (17)) as table(num)" + ) + } + + test("trig") { + assertSqlSubstraitRelRoundTrip( + "select sin(num), cos(num), tan(num) from (values (30), (90)) as table(num)" + ) + assertSqlSubstraitRelRoundTrip( + "select asin(num), acos(num), atan(num) from (values (0.5), (-0.5)) as table(num)" + ) + assertSqlSubstraitRelRoundTrip( + "select sinh(num), cosh(num), tanh(num) from (values (30), (90)) as table(num)" + ) + assertSqlSubstraitRelRoundTrip( + "select asinh(num), acosh(num), atanh(num) from (values (0.5), (-0.5)) as table(num)" + ) + } + +} diff --git a/spark/src/test/scala/io/substrait/spark/TPCDSPlan.scala b/spark/src/test/scala/io/substrait/spark/TPCDSPlan.scala index 5d3ff29aa..eb5403549 100644 --- a/spark/src/test/scala/io/substrait/spark/TPCDSPlan.scala +++ b/spark/src/test/scala/io/substrait/spark/TPCDSPlan.scala @@ -32,21 +32,18 @@ class TPCDSPlan extends TPCDSBase with SubstraitPlanTestBase { } // spotless:off - val successfulSQL: Set[String] = Set("q1", "q3", "q4", "q5", "q7", "q8", - "q11", "q12", "q13", "q14a", "q14b", "q15", "q16", "q18", "q19", - "q20", "q21", "q22", "q23a", "q23b", "q24a", "q24b", "q25", "q26", "q27", "q28", "q29", - "q30", "q31", "q32", "q33", "q36", "q37", "q38", - "q40", "q41", "q42", "q43", "q44", "q46", "q48", "q49", - "q50", "q52", "q54", "q55", "q56", "q58", "q59", - "q60", "q61", "q62", "q65", "q66", "q67", "q68", "q69", - "q70", "q71", "q73", "q76", "q77", "q79", - "q80", "q81", "q82", "q85", "q86", "q87", "q88", - "q90", "q91", "q92", "q93", "q94", "q95", "q96", "q97", "q98", "q99") + val failingSQL: Set[String] = Set( + "q2", // because round() isn't defined in substrait to work with Decimal. https://github.com/substrait-io/substrait/pull/713 + "q9", // requires implementation of named_struct() + "q10", "q35", "q45", // Unsupported join type ExistenceJoin (this is an internal spark type) + "q51", "q83", "q84", // TBD + "q72" //requires implementation of date_add() + ) // spotless:on tpcdsQueries.foreach { q => - if (runAllQueriesIncludeFailed || successfulSQL.contains(q)) { + if (runAllQueriesIncludeFailed || !failingSQL.contains(q)) { test(s"check simplified (tpcds-v1.4/$q)") { testQuery("tpcds", q) }