Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix collection_ops_tests for Spark 4.0 [databricks] #11414

Open
wants to merge 6 commits into
base: branch-24.12
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 8 additions & 3 deletions integration_tests/src/main/python/collection_ops_test.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright (c) 2021-2023, NVIDIA CORPORATION.
# Copyright (c) 2021-2024, NVIDIA CORPORATION.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand All @@ -17,6 +17,8 @@
from asserts import assert_gpu_and_cpu_are_equal_collect, assert_gpu_and_cpu_error
from data_gen import *
from pyspark.sql.types import *

from src.main.python.spark_session import is_before_spark_400
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: To be consistent with other files, this should just be

Suggested change
from src.main.python.spark_session import is_before_spark_400
from spark_session import is_before_spark_400

from string_test import mk_str_gen
import pyspark.sql.functions as f
import pyspark.sql.utils
Expand Down Expand Up @@ -326,8 +328,11 @@ def test_sequence_illegal_boundaries(start_gen, stop_gen, step_gen):
@pytest.mark.parametrize('stop_gen', sequence_too_long_length_gens, ids=idfn)
@allow_non_gpu(*non_utc_allow)
def test_sequence_too_long_sequence(stop_gen):
msg = "Too long sequence" if is_before_spark_334() or (not is_before_spark_340() and is_before_spark_342()) \
or is_spark_350() else "Unsuccessful try to create array with"
msg = "Too long sequence" if is_before_spark_334() \
or (not is_before_spark_340() and is_before_spark_342()) \
or is_spark_350() \
else "Can't create array" if not is_before_spark_400() \
else "Unsuccessful try to create array with"
assert_gpu_and_cpu_error(
# To avoid OOM, reduce the row number to 1, it is enough to verify this case.
lambda spark:unary_op_df(spark, stop_gen, 1).selectExpr(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ package org.apache.spark.sql.rapids
import java.util.Optional

import ai.rapids.cudf
import ai.rapids.cudf.{BinaryOp, ColumnVector, ColumnView, DType, Scalar, SegmentedReductionAggregation, Table}
import ai.rapids.cudf.{BinaryOp, ColumnVector, ColumnView, DType, ReductionAggregation, Scalar, SegmentedReductionAggregation, Table}
import com.nvidia.spark.rapids._
import com.nvidia.spark.rapids.Arm._
import com.nvidia.spark.rapids.ArrayIndexUtils.firstIndexAndNumElementUnchecked
Expand Down Expand Up @@ -1535,7 +1535,8 @@ object GpuSequenceUtil {
def computeSequenceSize(
start: ColumnVector,
stop: ColumnVector,
step: ColumnVector): ColumnVector = {
step: ColumnVector,
functionName: String): ColumnVector = {
checkSequenceInputs(start, stop, step)
val actualSize = GetSequenceSize(start, stop, step)
val sizeAsLong = withResource(actualSize) { _ =>
Expand All @@ -1557,7 +1558,11 @@ object GpuSequenceUtil {
// check max size
withResource(Scalar.fromInt(MAX_ROUNDED_ARRAY_LENGTH)) { maxLen =>
withResource(sizeAsLong.lessOrEqualTo(maxLen)) { allValid =>
require(isAllValidTrue(allValid), GetSequenceSize.TOO_LONG_SEQUENCE)
withResource(sizeAsLong.reduce(ReductionAggregation.max())) { maxSizeScalar =>
require(isAllValidTrue(allValid),
GetSequenceSize.TOO_LONG_SEQUENCE(maxSizeScalar.getLong.asInstanceOf[Int],
functionName))
}
}
}
// cast to int and return
Expand Down Expand Up @@ -1597,7 +1602,7 @@ case class GpuSequence(start: Expression, stop: Expression, stepOpt: Option[Expr
val steps = stepGpuColOpt.map(_.getBase.incRefCount())
.getOrElse(defaultStepsFunc(startCol, stopCol))
closeOnExcept(steps) { _ =>
(computeSequenceSize(startCol, stopCol, steps), steps)
(computeSequenceSize(startCol, stopCol, steps, prettyName), steps)
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,12 @@ import com.nvidia.spark.rapids.Arm._
import org.apache.spark.unsafe.array.ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH

object GetSequenceSize {
val TOO_LONG_SEQUENCE = s"Too long sequence found. Should be <= $MAX_ROUNDED_ARRAY_LENGTH"
def TOO_LONG_SEQUENCE(sequenceLength: Int, functionName: String) = {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

function name should be camelCase

// For these Spark versions, the sequence length and function name
// do not appear in the exception message.
s"Too long sequence found. Should be <= $MAX_ROUNDED_ARRAY_LENGTH"
}

/**
* Compute the size of each sequence according to 'start', 'stop' and 'step'.
* A row (Row[start, stop, step]) contains at least one null element will produce
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,11 +28,12 @@ import ai.rapids.cudf._
import com.nvidia.spark.rapids.Arm._

import org.apache.spark.sql.rapids.{AddOverflowChecks, SubtractOverflowChecks}
import org.apache.spark.sql.rapids.shims.SequenceSizeError
import org.apache.spark.unsafe.array.ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH

object GetSequenceSize {
val TOO_LONG_SEQUENCE = "Unsuccessful try to create array with elements exceeding the array " +
s"size limit $MAX_ROUNDED_ARRAY_LENGTH"
def TOO_LONG_SEQUENCE(sequenceLength: Int, functionName: String): String =
SequenceSizeError.getTooLongSequenceErrorString(sequenceLength, functionName)
/**
* Compute the size of each sequence according to 'start', 'stop' and 'step'.
* A row (Row[start, stop, step]) contains at least one null element will produce
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
/*
* Copyright (c) 2024, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

/*** spark-rapids-shim-json-lines
{"spark": "334"}
{"spark": "342"}
{"spark": "343"}
{"spark": "351"}
{"spark": "352"}
spark-rapids-shim-json-lines ***/
package org.apache.spark.sql.rapids.shims

import org.apache.spark.unsafe.array.ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH

object SequenceSizeError {
def getTooLongSequenceErrorString(sequenceSize: Int, functionName: String): String = {
// The errant function's name does not feature in the exception message
// prior to Spark 4.0. Neither does the attempted allocation size.
"Unsuccessful try to create array with elements exceeding the array " +
s"size limit $MAX_ROUNDED_ARRAY_LENGTH"
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
/*
* Copyright (c) 2024, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

/*** spark-rapids-shim-json-lines
{"spark": "400"}
spark-rapids-shim-json-lines ***/
package org.apache.spark.sql.rapids.shims

import org.apache.spark.sql.errors.QueryExecutionErrors

object SequenceSizeError {
def getTooLongSequenceErrorString(sequenceSize: Int, functionName: String): String = {
QueryExecutionErrors.createArrayWithElementsExceedLimitError(functionName, sequenceSize)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should move this to RapidsErrorUtils

The way I would do it is remove the TOO_LONG_SEQUENCE from the GetSequenceSize.scala altogether

Then introduce another trait RapidsErrorUtilsForSequence with a method tooLongSequenceError. The versions of 320 will have it's own implementation returning the hardcoded message "Too long..." 334-352 will have it's own implementation returning "Unsuccessful try..." and 400 will have it's own where it returns the QueryExecution.createArrayWithElementsExceedLimitErrors

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we're talking at cross purposes. Or maybe I'm simply unable to grok the suggestion yet.

The versions of 320 will have it's own implementation returning the hardcoded message "Too long..." 334-352 will have it's own implementation

This largely describes what my patch currently does, save for moving it into RapidsErrorUtils.

The reason (I think) I can't move this into RapidsErrorUtils is that the error messages differ within the same major Spark version.

The error messages are split as follows:

  1. Too long sequence found:
    1. 3.2.*
    2. 3.3.[x<4]
    3. 3.4.[x<2]
    4. 3.5.0
  2. Unsuccessful try to create array:
    1. 3.3.4
    2. 3.4.[2-3]
    3. 3.5.[1-2]
  3. Can't create array with...:
    1. Only 4.0.

The RapidsErrorUtils shims are grouped by version, as follows:

  1. All 3.2.* together.
  2. All 3.3.* except 3.3.*db
  3. All 3.3.[0,2]db
  4. 3.4.[0-3] + 3.5.[0-2] + 4.0.0.

Now if I'm trying to accommodate the correct error message for, say, Spark 3.4.2, I can't because there's only one RapidsErrorUtils for all of Spark 3.4.x (and that class happens to affect all 3.5.x as well as 4.0).

Is the suggestion to further slice up 3.4.*'s RapidsErrorUtils? We will then have to also slice up the same for 3.3.x and 3.5.x as well, with code duplicated everywhere. This doesn't sound productive to me.

Maybe I've missed something. Perhaps we should discuss this offline, and update the result on this bug.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This refactor is turning into a rats' nest. When the next shim needs to be added, and things need to be split further, I think it's going to be unreadable.

RapidsErrorUtilsBase, used in 33xdb is a very misleading name for the errors shim. It seems to apply only in 33x, while its name suggests that it's the base-class for all the RapidsErrorUtils. This is painful.

I'm going to try to add this with as little collateral damage as I can.

.getMessage
}
}