Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix collection_ops_tests for Spark 4.0 [databricks] #11414

Merged
merged 10 commits into from
Oct 12, 2024
11 changes: 8 additions & 3 deletions integration_tests/src/main/python/collection_ops_test.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright (c) 2021-2023, NVIDIA CORPORATION.
# Copyright (c) 2021-2024, NVIDIA CORPORATION.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand All @@ -17,6 +17,8 @@
from asserts import assert_gpu_and_cpu_are_equal_collect, assert_gpu_and_cpu_error
from data_gen import *
from pyspark.sql.types import *

from spark_session import is_before_spark_400
from string_test import mk_str_gen
import pyspark.sql.functions as f
import pyspark.sql.utils
Expand Down Expand Up @@ -326,8 +328,11 @@ def test_sequence_illegal_boundaries(start_gen, stop_gen, step_gen):
@pytest.mark.parametrize('stop_gen', sequence_too_long_length_gens, ids=idfn)
@allow_non_gpu(*non_utc_allow)
def test_sequence_too_long_sequence(stop_gen):
msg = "Too long sequence" if is_before_spark_334() or (not is_before_spark_340() and is_before_spark_342()) \
or is_spark_350() else "Unsuccessful try to create array with"
msg = "Too long sequence" if is_before_spark_334() \
or (not is_before_spark_340() and is_before_spark_342()) \
or is_spark_350() \
else "Can't create array" if not is_before_spark_400() \
else "Unsuccessful try to create array with"
assert_gpu_and_cpu_error(
# To avoid OOM, reduce the row number to 1, it is enough to verify this case.
lambda spark:unary_op_df(spark, stop_gen, 1).selectExpr(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ package org.apache.spark.sql.rapids
import java.util.Optional

import ai.rapids.cudf
import ai.rapids.cudf.{BinaryOp, ColumnVector, ColumnView, DType, Scalar, SegmentedReductionAggregation, Table}
import ai.rapids.cudf.{BinaryOp, ColumnVector, ColumnView, DType, ReductionAggregation, Scalar, SegmentedReductionAggregation, Table}
import com.nvidia.spark.rapids._
import com.nvidia.spark.rapids.Arm._
import com.nvidia.spark.rapids.ArrayIndexUtils.firstIndexAndNumElementUnchecked
Expand Down Expand Up @@ -1651,7 +1651,8 @@ object GpuSequenceUtil {
def computeSequenceSize(
start: ColumnVector,
stop: ColumnVector,
step: ColumnVector): ColumnVector = {
step: ColumnVector,
functionName: String): ColumnVector = {
checkSequenceInputs(start, stop, step)
val actualSize = GetSequenceSize(start, stop, step)
val sizeAsLong = withResource(actualSize) { _ =>
Expand All @@ -1673,7 +1674,12 @@ object GpuSequenceUtil {
// check max size
withResource(Scalar.fromInt(MAX_ROUNDED_ARRAY_LENGTH)) { maxLen =>
withResource(sizeAsLong.lessOrEqualTo(maxLen)) { allValid =>
require(isAllValidTrue(allValid), GetSequenceSize.TOO_LONG_SEQUENCE)
withResource(sizeAsLong.reduce(ReductionAggregation.max())) { maxSizeScalar =>
require(isAllValidTrue(allValid),
RapidsErrorUtils.getTooLongSequenceErrorString(
maxSizeScalar.getLong.asInstanceOf[Int],
functionName))
}
}
}
// cast to int and return
Expand Down Expand Up @@ -1713,7 +1719,7 @@ case class GpuSequence(start: Expression, stop: Expression, stepOpt: Option[Expr
val steps = stepGpuColOpt.map(_.getBase.incRefCount())
.getOrElse(defaultStepsFunc(startCol, stopCol))
closeOnExcept(steps) { _ =>
(computeSequenceSize(startCol, stopCol, steps), steps)
(computeSequenceSize(startCol, stopCol, steps, prettyName), steps)
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,10 +39,7 @@ package com.nvidia.spark.rapids.shims
import ai.rapids.cudf._
import com.nvidia.spark.rapids.Arm._

import org.apache.spark.unsafe.array.ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH
mythrocks marked this conversation as resolved.
Show resolved Hide resolved

object GetSequenceSize {
val TOO_LONG_SEQUENCE = s"Too long sequence found. Should be <= $MAX_ROUNDED_ARRAY_LENGTH"
/**
* Compute the size of each sequence according to 'start', 'stop' and 'step'.
* A row (Row[start, stop, step]) contains at least one null element will produce
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
/*
* Copyright (c) 2024, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

/*** spark-rapids-shim-json-lines
{"spark": "320"}
{"spark": "321"}
{"spark": "321cdh"}
{"spark": "322"}
{"spark": "323"}
{"spark": "324"}
{"spark": "330"}
{"spark": "330cdh"}
{"spark": "330db"}
{"spark": "331"}
{"spark": "332"}
{"spark": "332cdh"}
{"spark": "332db"}
{"spark": "333"}
{"spark": "340"}
{"spark": "341"}
{"spark": "341db"}
{"spark": "350"}
spark-rapids-shim-json-lines ***/
mythrocks marked this conversation as resolved.
Show resolved Hide resolved
package org.apache.spark.sql.rapids.shims

import org.apache.spark.unsafe.array.ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH

trait SequenceSizeTooLongErrorBuilder {

def getTooLongSequenceErrorString(sequenceSize: Int, functionName: String): String = {
// For these Spark versions, the sequence length and function name
// do not appear in the exception message.
s"Too long sequence found. Should be <= $MAX_ROUNDED_ARRAY_LENGTH"
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ import org.apache.spark.sql.catalyst.trees.Origin
import org.apache.spark.sql.errors.{QueryCompilationErrors, QueryExecutionErrors}
import org.apache.spark.sql.types.{DataType, Decimal, DecimalType}

object RapidsErrorUtils extends RapidsQueryErrorUtils {
object RapidsErrorUtils extends RapidsQueryErrorUtils with SequenceSizeTooLongErrorBuilder {
def invalidArrayIndexError(index: Int, numElements: Int,
isElementAtF: Boolean = false): ArrayIndexOutOfBoundsException = {
// Follow the Spark string format before 3.3.0
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,64 +21,9 @@
{"spark": "332"}
{"spark": "332cdh"}
{"spark": "333"}
{"spark": "334"}
spark-rapids-shim-json-lines ***/
package org.apache.spark.sql.rapids.shims

import org.apache.spark.SparkDateTimeException
import org.apache.spark.sql.catalyst.trees.Origin
import org.apache.spark.sql.errors.QueryExecutionErrors
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types.{DataType, Decimal, DecimalType}
object RapidsErrorUtils extends RapidsErrorUtils330To334Base
with SequenceSizeTooLongErrorBuilder

object RapidsErrorUtils extends RapidsErrorUtilsFor330plus with RapidsQueryErrorUtils {

def mapKeyNotExistError(
key: String,
keyType: DataType,
origin: Origin): NoSuchElementException = {
QueryExecutionErrors.mapKeyNotExistError(key, keyType, origin.context)
}

def invalidArrayIndexError(index: Int, numElements: Int,
isElementAtF: Boolean = false): ArrayIndexOutOfBoundsException = {
if (isElementAtF) {
QueryExecutionErrors.invalidElementAtIndexError(index, numElements)
} else {
QueryExecutionErrors.invalidArrayIndexError(index, numElements)
}
}

def arithmeticOverflowError(
message: String,
hint: String = "",
errorContext: String = ""): ArithmeticException = {
QueryExecutionErrors.arithmeticOverflowError(message, hint, errorContext)
}

def cannotChangeDecimalPrecisionError(
value: Decimal,
toType: DecimalType,
context: String = ""): ArithmeticException = {
QueryExecutionErrors.cannotChangeDecimalPrecisionError(
value, toType.precision, toType.scale, context
)
}

def overflowInIntegralDivideError(context: String = ""): ArithmeticException = {
QueryExecutionErrors.arithmeticOverflowError(
"Overflow in integral divide", "try_divide", context
)
}

def sparkDateTimeException(infOrNan: String): SparkDateTimeException = {
// These are the arguments required by SparkDateTimeException class to create error message.
val errorClass = "CAST_INVALID_INPUT"
val messageParameters = Array("DOUBLE", "TIMESTAMP", SQLConf.ANSI_ENABLED.key)
new SparkDateTimeException(errorClass, Array(infOrNan) ++ messageParameters)
}

def sqlArrayIndexNotStartAtOneError(): RuntimeException = {
new ArrayIndexOutOfBoundsException("SQL array indices start at 1")
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
/*
* Copyright (c) 2024, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

/*** spark-rapids-shim-json-lines
{"spark": "330"}
{"spark": "330cdh"}
{"spark": "331"}
{"spark": "332"}
{"spark": "332cdh"}
{"spark": "333"}
{"spark": "334"}
spark-rapids-shim-json-lines ***/
package org.apache.spark.sql.rapids.shims

import org.apache.spark.SparkDateTimeException
import org.apache.spark.sql.catalyst.trees.Origin
import org.apache.spark.sql.errors.QueryExecutionErrors
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types.{DataType, Decimal, DecimalType}

trait RapidsErrorUtils330To334Base extends RapidsErrorUtilsFor330plus with RapidsQueryErrorUtils {

def mapKeyNotExistError(
key: String,
keyType: DataType,
origin: Origin): NoSuchElementException = {
QueryExecutionErrors.mapKeyNotExistError(key, keyType, origin.context)
}

def invalidArrayIndexError(index: Int, numElements: Int,
isElementAtF: Boolean = false): ArrayIndexOutOfBoundsException = {
if (isElementAtF) {
QueryExecutionErrors.invalidElementAtIndexError(index, numElements)
} else {
QueryExecutionErrors.invalidArrayIndexError(index, numElements)
}
}

def arithmeticOverflowError(
message: String,
hint: String = "",
errorContext: String = ""): ArithmeticException = {
QueryExecutionErrors.arithmeticOverflowError(message, hint, errorContext)
}

def cannotChangeDecimalPrecisionError(
value: Decimal,
toType: DecimalType,
context: String = ""): ArithmeticException = {
QueryExecutionErrors.cannotChangeDecimalPrecisionError(
value, toType.precision, toType.scale, context
)
}

def overflowInIntegralDivideError(context: String = ""): ArithmeticException = {
QueryExecutionErrors.arithmeticOverflowError(
"Overflow in integral divide", "try_divide", context
)
}

def sparkDateTimeException(infOrNan: String): SparkDateTimeException = {
// These are the arguments required by SparkDateTimeException class to create error message.
val errorClass = "CAST_INVALID_INPUT"
val messageParameters = Array("DOUBLE", "TIMESTAMP", SQLConf.ANSI_ENABLED.key)
new SparkDateTimeException(errorClass, Array(infOrNan) ++ messageParameters)
}

def sqlArrayIndexNotStartAtOneError(): RuntimeException = {
new ArrayIndexOutOfBoundsException("SQL array indices start at 1")
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,8 @@ package org.apache.spark.sql.rapids.shims

import org.apache.spark.sql.errors.QueryExecutionErrors

object RapidsErrorUtils extends RapidsErrorUtilsBase with RapidsQueryErrorUtils {
object RapidsErrorUtils extends RapidsErrorUtilsBase
with RapidsQueryErrorUtils with SequenceSizeTooLongErrorBuilder {
def sqlArrayIndexNotStartAtOneError(): RuntimeException = {
QueryExecutionErrors.elementAtByIndexZeroError(context = null)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,6 @@ import org.apache.spark.sql.rapids.{AddOverflowChecks, SubtractOverflowChecks}
import org.apache.spark.unsafe.array.ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH

object GetSequenceSize {
val TOO_LONG_SEQUENCE = "Unsuccessful try to create array with elements exceeding the array " +
s"size limit $MAX_ROUNDED_ARRAY_LENGTH"
/**
* Compute the size of each sequence according to 'start', 'stop' and 'step'.
* A row (Row[start, stop, step]) contains at least one null element will produce
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
/*
* Copyright (c) 2024, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

/*** spark-rapids-shim-json-lines
{"spark": "334"}
spark-rapids-shim-json-lines ***/
package org.apache.spark.sql.rapids.shims

object RapidsErrorUtils extends RapidsErrorUtils330To334Base
with SequenceSizeTooLongUnsuccessfulErrorBuilder

Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
/*
* Copyright (c) 2024, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

/*** spark-rapids-shim-json-lines
{"spark": "334"}
{"spark": "342"}
{"spark": "343"}
{"spark": "351"}
{"spark": "352"}
spark-rapids-shim-json-lines ***/
package org.apache.spark.sql.rapids.shims

import org.apache.spark.unsafe.array.ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH

trait SequenceSizeTooLongUnsuccessfulErrorBuilder {
def getTooLongSequenceErrorString(sequenceSize: Int, functionName: String): String = {
// The errant function's name does not feature in the exception message
// prior to Spark 4.0. Neither does the attempted allocation size.
"Unsuccessful try to create array with elements exceeding the array " +
s"size limit $MAX_ROUNDED_ARRAY_LENGTH"
}
}
Loading