From 3677ff69a24769518a67ee7e8c8eb1991eb8a8b9 Mon Sep 17 00:00:00 2001 From: Chong Gao Date: Mon, 26 Aug 2024 11:16:39 +0800 Subject: [PATCH 1/4] Add combiner for string contains Signed-off-by: Chong Gao --- .../src/main/python/conditionals_test.py | 18 +++ .../spark/sql/rapids/stringFunctions.scala | 109 +++++++++++++++++- 2 files changed, 123 insertions(+), 4 deletions(-) diff --git a/integration_tests/src/main/python/conditionals_test.py b/integration_tests/src/main/python/conditionals_test.py index b95ed53f398..a2fb24a7434 100644 --- a/integration_tests/src/main/python/conditionals_test.py +++ b/integration_tests/src/main/python/conditionals_test.py @@ -379,3 +379,21 @@ def test_case_when_all_then_values_are_scalars_with_nulls(): "tab", sql_without_else, conf = {'spark.rapids.sql.case_when.fuse': 'true'}) + +@pytest.mark.parametrize('combine_string_contains_enabled', ['true', 'false']) +def test_combine_string_contains_in_case_when(combine_string_contains_enabled): + data_gen = [("c1", string_gen)] + sql = """ + SELECT + INSTR(c1, 'substring1') > 0, + INSTR(c1, 'substring2') > 0, + INSTR(c1, 'substring3') > 0 + from tab + """ + # spark.rapids.sql.combined.expressions.enabled is true by default + assert_gpu_and_cpu_are_equal_sql( + lambda spark : gen_df(spark, data_gen), + "tab", + sql, + { "spark.rapids.sql.expression.combined.GpuContains" : combine_string_contains_enabled} + ) diff --git a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/stringFunctions.scala b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/stringFunctions.scala index c8a90dc80ad..3ae66ada010 100644 --- a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/stringFunctions.scala +++ b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/stringFunctions.scala @@ -20,6 +20,8 @@ import java.nio.charset.Charset import java.text.DecimalFormatSymbols import java.util.{Locale, Optional} +import scala.annotation.tailrec +import scala.collection.mutable import scala.collection.mutable.ArrayBuffer import ai.rapids.cudf.{BinaryOp, BinaryOperable, CaptureGroups, ColumnVector, ColumnView, DType, PadSide, RegexProgram, RoundMode, Scalar} @@ -32,6 +34,7 @@ import com.nvidia.spark.rapids.jni.RegexRewriteUtils import com.nvidia.spark.rapids.shims.{ShimExpression, SparkShimImpl} import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.rapids.catalyst.expressions.{GpuCombinable, GpuExpressionCombiner, GpuExpressionEquals} import org.apache.spark.sql.types._ import org.apache.spark.sql.vectorized.ColumnarBatch import org.apache.spark.unsafe.types.UTF8String @@ -388,10 +391,11 @@ case class GpuConcatWs(children: Seq[Expression]) } case class GpuContains(left: Expression, right: Expression) - extends GpuBinaryExpressionArgsAnyScalar - with Predicate - with ImplicitCastInputTypes - with NullIntolerant { + extends GpuBinaryExpressionArgsAnyScalar + with Predicate + with ImplicitCastInputTypes + with NullIntolerant + with GpuCombinable { override def inputTypes: Seq[DataType] = Seq(StringType) @@ -411,6 +415,103 @@ case class GpuContains(left: Expression, right: Expression) doColumnar(expandedLhs, rhs) } } + + /** + * Get a combiner that can be used to find candidates to combine + */ + override def getCombiner(): GpuExpressionCombiner = new ContainsCombiner(this) +} + +case class GpuMultiContains(left: Expression, targets: Seq[UTF8String], output: StructType) + extends GpuExpression with ShimExpression { + + override def otherCopyArgs: Seq[AnyRef] = Nil + + override def dataType: DataType = output + + override def nullable: Boolean = false + + override def prettyName: String = "multi_contains" + + override def columnarEval(batch: ColumnarBatch): GpuColumnVector = { + val targetsBytes = targets.map(t => t.getBytes).toArray + withResource(ColumnVector.fromUTF8Strings(targetsBytes : _*)) { targetsCv => + withResource(left.columnarEval(batch)) { lhs => + withResource(lhs.getBase.stringContains(targetsCv)) { boolCvs => + GpuColumnVector.from(ColumnVector.makeStruct(batch.numRows(), boolCvs: _*), dataType) + } + } + } + } + override def children: Seq[Expression] = Seq(left) +} + +class ContainsCombiner(private val exp: GpuContains) extends GpuExpressionCombiner { + private var outputLocation = 0 + /** + * A mapping between an expression and where in the output struct of + * the MultiGetJsonObject will the output be. + */ + private val toCombine = mutable.HashMap.empty[GpuExpressionEquals, Int] + addExpression(exp) + + override def toString: String = s"ContainsCombiner $toCombine" + + override def hashCode: Int = { + // We already know that we are Contains, and what we can combine is based + // on the string column being the same. + "Contains".hashCode + (exp.left.semanticHash() * 17) + } + + /** + * only combine when targets are literals + */ + override def equals(o: Any): Boolean = o match { + case other: ContainsCombiner => exp.left.semanticEquals(other.exp.left) && + exp.right.isInstanceOf[GpuLiteral] && other.exp.right.isInstanceOf[GpuLiteral] + case _ => false + } + + override def addExpression(e: Expression): Unit = { + val localOutputLocation = outputLocation + outputLocation += 1 + val key = GpuExpressionEquals(e) + if (!toCombine.contains(key)) { + toCombine.put(key, localOutputLocation) + } + } + + override def useCount: Int = toCombine.size + + private def fieldName(id: Int): String = + s"_mc_$id" + + @tailrec + private def extractLiteral(exp: Expression): GpuLiteral = exp match { + case l: GpuLiteral => l + case a: Alias => extractLiteral(a.child) + case other => throw new RuntimeException("Unsupported expression in contains combiner, " + + "should be a literal type, actual type is " + other.getClass.getName) + } + + private lazy val multiContains: GpuMultiContains = { + val input = toCombine.head._1.e.asInstanceOf[GpuContains].left + val fieldsNPaths = toCombine.toSeq.map { + case (k, id) => + (id, k.e) + }.sortBy(_._1).map { + case (id, e: GpuContains) => + val target = extractLiteral(e.right).value.asInstanceOf[UTF8String] + (StructField(fieldName(id), e.dataType, e.nullable), target) + } + val dt = StructType(fieldsNPaths.map(_._1)) + GpuMultiContains(input, fieldsNPaths.map(_._2), dt) + } + + override def getReplacementExpression(e: Expression): Expression = { + val localId = toCombine(GpuExpressionEquals(e)) + GpuGetStructField(multiContains, localId, Some(fieldName(localId))) + } } case class GpuSubstring(str: Expression, pos: Expression, len: Expression) From 00f97d1782896ed48abc6ac1095ad1111bd91bdb Mon Sep 17 00:00:00 2001 From: Chong Gao Date: Mon, 2 Sep 2024 13:28:25 +0800 Subject: [PATCH 2/4] Format code --- .../spark/sql/rapids/stringFunctions.scala | 21 ++++++++++--------- 1 file changed, 11 insertions(+), 10 deletions(-) diff --git a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/stringFunctions.scala b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/stringFunctions.scala index 3ae66ada010..8ed56facf2e 100644 --- a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/stringFunctions.scala +++ b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/stringFunctions.scala @@ -34,7 +34,7 @@ import com.nvidia.spark.rapids.jni.RegexRewriteUtils import com.nvidia.spark.rapids.shims.{ShimExpression, SparkShimImpl} import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.rapids.catalyst.expressions.{GpuCombinable, GpuExpressionCombiner, GpuExpressionEquals} +import org.apache.spark.sql.rapids.catalyst.expressions._ import org.apache.spark.sql.types._ import org.apache.spark.sql.vectorized.ColumnarBatch import org.apache.spark.unsafe.types.UTF8String @@ -391,11 +391,11 @@ case class GpuConcatWs(children: Seq[Expression]) } case class GpuContains(left: Expression, right: Expression) - extends GpuBinaryExpressionArgsAnyScalar - with Predicate - with ImplicitCastInputTypes - with NullIntolerant - with GpuCombinable { + extends GpuBinaryExpressionArgsAnyScalar + with Predicate + with ImplicitCastInputTypes + with NullIntolerant + with GpuCombinable { override def inputTypes: Seq[DataType] = Seq(StringType) @@ -423,7 +423,7 @@ case class GpuContains(left: Expression, right: Expression) } case class GpuMultiContains(left: Expression, targets: Seq[UTF8String], output: StructType) - extends GpuExpression with ShimExpression { + extends GpuExpression with ShimExpression { override def otherCopyArgs: Seq[AnyRef] = Nil @@ -435,7 +435,7 @@ case class GpuMultiContains(left: Expression, targets: Seq[UTF8String], output: override def columnarEval(batch: ColumnarBatch): GpuColumnVector = { val targetsBytes = targets.map(t => t.getBytes).toArray - withResource(ColumnVector.fromUTF8Strings(targetsBytes : _*)) { targetsCv => + withResource(ColumnVector.fromUTF8Strings(targetsBytes: _*)) { targetsCv => withResource(left.columnarEval(batch)) { lhs => withResource(lhs.getBase.stringContains(targetsCv)) { boolCvs => GpuColumnVector.from(ColumnVector.makeStruct(batch.numRows(), boolCvs: _*), dataType) @@ -443,6 +443,7 @@ case class GpuMultiContains(left: Expression, targets: Seq[UTF8String], output: } } } + override def children: Seq[Expression] = Seq(left) } @@ -468,7 +469,7 @@ class ContainsCombiner(private val exp: GpuContains) extends GpuExpressionCombin */ override def equals(o: Any): Boolean = o match { case other: ContainsCombiner => exp.left.semanticEquals(other.exp.left) && - exp.right.isInstanceOf[GpuLiteral] && other.exp.right.isInstanceOf[GpuLiteral] + exp.right.isInstanceOf[GpuLiteral] && other.exp.right.isInstanceOf[GpuLiteral] case _ => false } @@ -491,7 +492,7 @@ class ContainsCombiner(private val exp: GpuContains) extends GpuExpressionCombin case l: GpuLiteral => l case a: Alias => extractLiteral(a.child) case other => throw new RuntimeException("Unsupported expression in contains combiner, " + - "should be a literal type, actual type is " + other.getClass.getName) + "should be a literal type, actual type is " + other.getClass.getName) } private lazy val multiContains: GpuMultiContains = { From 5003aac1e15fbdcd2c342206805dece898b15018 Mon Sep 17 00:00:00 2001 From: Chong Gao Date: Mon, 2 Sep 2024 14:37:14 +0800 Subject: [PATCH 3/4] Improvement: use makeStructView instead of makeStruct --- .../scala/org/apache/spark/sql/rapids/stringFunctions.scala | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/stringFunctions.scala b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/stringFunctions.scala index 8ed56facf2e..4155bb56e8f 100644 --- a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/stringFunctions.scala +++ b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/stringFunctions.scala @@ -438,7 +438,8 @@ case class GpuMultiContains(left: Expression, targets: Seq[UTF8String], output: withResource(ColumnVector.fromUTF8Strings(targetsBytes: _*)) { targetsCv => withResource(left.columnarEval(batch)) { lhs => withResource(lhs.getBase.stringContains(targetsCv)) { boolCvs => - GpuColumnVector.from(ColumnVector.makeStruct(batch.numRows(), boolCvs: _*), dataType) + val retView = ColumnView.makeStructView(batch.numRows(), boolCvs: _*) + GpuColumnVector.from(retView.copyToColumnVector(), dataType) } } } From 9f377f2a630383d9e4f696a41621566795686f66 Mon Sep 17 00:00:00 2001 From: Chong Gao Date: Tue, 3 Sep 2024 10:27:47 +0800 Subject: [PATCH 4/4] Update pytest --- .../src/main/python/conditionals_test.py | 15 ++++++++++++--- 1 file changed, 12 insertions(+), 3 deletions(-) diff --git a/integration_tests/src/main/python/conditionals_test.py b/integration_tests/src/main/python/conditionals_test.py index a2fb24a7434..8c230b13249 100644 --- a/integration_tests/src/main/python/conditionals_test.py +++ b/integration_tests/src/main/python/conditionals_test.py @@ -385,9 +385,18 @@ def test_combine_string_contains_in_case_when(combine_string_contains_enabled): data_gen = [("c1", string_gen)] sql = """ SELECT - INSTR(c1, 'substring1') > 0, - INSTR(c1, 'substring2') > 0, - INSTR(c1, 'substring3') > 0 + CASE + WHEN INSTR(c1, 'a') > 0 THEN 'a' + WHEN INSTR(c1, 'b') > 0 THEN 'b' + WHEN INSTR(c1, 'c') > 0 THEN 'c' + ELSE '' + END as output_1, + CASE + WHEN INSTR(c1, 'c') > 0 THEN 'c' + WHEN INSTR(c1, 'd') > 0 THEN 'd' + WHEN INSTR(c1, 'e') > 0 THEN 'e' + ELSE '' + END as output_2 from tab """ # spark.rapids.sql.combined.expressions.enabled is true by default