diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuOverrides.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuOverrides.scala index 55b883a479b..1cd916cbac1 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuOverrides.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuOverrides.scala @@ -3797,14 +3797,6 @@ object GpuOverrides extends Logging { TypeSig.ARRAY.nested(TypeSig.all)), (e, conf, p, r) => new GpuGetArrayStructFieldsMeta(e, conf, p, r) ), - expr[RaiseError]( - "Throw an exception", - ExprChecks.unaryProject( - TypeSig.NULL, TypeSig.NULL, - TypeSig.STRING, TypeSig.STRING), - (a, conf, p, r) => new UnaryExprMeta[RaiseError](a, conf, p, r) { - override def convertToGpu(child: Expression): GpuExpression = GpuRaiseError(child) - }), expr[DynamicPruningExpression]( "Dynamic pruning expression marker", ExprChecks.unaryProject(TypeSig.all, TypeSig.all, TypeSig.BOOLEAN, TypeSig.BOOLEAN), @@ -3820,7 +3812,8 @@ object GpuOverrides extends Logging { val expressions: Map[Class[_ <: Expression], ExprRule[_ <: Expression]] = commonExpressions ++ TimeStamp.getExprs ++ GpuHiveOverrides.exprs ++ ZOrderRules.exprs ++ DecimalArithmeticOverrides.exprs ++ - BloomFilterShims.exprs ++ InSubqueryShims.exprs ++ SparkShimImpl.getExprs + BloomFilterShims.exprs ++ InSubqueryShims.exprs ++ RaiseErrorShim.getExprs ++ + SparkShimImpl.getExprs def wrapScan[INPUT <: Scan]( scan: INPUT, diff --git a/sql-plugin/src/main/spark311/scala/com/nvidia/spark/rapids/shims/RaiseErrorShim.scala b/sql-plugin/src/main/spark311/scala/com/nvidia/spark/rapids/shims/RaiseErrorShim.scala new file mode 100644 index 00000000000..de433d5f270 --- /dev/null +++ b/sql-plugin/src/main/spark311/scala/com/nvidia/spark/rapids/shims/RaiseErrorShim.scala @@ -0,0 +1,62 @@ +/* + * Copyright (c) 2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +/*** spark-rapids-shim-json-lines +{"spark": "311"} +{"spark": "312"} +{"spark": "313"} +{"spark": "320"} +{"spark": "321"} +{"spark": "321cdh"} +{"spark": "322"} +{"spark": "323"} +{"spark": "324"} +{"spark": "330"} +{"spark": "330cdh"} +{"spark": "330db"} +{"spark": "331"} +{"spark": "332"} +{"spark": "332cdh"} +{"spark": "332db"} +{"spark": "333"} +{"spark": "334"} +{"spark": "340"} +{"spark": "341"} +{"spark": "341db"} +{"spark": "342"} +{"spark": "343"} +{"spark": "350"} +{"spark": "351"} +spark-rapids-shim-json-lines ***/ +package com.nvidia.spark.rapids.shims + +import com.nvidia.spark.rapids.{ExprRule, GpuOverrides} +import com.nvidia.spark.rapids.{ExprChecks, GpuExpression, TypeSig, UnaryExprMeta} + +import org.apache.spark.sql.catalyst.expressions.{Expression, RaiseError} +import org.apache.spark.sql.rapids.shims.GpuRaiseError + +object RaiseErrorShim { + val exprs: Map[Class[_ <: Expression], ExprRule[_ <: Expression]] = { + Seq(GpuOverrides.expr[RaiseError]( + "Throw an exception", + ExprChecks.unaryProject( + TypeSig.NULL, TypeSig.NULL, + TypeSig.STRING, TypeSig.STRING), + (a, conf, p, r) => new UnaryExprMeta[RaiseError](a, conf, p, r) { + override def convertToGpu(child: Expression): GpuExpression = GpuRaiseError(child) + })).map(r => (r.getClassFor.asSubclass(classOf[Expression]), r)).toMap + } +} diff --git a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/misc.scala b/sql-plugin/src/main/spark311/scala/org/apache/spark/sql/rapids/shims/misc.scala similarity index 75% rename from sql-plugin/src/main/scala/org/apache/spark/sql/rapids/misc.scala rename to sql-plugin/src/main/spark311/scala/org/apache/spark/sql/rapids/shims/misc.scala index b32bdfa207c..1ab58ddcbb6 100644 --- a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/misc.scala +++ b/sql-plugin/src/main/spark311/scala/org/apache/spark/sql/rapids/shims/misc.scala @@ -13,10 +13,36 @@ * See the License for the specific language governing permissions and * limitations under the License. */ +/*** spark-rapids-shim-json-lines +{"spark": "311"} +{"spark": "312"} +{"spark": "313"} +{"spark": "320"} +{"spark": "321"} +{"spark": "321cdh"} +{"spark": "322"} +{"spark": "323"} +{"spark": "324"} +{"spark": "330"} +{"spark": "330cdh"} +{"spark": "330db"} +{"spark": "331"} +{"spark": "332"} +{"spark": "332cdh"} +{"spark": "332db"} +{"spark": "333"} +{"spark": "334"} +{"spark": "340"} +{"spark": "341"} +{"spark": "341db"} +{"spark": "342"} +{"spark": "343"} +{"spark": "350"} +{"spark": "351"} +spark-rapids-shim-json-lines ***/ +package org.apache.spark.sql.rapids.shims -package org.apache.spark.sql.rapids - -import ai.rapids.cudf.{ColumnVector} +import ai.rapids.cudf.ColumnVector import com.nvidia.spark.rapids.{GpuColumnVector, GpuUnaryExpression} import com.nvidia.spark.rapids.Arm.withResource diff --git a/sql-plugin/src/main/spark400/scala/com/nvidia/spark/rapids/shims/RaiseErrorShim.scala b/sql-plugin/src/main/spark400/scala/com/nvidia/spark/rapids/shims/RaiseErrorShim.scala new file mode 100644 index 00000000000..af199301e47 --- /dev/null +++ b/sql-plugin/src/main/spark400/scala/com/nvidia/spark/rapids/shims/RaiseErrorShim.scala @@ -0,0 +1,39 @@ +/* + * Copyright (c) 2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +/*** spark-rapids-shim-json-lines +{"spark": "400"} +spark-rapids-shim-json-lines ***/ +package com.nvidia.spark.rapids.shims + +import com.nvidia.spark.rapids.{ExprRule, GpuOverrides} +import com.nvidia.spark.rapids.{ExprChecks, GpuExpression, TypeSig, UnaryExprMeta} + +import org.apache.spark.sql.catalyst.expressions.{Expression, RaiseError} +import org.apache.spark.sql.rapids.shims.GpuRaiseError + +object RaiseErrorShim { + val exprs: Map[Class[_ <: Expression], ExprRule[_ <: Expression]] = { + Seq(GpuOverrides.expr[RaiseError]( + "Throw an exception", + ExprChecks.binaryProject( + TypeSig.NULL, TypeSig.NULL, + ("lhs", TypeSig.STRING, TypeSig.MAP.nested(TypeSig.STRING)), + ("rhs", TypeSig.STRING, TypeSig.MAP.nested(TypeSig.STRING))), + (a, conf, p, r) => new UnaryExprMeta[RaiseError](a, conf, p, r) { + override def convertToGpu(child: Expression): GpuExpression = GpuRaiseError(child) + })).map(r => (r.getClassFor.asSubclass(classOf[Expression]), r)).toMap + } +} diff --git a/sql-plugin/src/main/spark400/scala/org/apache/spark/sql/rapids/shims/misc.scala b/sql-plugin/src/main/spark400/scala/org/apache/spark/sql/rapids/shims/misc.scala new file mode 100644 index 00000000000..8d654852579 --- /dev/null +++ b/sql-plugin/src/main/spark400/scala/org/apache/spark/sql/rapids/shims/misc.scala @@ -0,0 +1,97 @@ +/* + * Copyright (c) 2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +/*** spark-rapids-shim-json-lines +{"spark": "400"} +spark-rapids-shim-json-lines ***/ +package org.apache.spark.sql.rapids.shims + +import ai.rapids.cudf.ColumnVector +import com.nvidia.spark.rapids.{GpuColumnVector, GpuUnaryExpression} +import com.nvidia.spark.rapids.Arm.withResource + +import org.apache.spark.sql.catalyst.expressions.{ExpectsInputTypes, Expression} +import org.apache.spark.sql.types.{AbstractDataType, DataType, NullType, StringType} +import org.apache.spark.sql.internal.types.StringTypeAnyCollation + +case class GpuRaiseError( + errorClass: Expression, + errorParams: Expression) extends GpuBinaryExpression with ImplicitCastInputTypes { + + def this(str: Expression) = { + this(Literal( + if (SQLConf.get.getConf(SQLConf.LEGACY_RAISE_ERROR_WITHOUT_ERROR_CLASS)) { + "_LEGACY_ERROR_USER_RAISED_EXCEPTION" + } else { + "USER_RAISED_EXCEPTION" + }), + CreateMap(Seq(Literal("errorMessage"), str))) + } + + def this(errorClass: Expression, errorParms: Expression) = { + this(errorClass, errorParms) + } + + override def dataType: DataType = NullType + + override def inputTypes: Seq[AbstractDataType] = + Seq(StringTypeAnyCollation, MapType(StringType, StringType)) + + /** Could evaluating this expression cause side-effects, such as throwing an exception? */ + override def hasSideEffects: Boolean = true + + override protected def doColumnar( + lhs: GpuColumnVector, + rhs: GpuColumnVector): ColumnVector = { + if (lhs.getRowCount <= 0) { + // For the case: when(condition, raise_error(col("a")) + return GpuColumnVector.columnVectorFromNull(0, NullType) + } + + // Take the first one as the error message + withResource(lhs.getBase.getScalarElement(0)) { errorClass => + if (!errorClass.isValid()) { + throw new RuntimeException() + } else { + withResource(rhs.getBase.getScalarElement(0)) { errorParams => + if (!errorParams.isValid()) { + throw new RuntimeException() + } else { + if (errorClass.getJavaString.equals("USER_RAISED_EXCEPTION") || + errorClass.getJavaString.equals("_LEGACY_ERROR_USER_RAISED_EXCEPTION")) { + val strMessage = GpuGetMapValue(errorParams, "errorMessage", false) + throw new RapidsAnalysisException(strMessage) + } else { + RapidsSparkThrowableHelper() + } + } + } + } + } + + override def inputTypes: Seq[AbstractDataType] = + Seq(StringTypeAnyCollation, MapType(StringType, StringType)) + + override def left: Expression = errorClass + + override def right: Expression = errorParms + + override def prettyName: String = "raise_error" + + override protected def withNewChildrenInternal( + newLeft: Expression, newRight: Expression): RaiseError = { + copy(errorClass = newLeft, errorParms = newRight) + } +}