diff --git a/integration_tests/src/main/python/collection_ops_test.py b/integration_tests/src/main/python/collection_ops_test.py index 099eb28c053..813f1a77c94 100644 --- a/integration_tests/src/main/python/collection_ops_test.py +++ b/integration_tests/src/main/python/collection_ops_test.py @@ -1,4 +1,4 @@ -# Copyright (c) 2021-2023, NVIDIA CORPORATION. +# Copyright (c) 2021-2024, NVIDIA CORPORATION. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -17,6 +17,8 @@ from asserts import assert_gpu_and_cpu_are_equal_collect, assert_gpu_and_cpu_error from data_gen import * from pyspark.sql.types import * + +from spark_session import is_before_spark_400 from string_test import mk_str_gen import pyspark.sql.functions as f import pyspark.sql.utils @@ -326,8 +328,11 @@ def test_sequence_illegal_boundaries(start_gen, stop_gen, step_gen): @pytest.mark.parametrize('stop_gen', sequence_too_long_length_gens, ids=idfn) @allow_non_gpu(*non_utc_allow) def test_sequence_too_long_sequence(stop_gen): - msg = "Too long sequence" if is_before_spark_334() or (not is_before_spark_340() and is_before_spark_342()) \ - or is_spark_350() else "Unsuccessful try to create array with" + msg = "Too long sequence" if is_before_spark_334() \ + or (not is_before_spark_340() and is_before_spark_342()) \ + or is_spark_350() \ + else "Can't create array" if not is_before_spark_400() \ + else "Unsuccessful try to create array with" assert_gpu_and_cpu_error( # To avoid OOM, reduce the row number to 1, it is enough to verify this case. lambda spark:unary_op_df(spark, stop_gen, 1).selectExpr( diff --git a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/collectionOperations.scala b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/collectionOperations.scala index b675ef2bfbd..23b823e7117 100644 --- a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/collectionOperations.scala +++ b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/collectionOperations.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.rapids import java.util.Optional import ai.rapids.cudf -import ai.rapids.cudf.{BinaryOp, ColumnVector, ColumnView, DType, Scalar, SegmentedReductionAggregation, Table} +import ai.rapids.cudf.{BinaryOp, ColumnVector, ColumnView, DType, ReductionAggregation, Scalar, SegmentedReductionAggregation, Table} import com.nvidia.spark.rapids._ import com.nvidia.spark.rapids.Arm._ import com.nvidia.spark.rapids.ArrayIndexUtils.firstIndexAndNumElementUnchecked @@ -1651,7 +1651,8 @@ object GpuSequenceUtil { def computeSequenceSize( start: ColumnVector, stop: ColumnVector, - step: ColumnVector): ColumnVector = { + step: ColumnVector, + functionName: String): ColumnVector = { checkSequenceInputs(start, stop, step) val actualSize = GetSequenceSize(start, stop, step) val sizeAsLong = withResource(actualSize) { _ => @@ -1673,7 +1674,12 @@ object GpuSequenceUtil { // check max size withResource(Scalar.fromInt(MAX_ROUNDED_ARRAY_LENGTH)) { maxLen => withResource(sizeAsLong.lessOrEqualTo(maxLen)) { allValid => - require(isAllValidTrue(allValid), GetSequenceSize.TOO_LONG_SEQUENCE) + withResource(sizeAsLong.reduce(ReductionAggregation.max())) { maxSizeScalar => + require(isAllValidTrue(allValid), + RapidsErrorUtils.getTooLongSequenceErrorString( + maxSizeScalar.getLong.asInstanceOf[Int], + functionName)) + } } } // cast to int and return @@ -1713,7 +1719,7 @@ case class GpuSequence(start: Expression, stop: Expression, stepOpt: Option[Expr val steps = stepGpuColOpt.map(_.getBase.incRefCount()) .getOrElse(defaultStepsFunc(startCol, stopCol)) closeOnExcept(steps) { _ => - (computeSequenceSize(startCol, stopCol, steps), steps) + (computeSequenceSize(startCol, stopCol, steps, prettyName), steps) } } diff --git a/sql-plugin/src/main/spark320/scala/com/nvidia/spark/rapids/shims/GetSequenceSize.scala b/sql-plugin/src/main/spark320/scala/com/nvidia/spark/rapids/shims/GetSequenceSize.scala index 32ca03974bf..deb305cc89c 100644 --- a/sql-plugin/src/main/spark320/scala/com/nvidia/spark/rapids/shims/GetSequenceSize.scala +++ b/sql-plugin/src/main/spark320/scala/com/nvidia/spark/rapids/shims/GetSequenceSize.scala @@ -39,10 +39,7 @@ package com.nvidia.spark.rapids.shims import ai.rapids.cudf._ import com.nvidia.spark.rapids.Arm._ -import org.apache.spark.unsafe.array.ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH - object GetSequenceSize { - val TOO_LONG_SEQUENCE = s"Too long sequence found. Should be <= $MAX_ROUNDED_ARRAY_LENGTH" /** * Compute the size of each sequence according to 'start', 'stop' and 'step'. * A row (Row[start, stop, step]) contains at least one null element will produce diff --git a/sql-plugin/src/main/spark320/scala/com/nvidia/spark/rapids/shims/SequenceSizeTooLongErrorBuilder.scala b/sql-plugin/src/main/spark320/scala/com/nvidia/spark/rapids/shims/SequenceSizeTooLongErrorBuilder.scala new file mode 100644 index 00000000000..32d38540cb5 --- /dev/null +++ b/sql-plugin/src/main/spark320/scala/com/nvidia/spark/rapids/shims/SequenceSizeTooLongErrorBuilder.scala @@ -0,0 +1,48 @@ +/* + * Copyright (c) 2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/*** spark-rapids-shim-json-lines +{"spark": "320"} +{"spark": "321"} +{"spark": "321cdh"} +{"spark": "322"} +{"spark": "323"} +{"spark": "324"} +{"spark": "330"} +{"spark": "330cdh"} +{"spark": "330db"} +{"spark": "331"} +{"spark": "332"} +{"spark": "332cdh"} +{"spark": "332db"} +{"spark": "333"} +{"spark": "340"} +{"spark": "341"} +{"spark": "341db"} +{"spark": "350"} +spark-rapids-shim-json-lines ***/ +package org.apache.spark.sql.rapids.shims + +import org.apache.spark.unsafe.array.ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH + +trait SequenceSizeTooLongErrorBuilder { + + def getTooLongSequenceErrorString(sequenceSize: Int, functionName: String): String = { + // For these Spark versions, the sequence length and function name + // do not appear in the exception message. + s"Too long sequence found. Should be <= $MAX_ROUNDED_ARRAY_LENGTH" + } +} \ No newline at end of file diff --git a/sql-plugin/src/main/spark320/scala/org/apache/spark/sql/rapids/shims/RapidsErrorUtils.scala b/sql-plugin/src/main/spark320/scala/org/apache/spark/sql/rapids/shims/RapidsErrorUtils.scala index 68a6ce30569..dd387d453b5 100644 --- a/sql-plugin/src/main/spark320/scala/org/apache/spark/sql/rapids/shims/RapidsErrorUtils.scala +++ b/sql-plugin/src/main/spark320/scala/org/apache/spark/sql/rapids/shims/RapidsErrorUtils.scala @@ -29,7 +29,7 @@ import org.apache.spark.sql.catalyst.trees.Origin import org.apache.spark.sql.errors.{QueryCompilationErrors, QueryExecutionErrors} import org.apache.spark.sql.types.{DataType, Decimal, DecimalType} -object RapidsErrorUtils extends RapidsQueryErrorUtils { +object RapidsErrorUtils extends RapidsQueryErrorUtils with SequenceSizeTooLongErrorBuilder { def invalidArrayIndexError(index: Int, numElements: Int, isElementAtF: Boolean = false): ArrayIndexOutOfBoundsException = { // Follow the Spark string format before 3.3.0 diff --git a/sql-plugin/src/main/spark330/scala/org/apache/spark/sql/rapids/shims/RapidsErrorUtils.scala b/sql-plugin/src/main/spark330/scala/org/apache/spark/sql/rapids/shims/RapidsErrorUtils.scala index e5cdcd43568..80c61a9d481 100644 --- a/sql-plugin/src/main/spark330/scala/org/apache/spark/sql/rapids/shims/RapidsErrorUtils.scala +++ b/sql-plugin/src/main/spark330/scala/org/apache/spark/sql/rapids/shims/RapidsErrorUtils.scala @@ -21,64 +21,9 @@ {"spark": "332"} {"spark": "332cdh"} {"spark": "333"} -{"spark": "334"} spark-rapids-shim-json-lines ***/ package org.apache.spark.sql.rapids.shims -import org.apache.spark.SparkDateTimeException -import org.apache.spark.sql.catalyst.trees.Origin -import org.apache.spark.sql.errors.QueryExecutionErrors -import org.apache.spark.sql.internal.SQLConf -import org.apache.spark.sql.types.{DataType, Decimal, DecimalType} +object RapidsErrorUtils extends RapidsErrorUtils330To334Base + with SequenceSizeTooLongErrorBuilder -object RapidsErrorUtils extends RapidsErrorUtilsFor330plus with RapidsQueryErrorUtils { - - def mapKeyNotExistError( - key: String, - keyType: DataType, - origin: Origin): NoSuchElementException = { - QueryExecutionErrors.mapKeyNotExistError(key, keyType, origin.context) - } - - def invalidArrayIndexError(index: Int, numElements: Int, - isElementAtF: Boolean = false): ArrayIndexOutOfBoundsException = { - if (isElementAtF) { - QueryExecutionErrors.invalidElementAtIndexError(index, numElements) - } else { - QueryExecutionErrors.invalidArrayIndexError(index, numElements) - } - } - - def arithmeticOverflowError( - message: String, - hint: String = "", - errorContext: String = ""): ArithmeticException = { - QueryExecutionErrors.arithmeticOverflowError(message, hint, errorContext) - } - - def cannotChangeDecimalPrecisionError( - value: Decimal, - toType: DecimalType, - context: String = ""): ArithmeticException = { - QueryExecutionErrors.cannotChangeDecimalPrecisionError( - value, toType.precision, toType.scale, context - ) - } - - def overflowInIntegralDivideError(context: String = ""): ArithmeticException = { - QueryExecutionErrors.arithmeticOverflowError( - "Overflow in integral divide", "try_divide", context - ) - } - - def sparkDateTimeException(infOrNan: String): SparkDateTimeException = { - // These are the arguments required by SparkDateTimeException class to create error message. - val errorClass = "CAST_INVALID_INPUT" - val messageParameters = Array("DOUBLE", "TIMESTAMP", SQLConf.ANSI_ENABLED.key) - new SparkDateTimeException(errorClass, Array(infOrNan) ++ messageParameters) - } - - def sqlArrayIndexNotStartAtOneError(): RuntimeException = { - new ArrayIndexOutOfBoundsException("SQL array indices start at 1") - } -} diff --git a/sql-plugin/src/main/spark330/scala/org/apache/spark/sql/rapids/shims/RapidsErrorUtils330To334Base.scala b/sql-plugin/src/main/spark330/scala/org/apache/spark/sql/rapids/shims/RapidsErrorUtils330To334Base.scala new file mode 100644 index 00000000000..0e8f9261d6e --- /dev/null +++ b/sql-plugin/src/main/spark330/scala/org/apache/spark/sql/rapids/shims/RapidsErrorUtils330To334Base.scala @@ -0,0 +1,84 @@ +/* + * Copyright (c) 2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/*** spark-rapids-shim-json-lines +{"spark": "330"} +{"spark": "330cdh"} +{"spark": "331"} +{"spark": "332"} +{"spark": "332cdh"} +{"spark": "333"} +{"spark": "334"} +spark-rapids-shim-json-lines ***/ +package org.apache.spark.sql.rapids.shims + +import org.apache.spark.SparkDateTimeException +import org.apache.spark.sql.catalyst.trees.Origin +import org.apache.spark.sql.errors.QueryExecutionErrors +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.types.{DataType, Decimal, DecimalType} + +trait RapidsErrorUtils330To334Base extends RapidsErrorUtilsFor330plus with RapidsQueryErrorUtils { + + def mapKeyNotExistError( + key: String, + keyType: DataType, + origin: Origin): NoSuchElementException = { + QueryExecutionErrors.mapKeyNotExistError(key, keyType, origin.context) + } + + def invalidArrayIndexError(index: Int, numElements: Int, + isElementAtF: Boolean = false): ArrayIndexOutOfBoundsException = { + if (isElementAtF) { + QueryExecutionErrors.invalidElementAtIndexError(index, numElements) + } else { + QueryExecutionErrors.invalidArrayIndexError(index, numElements) + } + } + + def arithmeticOverflowError( + message: String, + hint: String = "", + errorContext: String = ""): ArithmeticException = { + QueryExecutionErrors.arithmeticOverflowError(message, hint, errorContext) + } + + def cannotChangeDecimalPrecisionError( + value: Decimal, + toType: DecimalType, + context: String = ""): ArithmeticException = { + QueryExecutionErrors.cannotChangeDecimalPrecisionError( + value, toType.precision, toType.scale, context + ) + } + + def overflowInIntegralDivideError(context: String = ""): ArithmeticException = { + QueryExecutionErrors.arithmeticOverflowError( + "Overflow in integral divide", "try_divide", context + ) + } + + def sparkDateTimeException(infOrNan: String): SparkDateTimeException = { + // These are the arguments required by SparkDateTimeException class to create error message. + val errorClass = "CAST_INVALID_INPUT" + val messageParameters = Array("DOUBLE", "TIMESTAMP", SQLConf.ANSI_ENABLED.key) + new SparkDateTimeException(errorClass, Array(infOrNan) ++ messageParameters) + } + + def sqlArrayIndexNotStartAtOneError(): RuntimeException = { + new ArrayIndexOutOfBoundsException("SQL array indices start at 1") + } +} diff --git a/sql-plugin/src/main/spark330db/scala/org/apache/spark/sql/rapids/shims/RapidsErrorUtils.scala b/sql-plugin/src/main/spark330db/scala/org/apache/spark/sql/rapids/shims/RapidsErrorUtils.scala index 7e58a54c921..0f40c6e3bfd 100644 --- a/sql-plugin/src/main/spark330db/scala/org/apache/spark/sql/rapids/shims/RapidsErrorUtils.scala +++ b/sql-plugin/src/main/spark330db/scala/org/apache/spark/sql/rapids/shims/RapidsErrorUtils.scala @@ -22,7 +22,9 @@ package org.apache.spark.sql.rapids.shims import org.apache.spark.sql.errors.QueryExecutionErrors -object RapidsErrorUtils extends RapidsErrorUtilsBase with RapidsQueryErrorUtils { +object RapidsErrorUtils extends RapidsErrorUtilsBase + with RapidsQueryErrorUtils + with SequenceSizeTooLongErrorBuilder { def sqlArrayIndexNotStartAtOneError(): RuntimeException = { QueryExecutionErrors.elementAtByIndexZeroError(context = null) } diff --git a/sql-plugin/src/main/spark334/scala/com/nvidia/spark/rapids/shims/GetSequenceSize.scala b/sql-plugin/src/main/spark334/scala/com/nvidia/spark/rapids/shims/GetSequenceSize.scala index aba0f465483..f386973200a 100644 --- a/sql-plugin/src/main/spark334/scala/com/nvidia/spark/rapids/shims/GetSequenceSize.scala +++ b/sql-plugin/src/main/spark334/scala/com/nvidia/spark/rapids/shims/GetSequenceSize.scala @@ -31,8 +31,6 @@ import org.apache.spark.sql.rapids.{AddOverflowChecks, SubtractOverflowChecks} import org.apache.spark.unsafe.array.ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH object GetSequenceSize { - val TOO_LONG_SEQUENCE = "Unsuccessful try to create array with elements exceeding the array " + - s"size limit $MAX_ROUNDED_ARRAY_LENGTH" /** * Compute the size of each sequence according to 'start', 'stop' and 'step'. * A row (Row[start, stop, step]) contains at least one null element will produce diff --git a/sql-plugin/src/main/spark334/scala/org/apache/spark/sql/rapids/shims/RapidsErrorUtils.scala b/sql-plugin/src/main/spark334/scala/org/apache/spark/sql/rapids/shims/RapidsErrorUtils.scala new file mode 100644 index 00000000000..0376d2e69a7 --- /dev/null +++ b/sql-plugin/src/main/spark334/scala/org/apache/spark/sql/rapids/shims/RapidsErrorUtils.scala @@ -0,0 +1,24 @@ +/* + * Copyright (c) 2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/*** spark-rapids-shim-json-lines +{"spark": "334"} +spark-rapids-shim-json-lines ***/ +package org.apache.spark.sql.rapids.shims + +object RapidsErrorUtils extends RapidsErrorUtils330To334Base + with SequenceSizeTooLongUnsuccessfulErrorBuilder + diff --git a/sql-plugin/src/main/spark334/scala/org/apache/spark/sql/rapids/shims/SequenceSizeTooLongUnsuccessfulErrorBuilder.scala b/sql-plugin/src/main/spark334/scala/org/apache/spark/sql/rapids/shims/SequenceSizeTooLongUnsuccessfulErrorBuilder.scala new file mode 100644 index 00000000000..5e584de7167 --- /dev/null +++ b/sql-plugin/src/main/spark334/scala/org/apache/spark/sql/rapids/shims/SequenceSizeTooLongUnsuccessfulErrorBuilder.scala @@ -0,0 +1,35 @@ +/* + * Copyright (c) 2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/*** spark-rapids-shim-json-lines +{"spark": "334"} +{"spark": "342"} +{"spark": "343"} +{"spark": "351"} +{"spark": "352"} +spark-rapids-shim-json-lines ***/ +package org.apache.spark.sql.rapids.shims + +import org.apache.spark.unsafe.array.ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH + +trait SequenceSizeTooLongUnsuccessfulErrorBuilder { + def getTooLongSequenceErrorString(sequenceSize: Int, functionName: String): String = { + // The errant function's name does not feature in the exception message + // prior to Spark 4.0. Neither does the attempted allocation size. + "Unsuccessful try to create array with elements exceeding the array " + + s"size limit $MAX_ROUNDED_ARRAY_LENGTH" + } +} diff --git a/sql-plugin/src/main/spark340/scala/org/apache/spark/sql/rapids/shims/RapidsErrorUtils.scala b/sql-plugin/src/main/spark340/scala/org/apache/spark/sql/rapids/shims/RapidsErrorUtils.scala index 0bf3e66d556..fadce5a8231 100644 --- a/sql-plugin/src/main/spark340/scala/org/apache/spark/sql/rapids/shims/RapidsErrorUtils.scala +++ b/sql-plugin/src/main/spark340/scala/org/apache/spark/sql/rapids/shims/RapidsErrorUtils.scala @@ -17,80 +17,9 @@ /*** spark-rapids-shim-json-lines {"spark": "340"} {"spark": "341"} -{"spark": "342"} -{"spark": "343"} {"spark": "350"} -{"spark": "351"} -{"spark": "352"} -{"spark": "400"} spark-rapids-shim-json-lines ***/ package org.apache.spark.sql.rapids.shims -import org.apache.spark.SparkDateTimeException -import org.apache.spark.sql.catalyst.trees.{Origin, SQLQueryContext} -import org.apache.spark.sql.errors.QueryExecutionErrors -import org.apache.spark.sql.internal.SQLConf -import org.apache.spark.sql.types.{DataType, Decimal, DecimalType} - -object RapidsErrorUtils extends RapidsErrorUtilsFor330plus with RapidsQueryErrorUtils { - - def mapKeyNotExistError( - key: String, - keyType: DataType, - origin: Origin): NoSuchElementException = { - throw new UnsupportedOperationException( - "`mapKeyNotExistError` has been removed since Spark 3.4.0. " - ) - } - - def invalidArrayIndexError( - index: Int, - numElements: Int, - isElementAtF: Boolean = false, - context: SQLQueryContext = null): ArrayIndexOutOfBoundsException = { - if (isElementAtF) { - QueryExecutionErrors.invalidElementAtIndexError(index, numElements, context) - } else { - QueryExecutionErrors.invalidArrayIndexError(index, numElements, context) - } - } - - def arithmeticOverflowError( - message: String, - hint: String = "", - errorContext: SQLQueryContext = null): ArithmeticException = { - QueryExecutionErrors.arithmeticOverflowError(message, hint, errorContext) - } - - def cannotChangeDecimalPrecisionError( - value: Decimal, - toType: DecimalType, - context: SQLQueryContext = null): ArithmeticException = { - QueryExecutionErrors.cannotChangeDecimalPrecisionError( - value, toType.precision, toType.scale, context - ) - } - - def overflowInIntegralDivideError(context: SQLQueryContext = null): ArithmeticException = { - QueryExecutionErrors.arithmeticOverflowError( - "Overflow in integral divide", "try_divide", context - ) - } - - def sparkDateTimeException(infOrNan: String): SparkDateTimeException = { - // These are the arguments required by SparkDateTimeException class to create error message. - val errorClass = "CAST_INVALID_INPUT" - val messageParameters = Map("expression" -> infOrNan, "sourceType" -> "DOUBLE", - "targetType" -> "TIMESTAMP", "ansiConfig" -> SQLConf.ANSI_ENABLED.key) - SparkDateTimeExceptionShims.newSparkDateTimeException(errorClass, messageParameters, - Array.empty, "") - } - - def sqlArrayIndexNotStartAtOneError(): RuntimeException = { - QueryExecutionErrors.invalidIndexOfZeroError(context = null) - } - - override def intervalDivByZeroError(origin: Origin): ArithmeticException = { - QueryExecutionErrors.intervalDividedByZeroError(origin.context) - } -} +object RapidsErrorUtils extends RapidsErrorUtils340PlusBase + with SequenceSizeTooLongErrorBuilder diff --git a/sql-plugin/src/main/spark340/scala/org/apache/spark/sql/rapids/shims/RapidsErrorUtils340PlusBase.scala b/sql-plugin/src/main/spark340/scala/org/apache/spark/sql/rapids/shims/RapidsErrorUtils340PlusBase.scala new file mode 100644 index 00000000000..173b06e3f8f --- /dev/null +++ b/sql-plugin/src/main/spark340/scala/org/apache/spark/sql/rapids/shims/RapidsErrorUtils340PlusBase.scala @@ -0,0 +1,96 @@ +/* + * Copyright (c) 2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/*** spark-rapids-shim-json-lines +{"spark": "340"} +{"spark": "341"} +{"spark": "342"} +{"spark": "343"} +{"spark": "350"} +{"spark": "351"} +{"spark": "352"} +{"spark": "400"} +spark-rapids-shim-json-lines ***/ +package org.apache.spark.sql.rapids.shims + +import org.apache.spark.SparkDateTimeException +import org.apache.spark.sql.catalyst.trees.{Origin, SQLQueryContext} +import org.apache.spark.sql.errors.QueryExecutionErrors +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.types.{DataType, Decimal, DecimalType} + +trait RapidsErrorUtils340PlusBase extends RapidsErrorUtilsFor330plus with RapidsQueryErrorUtils { + + def mapKeyNotExistError( + key: String, + keyType: DataType, + origin: Origin): NoSuchElementException = { + throw new UnsupportedOperationException( + "`mapKeyNotExistError` has been removed since Spark 3.4.0. " + ) + } + + def invalidArrayIndexError( + index: Int, + numElements: Int, + isElementAtF: Boolean = false, + context: SQLQueryContext = null): ArrayIndexOutOfBoundsException = { + if (isElementAtF) { + QueryExecutionErrors.invalidElementAtIndexError(index, numElements, context) + } else { + QueryExecutionErrors.invalidArrayIndexError(index, numElements, context) + } + } + + def arithmeticOverflowError( + message: String, + hint: String = "", + errorContext: SQLQueryContext = null): ArithmeticException = { + QueryExecutionErrors.arithmeticOverflowError(message, hint, errorContext) + } + + def cannotChangeDecimalPrecisionError( + value: Decimal, + toType: DecimalType, + context: SQLQueryContext = null): ArithmeticException = { + QueryExecutionErrors.cannotChangeDecimalPrecisionError( + value, toType.precision, toType.scale, context + ) + } + + def overflowInIntegralDivideError(context: SQLQueryContext = null): ArithmeticException = { + QueryExecutionErrors.arithmeticOverflowError( + "Overflow in integral divide", "try_divide", context + ) + } + + def sparkDateTimeException(infOrNan: String): SparkDateTimeException = { + // These are the arguments required by SparkDateTimeException class to create error message. + val errorClass = "CAST_INVALID_INPUT" + val messageParameters = Map("expression" -> infOrNan, "sourceType" -> "DOUBLE", + "targetType" -> "TIMESTAMP", "ansiConfig" -> SQLConf.ANSI_ENABLED.key) + SparkDateTimeExceptionShims.newSparkDateTimeException(errorClass, messageParameters, + Array.empty, "") + } + + def sqlArrayIndexNotStartAtOneError(): RuntimeException = { + QueryExecutionErrors.invalidIndexOfZeroError(context = null) + } + + override def intervalDivByZeroError(origin: Origin): ArithmeticException = { + QueryExecutionErrors.intervalDividedByZeroError(origin.context) + } +} diff --git a/sql-plugin/src/main/spark341db/scala/org/apache/spark/sql/rapids/shims/RapidsErrorUtils.scala b/sql-plugin/src/main/spark341db/scala/org/apache/spark/sql/rapids/shims/RapidsErrorUtils.scala index 9b800d4e51a..37393604f42 100644 --- a/sql-plugin/src/main/spark341db/scala/org/apache/spark/sql/rapids/shims/RapidsErrorUtils.scala +++ b/sql-plugin/src/main/spark341db/scala/org/apache/spark/sql/rapids/shims/RapidsErrorUtils.scala @@ -21,7 +21,9 @@ package org.apache.spark.sql.rapids.shims import org.apache.spark.sql.errors.QueryExecutionErrors -object RapidsErrorUtils extends RapidsErrorUtilsBase with RapidsQueryErrorUtils { +object RapidsErrorUtils extends RapidsErrorUtilsBase + with RapidsQueryErrorUtils + with SequenceSizeTooLongErrorBuilder { def sqlArrayIndexNotStartAtOneError(): RuntimeException = { QueryExecutionErrors.invalidIndexOfZeroError(context = null) } diff --git a/sql-plugin/src/main/spark342/scala/org/apache/spark/sql/rapids/shims/RapidsErrorUtils.scala b/sql-plugin/src/main/spark342/scala/org/apache/spark/sql/rapids/shims/RapidsErrorUtils.scala new file mode 100644 index 00000000000..b07ea3b1c7e --- /dev/null +++ b/sql-plugin/src/main/spark342/scala/org/apache/spark/sql/rapids/shims/RapidsErrorUtils.scala @@ -0,0 +1,26 @@ +/* + * Copyright (c) 2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/*** spark-rapids-shim-json-lines +{"spark": "342"} +{"spark": "343"} +{"spark": "351"} +{"spark": "352"} +spark-rapids-shim-json-lines ***/ +package org.apache.spark.sql.rapids.shims + +object RapidsErrorUtils extends RapidsErrorUtils340PlusBase + with SequenceSizeTooLongUnsuccessfulErrorBuilder diff --git a/sql-plugin/src/main/spark400/scala/org/apache/spark/sql/rapids/shims/RapidsErrorUtils.scala b/sql-plugin/src/main/spark400/scala/org/apache/spark/sql/rapids/shims/RapidsErrorUtils.scala new file mode 100644 index 00000000000..a7eca011383 --- /dev/null +++ b/sql-plugin/src/main/spark400/scala/org/apache/spark/sql/rapids/shims/RapidsErrorUtils.scala @@ -0,0 +1,23 @@ +/* + * Copyright (c) 2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/*** spark-rapids-shim-json-lines +{"spark": "400"} +spark-rapids-shim-json-lines ***/ +package org.apache.spark.sql.rapids.shims + +object RapidsErrorUtils extends RapidsErrorUtils340PlusBase + with SequenceSizeExceededLimitErrorBuilder diff --git a/sql-plugin/src/main/spark400/scala/org/apache/spark/sql/rapids/shims/SequenceSizeExceededLimitErrorBuilder.scala b/sql-plugin/src/main/spark400/scala/org/apache/spark/sql/rapids/shims/SequenceSizeExceededLimitErrorBuilder.scala new file mode 100644 index 00000000000..741634aea3f --- /dev/null +++ b/sql-plugin/src/main/spark400/scala/org/apache/spark/sql/rapids/shims/SequenceSizeExceededLimitErrorBuilder.scala @@ -0,0 +1,29 @@ +/* + * Copyright (c) 2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/*** spark-rapids-shim-json-lines +{"spark": "400"} +spark-rapids-shim-json-lines ***/ +package org.apache.spark.sql.rapids.shims + +import org.apache.spark.sql.errors.QueryExecutionErrors + +trait SequenceSizeExceededLimitErrorBuilder { + def getTooLongSequenceErrorString(sequenceSize: Int, functionName: String): String = { + QueryExecutionErrors.createArrayWithElementsExceedLimitError(functionName, sequenceSize) + .getMessage + } +}