Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix collection_ops_tests for Spark 4.0 [databricks] #11414

Open
wants to merge 6 commits into
base: branch-24.12
Choose a base branch
from

Conversation

mythrocks
Copy link
Collaborator

Fixes #11011.

This commit fixes the failures in collection_ops_tests on Spark 4.0.

On all versions of Spark, when a Sequence is collected with rows that exceed MAX_INT,
an exception is thrown indicating that the collected Sequence/array is
larger than permissible. The different versions of Spark vary in the
contents of the exception message.

On Spark 4, one sees that the error message now contains more
information than all prior versions, including:

  1. The name of the op causing the error.
  2. The errant sequence size.

This commit introduces a shim to make this new information available in
the exception.

Note that this shim does not fit cleanly in RapidsErrorUtils, because
there are differences within major Spark versions. For instance, Spark
3.4.0-1 have a different message as compared to 3.4.2 and 3.4.3.
Likewise, the differences in 3.5.0, 3.5.1, 3.5.2.

@mythrocks mythrocks self-assigned this Aug 30, 2024
@mythrocks mythrocks added the Spark 4.0+ Spark 4.0+ issues label Aug 30, 2024
Fixes NVIDIA#11011.

This commit fixes the failures in `collection_ops_tests` on Spark 4.0.

On all versions of Spark, when a Sequence is collected with rows that exceed MAX_INT,
an exception is thrown indicating that the collected Sequence/array is
larger than permissible. The different versions of Spark vary in the
contents of the exception message.

On Spark 4, one sees that the error message now contains more
information than all prior versions, including:
1. The name of the op causing the error
2. The errant sequence size

This commit introduces a shim to make this new information available in
the exception.

Note that this shim does not fit cleanly in RapidsErrorUtils, because
there are differences within major Spark versions. For instance, Spark
3.4.0-1 have a different message as compared to 3.4.2 and 3.4.3.
Likewise, the differences in 3.5.0, 3.5.1, 3.5.2.

Signed-off-by: MithunR <[email protected]>
@mythrocks
Copy link
Collaborator Author

Build

@@ -17,6 +17,8 @@
from asserts import assert_gpu_and_cpu_are_equal_collect, assert_gpu_and_cpu_error
from data_gen import *
from pyspark.sql.types import *

from src.main.python.spark_session import is_before_spark_400
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: To be consistent with other files, this should just be

Suggested change
from src.main.python.spark_session import is_before_spark_400
from spark_session import is_before_spark_400

@@ -42,7 +42,12 @@ import com.nvidia.spark.rapids.Arm._
import org.apache.spark.unsafe.array.ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH

object GetSequenceSize {
val TOO_LONG_SEQUENCE = s"Too long sequence found. Should be <= $MAX_ROUNDED_ARRAY_LENGTH"
def TOO_LONG_SEQUENCE(sequenceLength: Int, functionName: String) = {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

function name should be camelCase


object SequenceSizeError {
def getTooLongSequenceErrorString(sequenceSize: Int, functionName: String): String = {
QueryExecutionErrors.createArrayWithElementsExceedLimitError(functionName, sequenceSize)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should move this to RapidsErrorUtils

The way I would do it is remove the TOO_LONG_SEQUENCE from the GetSequenceSize.scala altogether

Then introduce another trait RapidsErrorUtilsForSequence with a method tooLongSequenceError. The versions of 320 will have it's own implementation returning the hardcoded message "Too long..." 334-352 will have it's own implementation returning "Unsuccessful try..." and 400 will have it's own where it returns the QueryExecution.createArrayWithElementsExceedLimitErrors

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we're talking at cross purposes. Or maybe I'm simply unable to grok the suggestion yet.

The versions of 320 will have it's own implementation returning the hardcoded message "Too long..." 334-352 will have it's own implementation

This largely describes what my patch currently does, save for moving it into RapidsErrorUtils.

The reason (I think) I can't move this into RapidsErrorUtils is that the error messages differ within the same major Spark version.

The error messages are split as follows:

  1. Too long sequence found:
    1. 3.2.*
    2. 3.3.[x<4]
    3. 3.4.[x<2]
    4. 3.5.0
  2. Unsuccessful try to create array:
    1. 3.3.4
    2. 3.4.[2-3]
    3. 3.5.[1-2]
  3. Can't create array with...:
    1. Only 4.0.

The RapidsErrorUtils shims are grouped by version, as follows:

  1. All 3.2.* together.
  2. All 3.3.* except 3.3.*db
  3. All 3.3.[0,2]db
  4. 3.4.[0-3] + 3.5.[0-2] + 4.0.0.

Now if I'm trying to accommodate the correct error message for, say, Spark 3.4.2, I can't because there's only one RapidsErrorUtils for all of Spark 3.4.x (and that class happens to affect all 3.5.x as well as 4.0).

Is the suggestion to further slice up 3.4.*'s RapidsErrorUtils? We will then have to also slice up the same for 3.3.x and 3.5.x as well, with code duplicated everywhere. This doesn't sound productive to me.

Maybe I've missed something. Perhaps we should discuss this offline, and update the result on this bug.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This refactor is turning into a rats' nest. When the next shim needs to be added, and things need to be split further, I think it's going to be unreadable.

RapidsErrorUtilsBase, used in 33xdb is a very misleading name for the errors shim. It seems to apply only in 33x, while its name suggests that it's the base-class for all the RapidsErrorUtils. This is painful.

I'm going to try to add this with as little collateral damage as I can.

This moves the construction of the long-sequence error strings into
RapidsErrorUtils.  The process involved introducing many new RapidsErrorUtils
classes, and using mix-ins of concrete implementations for the error-string
construction.
@mythrocks mythrocks changed the base branch from branch-24.10 to branch-24.12 September 28, 2024 02:02
@mythrocks
Copy link
Collaborator Author

Apologies for the noise. I had to rebase this to target branch-24.12, which then caused a lot of new reviewers to be added.

@razajafri is already examining this change. The others can ignore this.

@mythrocks
Copy link
Collaborator Author

Build

@mythrocks mythrocks changed the title Fix collection_ops_tests for Spark 4.0 Fix collection_ops_tests for Spark 4.0 [databricks] Sep 28, 2024
@mythrocks
Copy link
Collaborator Author

Build

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Spark 4.0+ Spark 4.0+ issues
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Fix tests failures in collection_ops_test.py
2 participants