Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support multi string contains [databricks] #11413

Open
wants to merge 4 commits into
base: branch-24.12
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 27 additions & 0 deletions integration_tests/src/main/python/conditionals_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -379,3 +379,30 @@ def test_case_when_all_then_values_are_scalars_with_nulls():
"tab",
sql_without_else,
conf = {'spark.rapids.sql.case_when.fuse': 'true'})

@pytest.mark.parametrize('combine_string_contains_enabled', ['true', 'false'])
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: prefer Python constants

Suggested change
@pytest.mark.parametrize('combine_string_contains_enabled', ['true', 'false'])
@pytest.mark.parametrize('combine_string_contains_enabled', [True, False])

However, the pytest case id will be more readable if, instead of a boolean, parameters are strings

@pytest.mark.parametrize('string_contains_mode', ['multiContains', 'singleContains'], ids=idfn)

def test_combine_string_contains_in_case_when(combine_string_contains_enabled):
data_gen = [("c1", string_gen)]
sql = """
SELECT
CASE
WHEN INSTR(c1, 'a') > 0 THEN 'a'
WHEN INSTR(c1, 'b') > 0 THEN 'b'
WHEN INSTR(c1, 'c') > 0 THEN 'c'
ELSE ''
END as output_1,
CASE
WHEN INSTR(c1, 'c') > 0 THEN 'c'
WHEN INSTR(c1, 'd') > 0 THEN 'd'
WHEN INSTR(c1, 'e') > 0 THEN 'e'
ELSE ''
END as output_2
from tab
"""
# spark.rapids.sql.combined.expressions.enabled is true by default
assert_gpu_and_cpu_are_equal_sql(
lambda spark : gen_df(spark, data_gen),
"tab",
sql,
{ "spark.rapids.sql.expression.combined.GpuContains" : combine_string_contains_enabled}
)
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@ import java.nio.charset.Charset
import java.text.DecimalFormatSymbols
import java.util.{Locale, Optional}

import scala.annotation.tailrec
import scala.collection.mutable
import scala.collection.mutable.ArrayBuffer

import ai.rapids.cudf.{BinaryOp, BinaryOperable, CaptureGroups, ColumnVector, ColumnView, DType, PadSide, RegexProgram, RoundMode, Scalar}
Expand All @@ -32,6 +34,7 @@ import com.nvidia.spark.rapids.jni.RegexRewriteUtils
import com.nvidia.spark.rapids.shims.{ShimExpression, SparkShimImpl}

import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.rapids.catalyst.expressions._
import org.apache.spark.sql.types._
import org.apache.spark.sql.vectorized.ColumnarBatch
import org.apache.spark.unsafe.types.UTF8String
Expand Down Expand Up @@ -391,7 +394,8 @@ case class GpuContains(left: Expression, right: Expression)
extends GpuBinaryExpressionArgsAnyScalar
with Predicate
with ImplicitCastInputTypes
with NullIntolerant {
with NullIntolerant
with GpuCombinable {

override def inputTypes: Seq[DataType] = Seq(StringType)

Expand All @@ -411,6 +415,105 @@ case class GpuContains(left: Expression, right: Expression)
doColumnar(expandedLhs, rhs)
}
}

/**
* Get a combiner that can be used to find candidates to combine
*/
override def getCombiner(): GpuExpressionCombiner = new ContainsCombiner(this)
}

case class GpuMultiContains(left: Expression, targets: Seq[UTF8String], output: StructType)
extends GpuExpression with ShimExpression {

override def otherCopyArgs: Seq[AnyRef] = Nil

override def dataType: DataType = output

override def nullable: Boolean = false

override def prettyName: String = "multi_contains"

override def columnarEval(batch: ColumnarBatch): GpuColumnVector = {
val targetsBytes = targets.map(t => t.getBytes).toArray
withResource(ColumnVector.fromUTF8Strings(targetsBytes: _*)) { targetsCv =>
withResource(left.columnarEval(batch)) { lhs =>
withResource(lhs.getBase.stringContains(targetsCv)) { boolCvs =>
val retView = ColumnView.makeStructView(batch.numRows(), boolCvs: _*)
GpuColumnVector.from(retView.copyToColumnVector(), dataType)
}
}
}
}

override def children: Seq[Expression] = Seq(left)
}

class ContainsCombiner(private val exp: GpuContains) extends GpuExpressionCombiner {
private var outputLocation = 0
/**
* A mapping between an expression and where in the output struct of
* the MultiGetJsonObject will the output be.
*/
private val toCombine = mutable.HashMap.empty[GpuExpressionEquals, Int]
addExpression(exp)

override def toString: String = s"ContainsCombiner $toCombine"

override def hashCode: Int = {
// We already know that we are Contains, and what we can combine is based
// on the string column being the same.
"Contains".hashCode + (exp.left.semanticHash() * 17)
}

/**
* only combine when targets are literals
*/
override def equals(o: Any): Boolean = o match {
case other: ContainsCombiner => exp.left.semanticEquals(other.exp.left) &&
exp.right.isInstanceOf[GpuLiteral] && other.exp.right.isInstanceOf[GpuLiteral]
Comment on lines +471 to +473
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: if you make ContainsCombiner a case class you can use pattern matching instead of manual instanceof checks:

Suggested change
override def equals(o: Any): Boolean = o match {
case other: ContainsCombiner => exp.left.semanticEquals(other.exp.left) &&
exp.right.isInstanceOf[GpuLiteral] && other.exp.right.isInstanceOf[GpuLiteral]
override def equals(o: Any): Boolean = (o, exp) match {
case (ContainsCombiner(GpuContains(combLeft, GpuLiteral(_, _))), GpuContains(expLeft, GpuLiteral(_, _))) =>
expLeft.semanticEquals(combLeft)

case _ => false
}

override def addExpression(e: Expression): Unit = {
val localOutputLocation = outputLocation
outputLocation += 1
val key = GpuExpressionEquals(e)
if (!toCombine.contains(key)) {
toCombine.put(key, localOutputLocation)
}
}

override def useCount: Int = toCombine.size

private def fieldName(id: Int): String =
s"_mc_$id"

@tailrec
private def extractLiteral(exp: Expression): GpuLiteral = exp match {
case l: GpuLiteral => l
case a: Alias => extractLiteral(a.child)
case other => throw new RuntimeException("Unsupported expression in contains combiner, " +
"should be a literal type, actual type is " + other.getClass.getName)
}

private lazy val multiContains: GpuMultiContains = {
val input = toCombine.head._1.e.asInstanceOf[GpuContains].left
val fieldsNPaths = toCombine.toSeq.map {
case (k, id) =>
(id, k.e)
}.sortBy(_._1).map {
case (id, e: GpuContains) =>
val target = extractLiteral(e.right).value.asInstanceOf[UTF8String]
(StructField(fieldName(id), e.dataType, e.nullable), target)
}
val dt = StructType(fieldsNPaths.map(_._1))
GpuMultiContains(input, fieldsNPaths.map(_._2), dt)
}

override def getReplacementExpression(e: Expression): Expression = {
val localId = toCombine(GpuExpressionEquals(e))
GpuGetStructField(multiContains, localId, Some(fieldName(localId)))
}
}

case class GpuSubstring(str: Expression, pos: Expression, len: Expression)
Expand Down
Loading