diff --git a/common/utils/src/main/resources/error/error-conditions.json b/common/utils/src/main/resources/error/error-conditions.json index 8b266e9d6ac11..5037b52475422 100644 --- a/common/utils/src/main/resources/error/error-conditions.json +++ b/common/utils/src/main/resources/error/error-conditions.json @@ -3126,6 +3126,13 @@ ], "sqlState" : "42K08" }, + "INVALID_SQL_FUNCTION_PLAN_STRUCTURE" : { + "message" : [ + "Invalid SQL function plan structure", + "" + ], + "sqlState" : "XXKD0" + }, "INVALID_SQL_SYNTAX" : { "message" : [ "Invalid SQL syntax:" @@ -5757,6 +5764,12 @@ ], "sqlState" : "0A000" }, + "UNSUPPORTED_SQL_UDF_USAGE" : { + "message" : [ + "Using SQL function in is not supported." + ], + "sqlState" : "0A000" + }, "UNSUPPORTED_STREAMING_OPERATOR_WITHOUT_WATERMARK" : { "message" : [ " output mode not supported for on streaming DataFrames/DataSets without watermark." diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/ExpressionInfo.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/ExpressionInfo.java index 4200619d3c5f9..310d18ddb3486 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/ExpressionInfo.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/ExpressionInfo.java @@ -51,7 +51,7 @@ public class ExpressionInfo { "window_funcs", "xml_funcs", "table_funcs", "url_funcs", "variant_funcs")); private static final Set validSources = - new HashSet<>(Arrays.asList("built-in", "hive", "python_udf", "scala_udf", + new HashSet<>(Arrays.asList("built-in", "hive", "python_udf", "scala_udf", "sql_udf", "java_udf", "python_udtf", "internal")); public String getClassName() { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 9282e0554a2d4..92cfc4119dd0c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -374,6 +374,7 @@ class Analyzer(override val catalogManager: CatalogManager) extends RuleExecutor BindProcedures :: ResolveTableSpec :: ValidateAndStripPipeExpressions :: + ResolveSQLFunctions :: ResolveAliases :: ResolveSubquery :: ResolveSubqueryColumnAliases :: @@ -2364,6 +2365,277 @@ class Analyzer(override val catalogManager: CatalogManager) extends RuleExecutor } } + /** + * This rule resolves SQL function expressions. It pulls out function inputs and place them + * in a separate [[Project]] node below the operator and replace the SQL function with its + * actual function body. SQL function expressions in [[Aggregate]] are handled in a special + * way. Non-aggregated SQL functions in the aggregate expressions of an Aggregate need to be + * pulled out into a Project above the Aggregate before replacing the SQL function expressions + * with actual function bodies. For example: + * + * Before: + * Aggregate [c1] [foo(c1), foo(max(c2)), sum(foo(c2)) AS sum] + * +- Relation [c1, c2] + * + * After: + * Project [foo(c1), foo(max_c2), sum] + * +- Aggregate [c1] [c1, max(c2) AS max_c2, sum(foo(c2)) AS sum] + * +- Relation [c1, c2] + */ + object ResolveSQLFunctions extends Rule[LogicalPlan] { + + private def hasSQLFunctionExpression(exprs: Seq[Expression]): Boolean = { + exprs.exists(_.find(_.isInstanceOf[SQLFunctionExpression]).nonEmpty) + } + + /** + * Check if the function input contains aggregate expressions. + */ + private def checkFunctionInput(f: SQLFunctionExpression): Unit = { + if (f.inputs.exists(AggregateExpression.containsAggregate)) { + // The input of a SQL function should not contain aggregate functions after + // `extractAndRewrite`. If there are aggregate functions, it means they are + // nested in another aggregate function, which is not allowed. + // For example: SELECT sum(foo(sum(c1))) FROM t + // We have to throw the error here because otherwise the query plan after + // resolving the SQL function will not be valid. + throw new AnalysisException( + errorClass = "NESTED_AGGREGATE_FUNCTION", + messageParameters = Map.empty) + } + } + + /** + * Resolve a SQL function expression as a logical plan check if it can be analyzed. + */ + private def resolve(f: SQLFunctionExpression): LogicalPlan = { + // Validate the SQL function input. + checkFunctionInput(f) + val plan = v1SessionCatalog.makeSQLFunctionPlan(f.name, f.function, f.inputs) + val resolved = SQLFunctionContext.withSQLFunction { + // Resolve the SQL function plan using its context. + val conf = new SQLConf() + f.function.getSQLConfigs.foreach { case (k, v) => conf.settings.put(k, v) } + SQLConf.withExistingConf(conf) { + executeSameContext(plan) + } + } + // Fail the analysis eagerly if a SQL function cannot be resolved using its input. + SimpleAnalyzer.checkAnalysis(resolved) + resolved + } + + /** + * Rewrite SQL function expressions into actual resolved function bodies and extract + * function inputs into the given project list. + */ + private def rewriteSQLFunctions[E <: Expression]( + expression: E, + projectList: ArrayBuffer[NamedExpression]): E = { + val newExpr = expression match { + case f: SQLFunctionExpression if !hasSQLFunctionExpression(f.inputs) && + // Make sure LateralColumnAliasReference in parameters is resolved and eliminated first. + // Otherwise, the projectList can contain the LateralColumnAliasReference, which will be + // pushed down to a Project without the 'referenced' alias by LCA present, leaving it + // unresolved. + !f.inputs.exists(_.containsPattern(LATERAL_COLUMN_ALIAS_REFERENCE)) => + withPosition(f) { + val plan = resolve(f) + // Extract the function input project list from the SQL function plan and + // inline the SQL function expression. + plan match { + case Project(body :: Nil, Project(aliases, _: LocalRelation)) => + projectList ++= aliases + SQLScalarFunction(f.function, aliases.map(_.toAttribute), body) + case o => + throw new AnalysisException( + errorClass = "INVALID_SQL_FUNCTION_PLAN_STRUCTURE", + messageParameters = Map("plan" -> o.toString)) + } + } + case o => o.mapChildren(rewriteSQLFunctions(_, projectList)) + } + newExpr.asInstanceOf[E] + } + + /** + * Check if the given expression contains expressions that should be extracted, + * i.e. non-aggregated SQL functions with non-foldable inputs. + */ + private def shouldExtract(e: Expression): Boolean = e match { + // Return false if the expression is already an aggregate expression. + case _: AggregateExpression => false + case _: SQLFunctionExpression => true + case _: LeafExpression => false + case o => o.children.exists(shouldExtract) + } + + /** + * Extract aggregate expressions from the given expression and replace + * them with attribute references. + * Example: + * Before: foo(c1) + foo(max(c2)) + max(foo(c2)) + * After: foo(c1) + foo(max_c2) + max_foo_c2 + * Extracted expressions: [c1, max(c2) AS max_c2, max(foo(c2)) AS max_foo_c2] + */ + private def extractAndRewrite[T <: Expression]( + expression: T, + extractedExprs: ArrayBuffer[NamedExpression]): T = { + val newExpr = expression match { + case e if !shouldExtract(e) => + val exprToAdd: NamedExpression = e match { + case o: OuterReference => Alias(o, toPrettySQL(o.e))() + case ne: NamedExpression => ne + case o => Alias(o, toPrettySQL(o))() + } + extractedExprs += exprToAdd + exprToAdd.toAttribute + case f: SQLFunctionExpression => + val newInputs = f.inputs.map(extractAndRewrite(_, extractedExprs)) + f.copy(inputs = newInputs) + case o => o.mapChildren(extractAndRewrite(_, extractedExprs)) + } + newExpr.asInstanceOf[T] + } + + /** + * Replace all [[SQLFunctionExpression]]s in an expression with attribute references + * from the aliasMap. + */ + private def replaceSQLFunctionWithAttr[T <: Expression]( + expr: T, + aliasMap: mutable.HashMap[Expression, Alias]): T = { + expr.transform { + case f: SQLFunctionExpression if aliasMap.contains(f.canonicalized) => + aliasMap(f.canonicalized).toAttribute + }.asInstanceOf[T] + } + + private def rewrite(plan: LogicalPlan): LogicalPlan = plan match { + // Return if a sub-tree does not contain SQLFunctionExpression. + case p: LogicalPlan if !p.containsPattern(SQL_FUNCTION_EXPRESSION) => p + + case f @ Filter(cond, a: Aggregate) + if !f.resolved || AggregateExpression.containsAggregate(cond) || + ResolveGroupingAnalytics.hasGroupingFunction(cond) || + cond.containsPattern(TEMP_RESOLVED_COLUMN) => + // If the filter's condition contains aggregate expressions or grouping expressions or temp + // resolved column, we cannot rewrite both the filter and the aggregate until they are + // resolved by ResolveAggregateFunctions or ResolveGroupingAnalytics, because rewriting SQL + // functions in aggregate can add an additional project on top of the aggregate + // which breaks the pattern matching in those rules. + f.copy(child = a.copy(child = rewrite(a.child))) + + case h @ UnresolvedHaving(_, a: Aggregate) => + // Similarly UnresolvedHaving should be resolved by ResolveAggregateFunctions first + // before rewriting aggregate. + h.copy(child = a.copy(child = rewrite(a.child))) + + case a: Aggregate if a.resolved && hasSQLFunctionExpression(a.expressions) => + val child = rewrite(a.child) + // Extract SQL functions in the grouping expressions and place them in a project list + // below the current aggregate. Also update their appearances in the aggregate expressions. + val bottomProjectList = ArrayBuffer.empty[NamedExpression] + val aliasMap = mutable.HashMap.empty[Expression, Alias] + val newGrouping = a.groupingExpressions.map { expr => + expr.transformDown { + case f: SQLFunctionExpression => + val alias = aliasMap.getOrElseUpdate(f.canonicalized, Alias(f, f.name)()) + bottomProjectList += alias + alias.toAttribute + } + } + val aggregateExpressions = a.aggregateExpressions.map( + replaceSQLFunctionWithAttr(_, aliasMap)) + + // Rewrite SQL functions in the aggregate expressions that are not wrapped in + // aggregate functions. They need to be extracted into a project list above the + // current aggregate. + val aggExprs = ArrayBuffer.empty[NamedExpression] + val topProjectList = aggregateExpressions.map(extractAndRewrite(_, aggExprs)) + + // Rewrite SQL functions in the new aggregate expressions that are wrapped inside + // aggregate functions. + val newAggExprs = aggExprs.map(rewriteSQLFunctions(_, bottomProjectList)) + + val bottomProject = if (bottomProjectList.nonEmpty) { + Project(child.output ++ bottomProjectList, child) + } else { + child + } + val newAgg = if (newGrouping.nonEmpty || newAggExprs.nonEmpty) { + a.copy( + groupingExpressions = newGrouping, + aggregateExpressions = newAggExprs.toSeq, + child = bottomProject) + } else { + bottomProject + } + if (topProjectList.nonEmpty) Project(topProjectList, newAgg) else newAgg + + case p: Project if p.resolved && hasSQLFunctionExpression(p.expressions) => + val newChild = rewrite(p.child) + val projectList = ArrayBuffer.empty[NamedExpression] + val newPList = p.projectList.map(rewriteSQLFunctions(_, projectList)) + if (newPList != newChild.output) { + p.copy(newPList, Project(newChild.output ++ projectList, newChild)) + } else { + assert(projectList.isEmpty) + p.copy(child = newChild) + } + + case f: Filter if f.resolved && hasSQLFunctionExpression(f.expressions) => + val newChild = rewrite(f.child) + val projectList = ArrayBuffer.empty[NamedExpression] + val newCond = rewriteSQLFunctions(f.condition, projectList) + if (newCond != f.condition) { + Project(f.output, Filter(newCond, Project(newChild.output ++ projectList, newChild))) + } else { + assert(projectList.isEmpty) + f.copy(child = newChild) + } + + case j: Join if j.resolved && hasSQLFunctionExpression(j.expressions) => + val newLeft = rewrite(j.left) + val newRight = rewrite(j.right) + val projectList = ArrayBuffer.empty[NamedExpression] + val joinCond = j.condition.map(rewriteSQLFunctions(_, projectList)) + if (joinCond != j.condition) { + // Join condition cannot have non-deterministic expressions. We can safely + // replace the aliases with the original SQL function input expressions. + val aliasMap = projectList.collect { case a: Alias => a.toAttribute -> a.child }.toMap + val newJoinCond = joinCond.map(_.transform { + case a: Attribute => aliasMap.getOrElse(a, a) + }) + j.copy(left = newLeft, right = newRight, condition = newJoinCond) + } else { + assert(projectList.isEmpty) + j.copy(left = newLeft, right = newRight) + } + + case o: LogicalPlan if o.resolved && hasSQLFunctionExpression(o.expressions) => + o.transformExpressionsWithPruning(_.containsPattern(SQL_FUNCTION_EXPRESSION)) { + case f: SQLFunctionExpression => + f.failAnalysis( + errorClass = "UNSUPPORTED_SQL_UDF_USAGE", + messageParameters = Map( + "functionName" -> toSQLId(f.function.name.nameParts), + "nodeName" -> o.nodeName.toString)) + } + + case p: LogicalPlan => p.mapChildren(rewrite) + } + + def apply(plan: LogicalPlan): LogicalPlan = { + // Only rewrite SQL functions when they are not in nested function calls. + if (SQLFunctionContext.get.nestedSQLFunctionDepth > 0) { + plan + } else { + rewrite(plan) + } + } + } + /** * Turns projections that contain aggregate expressions into aggregations. */ diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala index 46ca8e793218b..0a68524c31241 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala @@ -1106,6 +1106,8 @@ trait CheckAnalysis extends PredicateHelper with LookupCatalog with QueryErrorsB @scala.annotation.tailrec def cleanQueryInScalarSubquery(p: LogicalPlan): LogicalPlan = p match { case s: SubqueryAlias => cleanQueryInScalarSubquery(s.child) + // Skip SQL function node added by the Analyzer + case s: SQLFunctionNode => cleanQueryInScalarSubquery(s.child) case p: Project => cleanQueryInScalarSubquery(p.child) case h: ResolvedHint => cleanQueryInScalarSubquery(h.child) case child => child diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/SQLFunctionExpression.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/SQLFunctionExpression.scala index fb6935d64d4c4..37981f47287da 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/SQLFunctionExpression.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/SQLFunctionExpression.scala @@ -18,8 +18,8 @@ package org.apache.spark.sql.catalyst.analysis import org.apache.spark.sql.catalyst.catalog.SQLFunction -import org.apache.spark.sql.catalyst.expressions.{Expression, Unevaluable} -import org.apache.spark.sql.catalyst.trees.TreePattern.{SQL_FUNCTION_EXPRESSION, TreePattern} +import org.apache.spark.sql.catalyst.expressions.{Expression, UnaryExpression, Unevaluable} +import org.apache.spark.sql.catalyst.trees.TreePattern.{SQL_FUNCTION_EXPRESSION, SQL_SCALAR_FUNCTION, TreePattern} import org.apache.spark.sql.types.DataType /** @@ -39,3 +39,52 @@ case class SQLFunctionExpression( newChildren: IndexedSeq[Expression]): SQLFunctionExpression = copy(inputs = newChildren) final override val nodePatterns: Seq[TreePattern] = Seq(SQL_FUNCTION_EXPRESSION) } + +/** + * A wrapper node for a SQL scalar function expression. + */ +case class SQLScalarFunction(function: SQLFunction, inputs: Seq[Expression], child: Expression) + extends UnaryExpression with Unevaluable { + override def dataType: DataType = child.dataType + override def toString: String = s"${function.name}(${inputs.mkString(", ")})" + override def sql: String = s"${function.name}(${inputs.map(_.sql).mkString(", ")})" + override protected def withNewChildInternal(newChild: Expression): SQLScalarFunction = { + copy(child = newChild) + } + final override val nodePatterns: Seq[TreePattern] = Seq(SQL_SCALAR_FUNCTION) + // The `inputs` is for display only and does not matter in execution. + override lazy val canonicalized: Expression = copy(inputs = Nil, child = child.canonicalized) + override lazy val deterministic: Boolean = { + function.deterministic.getOrElse(true) && children.forall(_.deterministic) + } +} + +/** + * Provide a way to keep state during analysis for resolving nested SQL functions. + * + * @param nestedSQLFunctionDepth The nested depth in the SQL function resolution. A SQL function + * expression should only be expanded as a [[SQLScalarFunction]] if + * the nested depth is 0. + */ +case class SQLFunctionContext(nestedSQLFunctionDepth: Int = 0) + +object SQLFunctionContext { + + private val value = new ThreadLocal[SQLFunctionContext]() { + override def initialValue: SQLFunctionContext = SQLFunctionContext() + } + + def get: SQLFunctionContext = value.get() + + def reset(): Unit = value.remove() + + private def set(context: SQLFunctionContext): Unit = value.set(context) + + def withSQLFunction[A](f: => A): A = { + val originContext = value.get() + val context = originContext.copy( + nestedSQLFunctionDepth = originContext.nestedSQLFunctionDepth + 1) + set(context) + try f finally { set(originContext) } + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala index 3c6dfe5ac8445..b123952c5f086 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala @@ -38,9 +38,9 @@ import org.apache.spark.sql.catalyst.analysis._ import org.apache.spark.sql.catalyst.analysis.FunctionRegistry.FunctionBuilder import org.apache.spark.sql.catalyst.analysis.TableFunctionRegistry.TableFunctionBuilder import org.apache.spark.sql.catalyst.catalog.SQLFunction.parseDefault -import org.apache.spark.sql.catalyst.expressions.{Alias, AttributeReference, Cast, Expression, ExpressionInfo, NamedArgumentExpression, NamedExpression, UpCast} +import org.apache.spark.sql.catalyst.expressions.{Alias, AttributeReference, Cast, Expression, ExpressionInfo, NamedArgumentExpression, NamedExpression, ScalarSubquery, UpCast} import org.apache.spark.sql.catalyst.parser.{CatalystSqlParser, ParserInterface} -import org.apache.spark.sql.catalyst.plans.logical.{FunctionSignature, InputParameter, LogicalPlan, NamedParametersSupport, Project, SubqueryAlias, View} +import org.apache.spark.sql.catalyst.plans.logical.{FunctionSignature, InputParameter, LocalRelation, LogicalPlan, NamedParametersSupport, Project, SubqueryAlias, View} import org.apache.spark.sql.catalyst.trees.CurrentOrigin import org.apache.spark.sql.catalyst.util.{CharVarcharUtils, StringUtils} import org.apache.spark.sql.connector.catalog.CatalogManager @@ -48,7 +48,7 @@ import org.apache.spark.sql.connector.catalog.CatalogManager.SESSION_CATALOG_NAM import org.apache.spark.sql.errors.{QueryCompilationErrors, QueryExecutionErrors} import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.internal.StaticSQLConf.GLOBAL_TEMP_DATABASE -import org.apache.spark.sql.types.{StructField, StructType} +import org.apache.spark.sql.types.{MetadataBuilder, StructField, StructType} import org.apache.spark.sql.util.{CaseInsensitiveStringMap, PartitioningUtils} import org.apache.spark.util.ArrayImplicits._ import org.apache.spark.util.Utils @@ -1561,6 +1561,103 @@ class SessionCatalog( } } + /** + * Constructs a scalar SQL function logical plan. The logical plan will be used to + * construct actual expression from the function inputs and body. + * + * The body of a scalar SQL function can either be an expression or a query returns + * one single column. + * + * Example scalar SQL function with an expression: + * + * CREATE FUNCTION area(width DOUBLE, height DOUBLE) RETURNS DOUBLE + * RETURN width * height; + * + * Query: + * + * SELECT area(a, b) FROM t; + * + * SQL function plan: + * + * Project [CAST(width * height AS DOUBLE) AS area] + * +- Project [CAST(a AS DOUBLE) AS width, CAST(b AS DOUBLE) AS height] + * +- LocalRelation [a, b] + * + * Example scalar SQL function with a subquery: + * + * CREATE FUNCTION foo(x INT) RETURNS INT + * RETURN SELECT SUM(b) FROM t WHERE x = a; + * + * SELECT foo(a) FROM t; + * + * SQL function plan: + * + * Project [scalar-subquery AS foo] + * : +- Aggregate [] [sum(b)] + * : +- Filter [outer(x) = a] + * : +- Relation [a, b] + * +- Project [CAST(a AS INT) AS x] + * +- LocalRelation [a, b] + */ + def makeSQLFunctionPlan( + name: String, + function: SQLFunction, + input: Seq[Expression]): LogicalPlan = { + def metaForFuncInputAlias = { + new MetadataBuilder() + .putString("__funcInputAlias", "true") + .build() + } + assert(!function.isTableFunc) + val funcName = function.name.funcName + + // Use captured SQL configs when parsing a SQL function. + val conf = new SQLConf() + function.getSQLConfigs.foreach { case (k, v) => conf.settings.put(k, v) } + SQLConf.withExistingConf(conf) { + val inputParam = function.inputParam + val returnType = function.getScalarFuncReturnType + val (expression, query) = function.getExpressionAndQuery(parser, isTableFunc = false) + assert(expression.isDefined || query.isDefined) + + // Check function arguments + val paramSize = inputParam.map(_.size).getOrElse(0) + if (input.size > paramSize) { + throw QueryCompilationErrors.wrongNumArgsError( + name, paramSize.toString, input.size) + } + + val inputs = inputParam.map { param => + // Attributes referencing the input parameters inside the function can use the + // function name as a qualifier. E.G.: + // `create function foo(a int) returns int return foo.a` + val qualifier = Seq(funcName) + val paddedInput = input ++ + param.takeRight(paramSize - input.size).map { p => + val defaultExpr = p.getDefault() + if (defaultExpr.isDefined) { + Cast(parseDefault(defaultExpr.get, parser), p.dataType) + } else { + throw QueryCompilationErrors.wrongNumArgsError( + name, paramSize.toString, input.size) + } + } + + paddedInput.zip(param.fields).map { + case (expr, param) => + Alias(Cast(expr, param.dataType), param.name)( + qualifier = qualifier, + // mark the alias as function input + explicitMetadata = Some(metaForFuncInputAlias)) + } + }.getOrElse(Nil) + + val body = if (query.isDefined) ScalarSubquery(query.get) else expression.get + Project(Alias(Cast(body, returnType), funcName)() :: Nil, + Project(inputs, LocalRelation(inputs.flatMap(_.references)))) + } + } + /** * Constructs a [[TableFunctionBuilder]] based on the provided class that represents a function. */ diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/UserDefinedFunction.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/UserDefinedFunction.scala index b00cae22cf9c0..a76ca7b15c278 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/UserDefinedFunction.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/UserDefinedFunction.scala @@ -45,6 +45,14 @@ trait UserDefinedFunction { */ def properties: Map[String, String] + /** + * Get SQL configs from the function properties. + * Use this to restore the SQL configs that should be used for this function. + */ + def getSQLConfigs: Map[String, String] = { + UserDefinedFunction.propertiesToSQLConfigs(properties) + } + /** * Owner of the function */ @@ -142,4 +150,17 @@ object UserDefinedFunction { * Verify if the function is a [[UserDefinedFunction]]. */ def isUserDefinedFunction(className: String): Boolean = SQLFunction.isSQLFunction(className) + + /** + * Covert properties to SQL configs. + */ + def propertiesToSQLConfigs(properties: Map[String, String]): Map[String, String] = { + try { + for ((key, value) <- properties if key.startsWith(SQL_CONFIG_PREFIX)) + yield (key.substring(SQL_CONFIG_PREFIX.length), value) + } catch { + case e: Exception => throw SparkException.internalError( + "Corrupted user defined function SQL configs in catalog", cause = e) + } + } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/EliminateSQLFunctionNode.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/EliminateSQLFunctionNode.scala new file mode 100644 index 0000000000000..d9da38b4c2af4 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/EliminateSQLFunctionNode.scala @@ -0,0 +1,47 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.optimizer + +import org.apache.spark.SparkException +import org.apache.spark.sql.catalyst.analysis.{SQLFunctionExpression, SQLFunctionNode, SQLScalarFunction, SQLTableFunction} +import org.apache.spark.sql.catalyst.plans.logical._ +import org.apache.spark.sql.catalyst.rules.Rule + +/** + * This rule removes [[SQLScalarFunction]] and [[SQLFunctionNode]] wrapper. They are respected + * till the end of analysis stage because we want to see which part of an analyzed logical + * plan is generated from a SQL function and also perform ACL checks. + */ +object EliminateSQLFunctionNode extends Rule[LogicalPlan] { + override def apply(plan: LogicalPlan): LogicalPlan = { + // Include subqueries when eliminating SQL function expressions otherwise we might miss + // expressions in subqueries which can be inlined by the rule `OptimizeOneRowRelationSubquery`. + plan.transformWithSubqueries { + case SQLFunctionNode(_, child) => child + case f: SQLTableFunction => + throw SparkException.internalError( + s"SQL table function plan should be rewritten during analysis: $f") + case p: LogicalPlan => p.transformExpressions { + case f: SQLScalarFunction => f.child + case f: SQLFunctionExpression => + throw SparkException.internalError( + s"SQL function expression should be rewritten during analysis: $f") + } + } + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index 8ee2226947ec9..9d269f37e58b9 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -315,6 +315,7 @@ abstract class Optimizer(catalogManager: CatalogManager) EliminateSubqueryAliases, EliminatePipeOperators, EliminateView, + EliminateSQLFunctionNode, ReplaceExpressions, RewriteNonCorrelatedExists, PullOutGroupingExpressions, diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreePatterns.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreePatterns.scala index b56085ecae8d6..9856a26346f6a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreePatterns.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreePatterns.scala @@ -93,6 +93,7 @@ object TreePattern extends Enumeration { val SESSION_WINDOW: Value = Value val SORT: Value = Value val SQL_FUNCTION_EXPRESSION: Value = Value + val SQL_SCALAR_FUNCTION: Value = Value val SQL_TABLE_FUNCTION: Value = Value val SUBQUERY_ALIAS: Value = Value val SUM: Value = Value diff --git a/sql/core/src/test/resources/sql-tests/analyzer-results/sql-udf.sql.out b/sql/core/src/test/resources/sql-tests/analyzer-results/sql-udf.sql.out new file mode 100644 index 0000000000000..b3c10e929f297 --- /dev/null +++ b/sql/core/src/test/resources/sql-tests/analyzer-results/sql-udf.sql.out @@ -0,0 +1,575 @@ +-- Automatically generated by SQLQueryTestSuite +-- !query +CREATE FUNCTION foo1a0() RETURNS INT RETURN 1 +-- !query analysis +org.apache.spark.sql.catalyst.analysis.FunctionAlreadyExistsException +{ + "errorClass" : "ROUTINE_ALREADY_EXISTS", + "sqlState" : "42723", + "messageParameters" : { + "existingRoutineType" : "routine", + "newRoutineType" : "routine", + "routineName" : "`default`.`foo1a0`" + } +} + + +-- !query +SELECT foo1a0() +-- !query analysis +Project [spark_catalog.default.foo1a0() AS spark_catalog.default.foo1a0()#x] ++- Project + +- OneRowRelation + + +-- !query +SELECT foo1a0(1) +-- !query analysis +org.apache.spark.sql.AnalysisException +{ + "errorClass" : "WRONG_NUM_ARGS.WITHOUT_SUGGESTION", + "sqlState" : "42605", + "messageParameters" : { + "actualNum" : "1", + "docroot" : "https://spark.apache.org/docs/latest", + "expectedNum" : "0", + "functionName" : "`spark_catalog`.`default`.`foo1a0`" + }, + "queryContext" : [ { + "objectType" : "", + "objectName" : "", + "startIndex" : 8, + "stopIndex" : 16, + "fragment" : "foo1a0(1)" + } ] +} + + +-- !query +CREATE FUNCTION foo1a1(a INT) RETURNS INT RETURN 1 +-- !query analysis +org.apache.spark.sql.catalyst.analysis.FunctionAlreadyExistsException +{ + "errorClass" : "ROUTINE_ALREADY_EXISTS", + "sqlState" : "42723", + "messageParameters" : { + "existingRoutineType" : "routine", + "newRoutineType" : "routine", + "routineName" : "`default`.`foo1a1`" + } +} + + +-- !query +SELECT foo1a1(1) +-- !query analysis +Project [spark_catalog.default.foo1a1(a#x) AS spark_catalog.default.foo1a1(1)#x] ++- Project [cast(1 as int) AS a#x] + +- OneRowRelation + + +-- !query +SELECT foo1a1(1, 2) +-- !query analysis +org.apache.spark.sql.AnalysisException +{ + "errorClass" : "WRONG_NUM_ARGS.WITHOUT_SUGGESTION", + "sqlState" : "42605", + "messageParameters" : { + "actualNum" : "2", + "docroot" : "https://spark.apache.org/docs/latest", + "expectedNum" : "1", + "functionName" : "`spark_catalog`.`default`.`foo1a1`" + }, + "queryContext" : [ { + "objectType" : "", + "objectName" : "", + "startIndex" : 8, + "stopIndex" : 19, + "fragment" : "foo1a1(1, 2)" + } ] +} + + +-- !query +CREATE FUNCTION foo1a2(a INT, b INT, c INT, d INT) RETURNS INT RETURN 1 +-- !query analysis +org.apache.spark.sql.catalyst.analysis.FunctionAlreadyExistsException +{ + "errorClass" : "ROUTINE_ALREADY_EXISTS", + "sqlState" : "42723", + "messageParameters" : { + "existingRoutineType" : "routine", + "newRoutineType" : "routine", + "routineName" : "`default`.`foo1a2`" + } +} + + +-- !query +SELECT foo1a2(1, 2, 3, 4) +-- !query analysis +Project [spark_catalog.default.foo1a2(a#x, b#x, c#x, d#x) AS spark_catalog.default.foo1a2(1, 2, 3, 4)#x] ++- Project [cast(1 as int) AS a#x, cast(2 as int) AS b#x, cast(3 as int) AS c#x, cast(4 as int) AS d#x] + +- OneRowRelation + + +-- !query +CREATE FUNCTION foo2_1a(a INT) RETURNS INT RETURN a +-- !query analysis +org.apache.spark.sql.catalyst.analysis.FunctionAlreadyExistsException +{ + "errorClass" : "ROUTINE_ALREADY_EXISTS", + "sqlState" : "42723", + "messageParameters" : { + "existingRoutineType" : "routine", + "newRoutineType" : "routine", + "routineName" : "`default`.`foo2_1a`" + } +} + + +-- !query +SELECT foo2_1a(5) +-- !query analysis +Project [spark_catalog.default.foo2_1a(a#x) AS spark_catalog.default.foo2_1a(5)#x] ++- Project [cast(5 as int) AS a#x] + +- OneRowRelation + + +-- !query +CREATE FUNCTION foo2_1b(a INT, b INT) RETURNS INT RETURN a + b +-- !query analysis +org.apache.spark.sql.catalyst.analysis.FunctionAlreadyExistsException +{ + "errorClass" : "ROUTINE_ALREADY_EXISTS", + "sqlState" : "42723", + "messageParameters" : { + "existingRoutineType" : "routine", + "newRoutineType" : "routine", + "routineName" : "`default`.`foo2_1b`" + } +} + + +-- !query +SELECT foo2_1b(5, 6) +-- !query analysis +Project [spark_catalog.default.foo2_1b(a#x, b#x) AS spark_catalog.default.foo2_1b(5, 6)#x] ++- Project [cast(5 as int) AS a#x, cast(6 as int) AS b#x] + +- OneRowRelation + + +-- !query +CREATE FUNCTION foo2_1c(a INT, b INT) RETURNS INT RETURN 10 * (a + b) + 100 * (a -b) +-- !query analysis +org.apache.spark.sql.catalyst.analysis.FunctionAlreadyExistsException +{ + "errorClass" : "ROUTINE_ALREADY_EXISTS", + "sqlState" : "42723", + "messageParameters" : { + "existingRoutineType" : "routine", + "newRoutineType" : "routine", + "routineName" : "`default`.`foo2_1c`" + } +} + + +-- !query +SELECT foo2_1c(5, 6) +-- !query analysis +Project [spark_catalog.default.foo2_1c(a#x, b#x) AS spark_catalog.default.foo2_1c(5, 6)#x] ++- Project [cast(5 as int) AS a#x, cast(6 as int) AS b#x] + +- OneRowRelation + + +-- !query +CREATE FUNCTION foo2_1d(a INT, b INT) RETURNS INT RETURN ABS(a) - LENGTH(CAST(b AS VARCHAR(10))) +-- !query analysis +org.apache.spark.sql.catalyst.analysis.FunctionAlreadyExistsException +{ + "errorClass" : "ROUTINE_ALREADY_EXISTS", + "sqlState" : "42723", + "messageParameters" : { + "existingRoutineType" : "routine", + "newRoutineType" : "routine", + "routineName" : "`default`.`foo2_1d`" + } +} + + +-- !query +SELECT foo2_1d(-5, 6) +-- !query analysis +Project [spark_catalog.default.foo2_1d(a#x, b#x) AS spark_catalog.default.foo2_1d(-5, 6)#x] ++- Project [cast(-5 as int) AS a#x, cast(6 as int) AS b#x] + +- OneRowRelation + + +-- !query +CREATE FUNCTION foo2_2a(a INT) RETURNS INT RETURN SELECT a +-- !query analysis +org.apache.spark.sql.catalyst.analysis.FunctionAlreadyExistsException +{ + "errorClass" : "ROUTINE_ALREADY_EXISTS", + "sqlState" : "42723", + "messageParameters" : { + "existingRoutineType" : "routine", + "newRoutineType" : "routine", + "routineName" : "`default`.`foo2_2a`" + } +} + + +-- !query +SELECT foo2_2a(5) +-- !query analysis +Project [spark_catalog.default.foo2_2a(a#x) AS spark_catalog.default.foo2_2a(5)#x] ++- Project [cast(5 as int) AS a#x] + +- OneRowRelation + + +-- !query +CREATE FUNCTION foo2_2b(a INT) RETURNS INT RETURN 1 + (SELECT a) +-- !query analysis +org.apache.spark.sql.catalyst.analysis.FunctionAlreadyExistsException +{ + "errorClass" : "ROUTINE_ALREADY_EXISTS", + "sqlState" : "42723", + "messageParameters" : { + "existingRoutineType" : "routine", + "newRoutineType" : "routine", + "routineName" : "`default`.`foo2_2b`" + } +} + + +-- !query +SELECT foo2_2b(5) +-- !query analysis +Project [spark_catalog.default.foo2_2b(a#x) AS spark_catalog.default.foo2_2b(5)#x] +: +- Project [outer(a#x)] +: +- OneRowRelation ++- Project [cast(5 as int) AS a#x] + +- OneRowRelation + + +-- !query +CREATE FUNCTION foo2_2c(a INT) RETURNS INT RETURN 1 + (SELECT (SELECT a)) +-- !query analysis +org.apache.spark.sql.catalyst.ExtendedAnalysisException +{ + "errorClass" : "UNRESOLVED_COLUMN.WITHOUT_SUGGESTION", + "sqlState" : "42703", + "messageParameters" : { + "objectName" : "`a`" + }, + "queryContext" : [ { + "objectType" : "", + "objectName" : "", + "startIndex" : 21, + "stopIndex" : 21, + "fragment" : "a" + } ] +} + + +-- !query +CREATE FUNCTION foo2_2d(a INT) RETURNS INT RETURN 1 + (SELECT (SELECT (SELECT (SELECT a)))) +-- !query analysis +org.apache.spark.sql.catalyst.ExtendedAnalysisException +{ + "errorClass" : "UNRESOLVED_COLUMN.WITHOUT_SUGGESTION", + "sqlState" : "42703", + "messageParameters" : { + "objectName" : "`a`" + }, + "queryContext" : [ { + "objectType" : "", + "objectName" : "", + "startIndex" : 37, + "stopIndex" : 37, + "fragment" : "a" + } ] +} + + +-- !query +CREATE FUNCTION foo2_2e(a INT) RETURNS INT RETURN +SELECT a FROM (VALUES 1) AS V(c1) WHERE c1 = 2 +UNION ALL +SELECT a + 1 FROM (VALUES 1) AS V(c1) +-- !query analysis +org.apache.spark.sql.catalyst.analysis.FunctionAlreadyExistsException +{ + "errorClass" : "ROUTINE_ALREADY_EXISTS", + "sqlState" : "42723", + "messageParameters" : { + "existingRoutineType" : "routine", + "newRoutineType" : "routine", + "routineName" : "`default`.`foo2_2e`" + } +} + + +-- !query +CREATE FUNCTION foo2_2f(a INT) RETURNS INT RETURN +SELECT a FROM (VALUES 1) AS V(c1) +EXCEPT +SELECT a + 1 FROM (VALUES 1) AS V(a) +-- !query analysis +org.apache.spark.sql.catalyst.analysis.FunctionAlreadyExistsException +{ + "errorClass" : "ROUTINE_ALREADY_EXISTS", + "sqlState" : "42723", + "messageParameters" : { + "existingRoutineType" : "routine", + "newRoutineType" : "routine", + "routineName" : "`default`.`foo2_2f`" + } +} + + +-- !query +CREATE FUNCTION foo2_2g(a INT) RETURNS INT RETURN +SELECT a FROM (VALUES 1) AS V(c1) +INTERSECT +SELECT a FROM (VALUES 1) AS V(a) +-- !query analysis +org.apache.spark.sql.catalyst.analysis.FunctionAlreadyExistsException +{ + "errorClass" : "ROUTINE_ALREADY_EXISTS", + "sqlState" : "42723", + "messageParameters" : { + "existingRoutineType" : "routine", + "newRoutineType" : "routine", + "routineName" : "`default`.`foo2_2g`" + } +} + + +-- !query +DROP TABLE IF EXISTS t1 +-- !query analysis +DropTable true, false ++- ResolvedIdentifier V2SessionCatalog(spark_catalog), default.t1 + + +-- !query +DROP TABLE IF EXISTS t2 +-- !query analysis +DropTable true, false ++- ResolvedIdentifier V2SessionCatalog(spark_catalog), default.t2 + + +-- !query +DROP TABLE IF EXISTS ts +-- !query analysis +DropTable true, false ++- ResolvedIdentifier V2SessionCatalog(spark_catalog), default.ts + + +-- !query +DROP TABLE IF EXISTS tm +-- !query analysis +DropTable true, false ++- ResolvedIdentifier V2SessionCatalog(spark_catalog), default.tm + + +-- !query +DROP TABLE IF EXISTS ta +-- !query analysis +DropTable true, false ++- ResolvedIdentifier V2SessionCatalog(spark_catalog), default.ta + + +-- !query +DROP TABLE IF EXISTS V1 +-- !query analysis +DropTable true, false ++- ResolvedIdentifier V2SessionCatalog(spark_catalog), default.V1 + + +-- !query +DROP TABLE IF EXISTS V2 +-- !query analysis +DropTable true, false ++- ResolvedIdentifier V2SessionCatalog(spark_catalog), default.V2 + + +-- !query +DROP VIEW IF EXISTS t1 +-- !query analysis +DropTableCommand `spark_catalog`.`default`.`t1`, true, true, false + + +-- !query +DROP VIEW IF EXISTS t2 +-- !query analysis +DropTableCommand `spark_catalog`.`default`.`t2`, true, true, false + + +-- !query +DROP VIEW IF EXISTS ts +-- !query analysis +DropTableCommand `spark_catalog`.`default`.`ts`, true, true, false + + +-- !query +DROP VIEW IF EXISTS tm +-- !query analysis +DropTableCommand `spark_catalog`.`default`.`tm`, true, true, false + + +-- !query +DROP VIEW IF EXISTS ta +-- !query analysis +DropTableCommand `spark_catalog`.`default`.`ta`, true, true, false + + +-- !query +DROP VIEW IF EXISTS V1 +-- !query analysis +DropTableCommand `spark_catalog`.`default`.`V1`, true, true, false + + +-- !query +DROP VIEW IF EXISTS V2 +-- !query analysis +DropTableCommand `spark_catalog`.`default`.`V2`, true, true, false + + +-- !query +CREATE FUNCTION foo2_3(a INT, b INT) RETURNS INT RETURN a + b +-- !query analysis +org.apache.spark.sql.catalyst.analysis.FunctionAlreadyExistsException +{ + "errorClass" : "ROUTINE_ALREADY_EXISTS", + "sqlState" : "42723", + "messageParameters" : { + "existingRoutineType" : "routine", + "newRoutineType" : "routine", + "routineName" : "`default`.`foo2_3`" + } +} + + +-- !query +CREATE VIEW V1(c1, c2) AS VALUES (1, 2), (3, 4), (5, 6) +-- !query analysis +CreateViewCommand `spark_catalog`.`default`.`V1`, [(c1,None), (c2,None)], VALUES (1, 2), (3, 4), (5, 6), false, false, PersistedView, COMPENSATION, true + +- LocalRelation [col1#x, col2#x] + + +-- !query +CREATE VIEW V2(c1, c2) AS VALUES (-1, -2), (-3, -4), (-5, -6) +-- !query analysis +CreateViewCommand `spark_catalog`.`default`.`V2`, [(c1,None), (c2,None)], VALUES (-1, -2), (-3, -4), (-5, -6), false, false, PersistedView, COMPENSATION, true + +- LocalRelation [col1#x, col2#x] + + +-- !query +SELECT foo2_3(c1, c2), foo2_3(c2, 1), foo2_3(c1, c2) - foo2_3(c2, c1 - 1) FROM V1 ORDER BY 1, 2, 3 +-- !query analysis +Sort [spark_catalog.default.foo2_3(c1, c2)#x ASC NULLS FIRST, spark_catalog.default.foo2_3(c2, 1)#x ASC NULLS FIRST, (spark_catalog.default.foo2_3(c1, c2) - spark_catalog.default.foo2_3(c2, (c1 - 1)))#x ASC NULLS FIRST], true ++- Project [spark_catalog.default.foo2_3(a#x, b#x) AS spark_catalog.default.foo2_3(c1, c2)#x, spark_catalog.default.foo2_3(a#x, b#x) AS spark_catalog.default.foo2_3(c2, 1)#x, (spark_catalog.default.foo2_3(a#x, b#x) - spark_catalog.default.foo2_3(a#x, b#x)) AS (spark_catalog.default.foo2_3(c1, c2) - spark_catalog.default.foo2_3(c2, (c1 - 1)))#x] + +- Project [c1#x, c2#x, cast(c1#x as int) AS a#x, cast(c2#x as int) AS b#x, cast(c2#x as int) AS a#x, cast(1 as int) AS b#x, cast(c1#x as int) AS a#x, cast(c2#x as int) AS b#x, cast(c2#x as int) AS a#x, cast((c1#x - 1) as int) AS b#x] + +- SubqueryAlias spark_catalog.default.v1 + +- View (`spark_catalog`.`default`.`v1`, [c1#x, c2#x]) + +- Project [cast(col1#x as int) AS c1#x, cast(col2#x as int) AS c2#x] + +- LocalRelation [col1#x, col2#x] + + +-- !query +SELECT * FROM V1 WHERE foo2_3(c1, 0) = c1 AND foo2_3(c1, c2) < 8 +-- !query analysis +Project [c1#x, c2#x] ++- Project [c1#x, c2#x] + +- Filter ((spark_catalog.default.foo2_3(a#x, b#x) = c1#x) AND (spark_catalog.default.foo2_3(a#x, b#x) < 8)) + +- Project [c1#x, c2#x, cast(c1#x as int) AS a#x, cast(0 as int) AS b#x, cast(c1#x as int) AS a#x, cast(c2#x as int) AS b#x] + +- SubqueryAlias spark_catalog.default.v1 + +- View (`spark_catalog`.`default`.`v1`, [c1#x, c2#x]) + +- Project [cast(col1#x as int) AS c1#x, cast(col2#x as int) AS c2#x] + +- LocalRelation [col1#x, col2#x] + + +-- !query +SELECT foo2_3(SUM(c1), SUM(c2)), SUM(c1) + SUM(c2), SUM(foo2_3(c1, c2) + foo2_3(c2, c1) - foo2_3(c2, c1)) +FROM V1 +-- !query analysis +Project [spark_catalog.default.foo2_3(a#x, b#x) AS spark_catalog.default.foo2_3(sum(c1), sum(c2))#x, (sum(c1) + sum(c2))#xL, sum(((spark_catalog.default.foo2_3(c1, c2) + spark_catalog.default.foo2_3(c2, c1)) - spark_catalog.default.foo2_3(c2, c1)))#xL] ++- Project [sum(c1)#xL, sum(c2)#xL, (sum(c1) + sum(c2))#xL, sum(((spark_catalog.default.foo2_3(c1, c2) + spark_catalog.default.foo2_3(c2, c1)) - spark_catalog.default.foo2_3(c2, c1)))#xL, cast(sum(c1)#xL as int) AS a#x, cast(sum(c2)#xL as int) AS b#x] + +- Aggregate [sum(c1#x) AS sum(c1)#xL, sum(c2#x) AS sum(c2)#xL, (sum(c1#x) + sum(c2#x)) AS (sum(c1) + sum(c2))#xL, sum(((spark_catalog.default.foo2_3(a#x, b#x) + spark_catalog.default.foo2_3(a#x, b#x)) - spark_catalog.default.foo2_3(a#x, b#x))) AS sum(((spark_catalog.default.foo2_3(c1, c2) + spark_catalog.default.foo2_3(c2, c1)) - spark_catalog.default.foo2_3(c2, c1)))#xL] + +- Project [c1#x, c2#x, cast(c1#x as int) AS a#x, cast(c2#x as int) AS b#x, cast(c2#x as int) AS a#x, cast(c1#x as int) AS b#x, cast(c2#x as int) AS a#x, cast(c1#x as int) AS b#x] + +- SubqueryAlias spark_catalog.default.v1 + +- View (`spark_catalog`.`default`.`v1`, [c1#x, c2#x]) + +- Project [cast(col1#x as int) AS c1#x, cast(col2#x as int) AS c2#x] + +- LocalRelation [col1#x, col2#x] + + +-- !query +CREATE FUNCTION foo2_4a(a ARRAY) RETURNS STRING RETURN +SELECT array_sort(a, (i, j) -> rank[i] - rank[j])[0] FROM (SELECT MAP('a', 1, 'b', 2) rank) +-- !query analysis +org.apache.spark.sql.catalyst.analysis.FunctionAlreadyExistsException +{ + "errorClass" : "ROUTINE_ALREADY_EXISTS", + "sqlState" : "42723", + "messageParameters" : { + "existingRoutineType" : "routine", + "newRoutineType" : "routine", + "routineName" : "`default`.`foo2_4a`" + } +} + + +-- !query +SELECT foo2_4a(ARRAY('a', 'b')) +-- !query analysis +Project [spark_catalog.default.foo2_4a(a#x) AS spark_catalog.default.foo2_4a(array(a, b))#x] +: +- Project [array_sort(outer(a#x), lambdafunction((rank#x[lambda i#x] - rank#x[lambda j#x]), lambda i#x, lambda j#x, false), false)[0] AS array_sort(outer(foo2_4a.a), lambdafunction((rank[namedlambdavariable()] - rank[namedlambdavariable()]), namedlambdavariable(), namedlambdavariable()))[0]#x] +: +- SubqueryAlias __auto_generated_subquery_name +: +- Project [map(a, 1, b, 2) AS rank#x] +: +- OneRowRelation ++- Project [cast(array(a, b) as array) AS a#x] + +- OneRowRelation + + +-- !query +CREATE FUNCTION foo2_4b(m MAP, k STRING) RETURNS STRING RETURN +SELECT v || ' ' || v FROM (SELECT upper(m[k]) AS v) +-- !query analysis +org.apache.spark.sql.catalyst.analysis.FunctionAlreadyExistsException +{ + "errorClass" : "ROUTINE_ALREADY_EXISTS", + "sqlState" : "42723", + "messageParameters" : { + "existingRoutineType" : "routine", + "newRoutineType" : "routine", + "routineName" : "`default`.`foo2_4b`" + } +} + + +-- !query +SELECT foo2_4b(map('a', 'hello', 'b', 'world'), 'a') +-- !query analysis +Project [spark_catalog.default.foo2_4b(m#x, k#x) AS spark_catalog.default.foo2_4b(map(a, hello, b, world), a)#x] +: +- Project [concat(concat(v#x, ), v#x) AS concat(concat(v, ), v)#x] +: +- SubqueryAlias __auto_generated_subquery_name +: +- Project [upper(outer(m#x)[outer(k#x)]) AS v#x] +: +- OneRowRelation ++- Project [cast(map(a, hello, b, world) as map) AS m#x, cast(a as string) AS k#x] + +- OneRowRelation + + +-- !query +DROP VIEW V2 +-- !query analysis +DropTableCommand `spark_catalog`.`default`.`V2`, false, true, false + + +-- !query +DROP VIEW V1 +-- !query analysis +DropTableCommand `spark_catalog`.`default`.`V1`, false, true, false diff --git a/sql/core/src/test/resources/sql-tests/inputs/sql-udf.sql b/sql/core/src/test/resources/sql-tests/inputs/sql-udf.sql new file mode 100644 index 0000000000000..34cb41d726766 --- /dev/null +++ b/sql/core/src/test/resources/sql-tests/inputs/sql-udf.sql @@ -0,0 +1,122 @@ +-- test cases for SQL User Defined Functions + +-- 1. CREATE FUNCTION +-- 1.1 Parameter +-- 1.1.a A scalar function with various numbers of parameter +-- Expect success +CREATE FUNCTION foo1a0() RETURNS INT RETURN 1; +-- Expect: 1 +SELECT foo1a0(); +-- Expect failure +SELECT foo1a0(1); + +CREATE FUNCTION foo1a1(a INT) RETURNS INT RETURN 1; +-- Expect: 1 +SELECT foo1a1(1); +-- Expect failure +SELECT foo1a1(1, 2); + +CREATE FUNCTION foo1a2(a INT, b INT, c INT, d INT) RETURNS INT RETURN 1; +-- Expect: 1 +SELECT foo1a2(1, 2, 3, 4); + +------------------------------- +-- 2. Scalar SQL UDF +-- 2.1 deterministic simple expressions +CREATE FUNCTION foo2_1a(a INT) RETURNS INT RETURN a; +SELECT foo2_1a(5); + +CREATE FUNCTION foo2_1b(a INT, b INT) RETURNS INT RETURN a + b; +SELECT foo2_1b(5, 6); + +CREATE FUNCTION foo2_1c(a INT, b INT) RETURNS INT RETURN 10 * (a + b) + 100 * (a -b); +SELECT foo2_1c(5, 6); + +CREATE FUNCTION foo2_1d(a INT, b INT) RETURNS INT RETURN ABS(a) - LENGTH(CAST(b AS VARCHAR(10))); +SELECT foo2_1d(-5, 6); + +-- 2.2 deterministic complex expression with subqueries +-- 2.2.1 Nested Scalar subqueries +CREATE FUNCTION foo2_2a(a INT) RETURNS INT RETURN SELECT a; +SELECT foo2_2a(5); + +CREATE FUNCTION foo2_2b(a INT) RETURNS INT RETURN 1 + (SELECT a); +SELECT foo2_2b(5); + +-- Expect error: deep correlation is not yet supported +CREATE FUNCTION foo2_2c(a INT) RETURNS INT RETURN 1 + (SELECT (SELECT a)); +-- SELECT foo2_2c(5); + +-- Expect error: deep correlation is not yet supported +CREATE FUNCTION foo2_2d(a INT) RETURNS INT RETURN 1 + (SELECT (SELECT (SELECT (SELECT a)))); +-- SELECT foo2_2d(5); + +-- 2.2.2 Set operations +-- Expect error: correlated scalar subquery must be aggregated. +CREATE FUNCTION foo2_2e(a INT) RETURNS INT RETURN +SELECT a FROM (VALUES 1) AS V(c1) WHERE c1 = 2 +UNION ALL +SELECT a + 1 FROM (VALUES 1) AS V(c1); +-- SELECT foo2_2e(5); + +-- Expect error: correlated scalar subquery must be aggregated. +CREATE FUNCTION foo2_2f(a INT) RETURNS INT RETURN +SELECT a FROM (VALUES 1) AS V(c1) +EXCEPT +SELECT a + 1 FROM (VALUES 1) AS V(a); +-- SELECT foo2_2f(5); + +-- Expect error: correlated scalar subquery must be aggregated. +CREATE FUNCTION foo2_2g(a INT) RETURNS INT RETURN +SELECT a FROM (VALUES 1) AS V(c1) +INTERSECT +SELECT a FROM (VALUES 1) AS V(a); +-- SELECT foo2_2g(5); + +-- Prepare by dropping views or tables if they already exist. +DROP TABLE IF EXISTS t1; +DROP TABLE IF EXISTS t2; +DROP TABLE IF EXISTS ts; +DROP TABLE IF EXISTS tm; +DROP TABLE IF EXISTS ta; +DROP TABLE IF EXISTS V1; +DROP TABLE IF EXISTS V2; +DROP VIEW IF EXISTS t1; +DROP VIEW IF EXISTS t2; +DROP VIEW IF EXISTS ts; +DROP VIEW IF EXISTS tm; +DROP VIEW IF EXISTS ta; +DROP VIEW IF EXISTS V1; +DROP VIEW IF EXISTS V2; + +-- 2.3 Calling Scalar UDF from various places +CREATE FUNCTION foo2_3(a INT, b INT) RETURNS INT RETURN a + b; +CREATE VIEW V1(c1, c2) AS VALUES (1, 2), (3, 4), (5, 6); +CREATE VIEW V2(c1, c2) AS VALUES (-1, -2), (-3, -4), (-5, -6); + +-- 2.3.1 Multiple times in the select list +SELECT foo2_3(c1, c2), foo2_3(c2, 1), foo2_3(c1, c2) - foo2_3(c2, c1 - 1) FROM V1 ORDER BY 1, 2, 3; + +-- 2.3.2 In the WHERE clause +SELECT * FROM V1 WHERE foo2_3(c1, 0) = c1 AND foo2_3(c1, c2) < 8; + +-- 2.3.3 Different places around an aggregate +SELECT foo2_3(SUM(c1), SUM(c2)), SUM(c1) + SUM(c2), SUM(foo2_3(c1, c2) + foo2_3(c2, c1) - foo2_3(c2, c1)) +FROM V1; + +-- 2.4 Scalar UDF with complex one row relation subquery +-- 2.4.1 higher order functions +CREATE FUNCTION foo2_4a(a ARRAY) RETURNS STRING RETURN +SELECT array_sort(a, (i, j) -> rank[i] - rank[j])[0] FROM (SELECT MAP('a', 1, 'b', 2) rank); + +SELECT foo2_4a(ARRAY('a', 'b')); + +-- 2.4.2 built-in functions +CREATE FUNCTION foo2_4b(m MAP, k STRING) RETURNS STRING RETURN +SELECT v || ' ' || v FROM (SELECT upper(m[k]) AS v); + +SELECT foo2_4b(map('a', 'hello', 'b', 'world'), 'a'); + +-- Clean up +DROP VIEW V2; +DROP VIEW V1; diff --git a/sql/core/src/test/resources/sql-tests/results/sql-udf.sql.out b/sql/core/src/test/resources/sql-tests/results/sql-udf.sql.out new file mode 100644 index 0000000000000..9f7af7c644871 --- /dev/null +++ b/sql/core/src/test/resources/sql-tests/results/sql-udf.sql.out @@ -0,0 +1,484 @@ +-- Automatically generated by SQLQueryTestSuite +-- !query +CREATE FUNCTION foo1a0() RETURNS INT RETURN 1 +-- !query schema +struct<> +-- !query output + + + +-- !query +SELECT foo1a0() +-- !query schema +struct +-- !query output +1 + + +-- !query +SELECT foo1a0(1) +-- !query schema +struct<> +-- !query output +org.apache.spark.sql.AnalysisException +{ + "errorClass" : "WRONG_NUM_ARGS.WITHOUT_SUGGESTION", + "sqlState" : "42605", + "messageParameters" : { + "actualNum" : "1", + "docroot" : "https://spark.apache.org/docs/latest", + "expectedNum" : "0", + "functionName" : "`spark_catalog`.`default`.`foo1a0`" + }, + "queryContext" : [ { + "objectType" : "", + "objectName" : "", + "startIndex" : 8, + "stopIndex" : 16, + "fragment" : "foo1a0(1)" + } ] +} + + +-- !query +CREATE FUNCTION foo1a1(a INT) RETURNS INT RETURN 1 +-- !query schema +struct<> +-- !query output + + + +-- !query +SELECT foo1a1(1) +-- !query schema +struct +-- !query output +1 + + +-- !query +SELECT foo1a1(1, 2) +-- !query schema +struct<> +-- !query output +org.apache.spark.sql.AnalysisException +{ + "errorClass" : "WRONG_NUM_ARGS.WITHOUT_SUGGESTION", + "sqlState" : "42605", + "messageParameters" : { + "actualNum" : "2", + "docroot" : "https://spark.apache.org/docs/latest", + "expectedNum" : "1", + "functionName" : "`spark_catalog`.`default`.`foo1a1`" + }, + "queryContext" : [ { + "objectType" : "", + "objectName" : "", + "startIndex" : 8, + "stopIndex" : 19, + "fragment" : "foo1a1(1, 2)" + } ] +} + + +-- !query +CREATE FUNCTION foo1a2(a INT, b INT, c INT, d INT) RETURNS INT RETURN 1 +-- !query schema +struct<> +-- !query output + + + +-- !query +SELECT foo1a2(1, 2, 3, 4) +-- !query schema +struct +-- !query output +1 + + +-- !query +CREATE FUNCTION foo2_1a(a INT) RETURNS INT RETURN a +-- !query schema +struct<> +-- !query output + + + +-- !query +SELECT foo2_1a(5) +-- !query schema +struct +-- !query output +5 + + +-- !query +CREATE FUNCTION foo2_1b(a INT, b INT) RETURNS INT RETURN a + b +-- !query schema +struct<> +-- !query output + + + +-- !query +SELECT foo2_1b(5, 6) +-- !query schema +struct +-- !query output +11 + + +-- !query +CREATE FUNCTION foo2_1c(a INT, b INT) RETURNS INT RETURN 10 * (a + b) + 100 * (a -b) +-- !query schema +struct<> +-- !query output + + + +-- !query +SELECT foo2_1c(5, 6) +-- !query schema +struct +-- !query output +10 + + +-- !query +CREATE FUNCTION foo2_1d(a INT, b INT) RETURNS INT RETURN ABS(a) - LENGTH(CAST(b AS VARCHAR(10))) +-- !query schema +struct<> +-- !query output + + + +-- !query +SELECT foo2_1d(-5, 6) +-- !query schema +struct +-- !query output +4 + + +-- !query +CREATE FUNCTION foo2_2a(a INT) RETURNS INT RETURN SELECT a +-- !query schema +struct<> +-- !query output + + + +-- !query +SELECT foo2_2a(5) +-- !query schema +struct +-- !query output +5 + + +-- !query +CREATE FUNCTION foo2_2b(a INT) RETURNS INT RETURN 1 + (SELECT a) +-- !query schema +struct<> +-- !query output + + + +-- !query +SELECT foo2_2b(5) +-- !query schema +struct +-- !query output +6 + + +-- !query +CREATE FUNCTION foo2_2c(a INT) RETURNS INT RETURN 1 + (SELECT (SELECT a)) +-- !query schema +struct<> +-- !query output +org.apache.spark.sql.catalyst.ExtendedAnalysisException +{ + "errorClass" : "UNRESOLVED_COLUMN.WITHOUT_SUGGESTION", + "sqlState" : "42703", + "messageParameters" : { + "objectName" : "`a`" + }, + "queryContext" : [ { + "objectType" : "", + "objectName" : "", + "startIndex" : 21, + "stopIndex" : 21, + "fragment" : "a" + } ] +} + + +-- !query +CREATE FUNCTION foo2_2d(a INT) RETURNS INT RETURN 1 + (SELECT (SELECT (SELECT (SELECT a)))) +-- !query schema +struct<> +-- !query output +org.apache.spark.sql.catalyst.ExtendedAnalysisException +{ + "errorClass" : "UNRESOLVED_COLUMN.WITHOUT_SUGGESTION", + "sqlState" : "42703", + "messageParameters" : { + "objectName" : "`a`" + }, + "queryContext" : [ { + "objectType" : "", + "objectName" : "", + "startIndex" : 37, + "stopIndex" : 37, + "fragment" : "a" + } ] +} + + +-- !query +CREATE FUNCTION foo2_2e(a INT) RETURNS INT RETURN +SELECT a FROM (VALUES 1) AS V(c1) WHERE c1 = 2 +UNION ALL +SELECT a + 1 FROM (VALUES 1) AS V(c1) +-- !query schema +struct<> +-- !query output + + + +-- !query +CREATE FUNCTION foo2_2f(a INT) RETURNS INT RETURN +SELECT a FROM (VALUES 1) AS V(c1) +EXCEPT +SELECT a + 1 FROM (VALUES 1) AS V(a) +-- !query schema +struct<> +-- !query output + + + +-- !query +CREATE FUNCTION foo2_2g(a INT) RETURNS INT RETURN +SELECT a FROM (VALUES 1) AS V(c1) +INTERSECT +SELECT a FROM (VALUES 1) AS V(a) +-- !query schema +struct<> +-- !query output + + + +-- !query +DROP TABLE IF EXISTS t1 +-- !query schema +struct<> +-- !query output + + + +-- !query +DROP TABLE IF EXISTS t2 +-- !query schema +struct<> +-- !query output + + + +-- !query +DROP TABLE IF EXISTS ts +-- !query schema +struct<> +-- !query output + + + +-- !query +DROP TABLE IF EXISTS tm +-- !query schema +struct<> +-- !query output + + + +-- !query +DROP TABLE IF EXISTS ta +-- !query schema +struct<> +-- !query output + + + +-- !query +DROP TABLE IF EXISTS V1 +-- !query schema +struct<> +-- !query output + + + +-- !query +DROP TABLE IF EXISTS V2 +-- !query schema +struct<> +-- !query output + + + +-- !query +DROP VIEW IF EXISTS t1 +-- !query schema +struct<> +-- !query output + + + +-- !query +DROP VIEW IF EXISTS t2 +-- !query schema +struct<> +-- !query output + + + +-- !query +DROP VIEW IF EXISTS ts +-- !query schema +struct<> +-- !query output + + + +-- !query +DROP VIEW IF EXISTS tm +-- !query schema +struct<> +-- !query output + + + +-- !query +DROP VIEW IF EXISTS ta +-- !query schema +struct<> +-- !query output + + + +-- !query +DROP VIEW IF EXISTS V1 +-- !query schema +struct<> +-- !query output + + + +-- !query +DROP VIEW IF EXISTS V2 +-- !query schema +struct<> +-- !query output + + + +-- !query +CREATE FUNCTION foo2_3(a INT, b INT) RETURNS INT RETURN a + b +-- !query schema +struct<> +-- !query output + + + +-- !query +CREATE VIEW V1(c1, c2) AS VALUES (1, 2), (3, 4), (5, 6) +-- !query schema +struct<> +-- !query output + + + +-- !query +CREATE VIEW V2(c1, c2) AS VALUES (-1, -2), (-3, -4), (-5, -6) +-- !query schema +struct<> +-- !query output + + + +-- !query +SELECT foo2_3(c1, c2), foo2_3(c2, 1), foo2_3(c1, c2) - foo2_3(c2, c1 - 1) FROM V1 ORDER BY 1, 2, 3 +-- !query schema +struct +-- !query output +3 3 1 +7 5 1 +11 7 1 + + +-- !query +SELECT * FROM V1 WHERE foo2_3(c1, 0) = c1 AND foo2_3(c1, c2) < 8 +-- !query schema +struct +-- !query output +1 2 +3 4 + + +-- !query +SELECT foo2_3(SUM(c1), SUM(c2)), SUM(c1) + SUM(c2), SUM(foo2_3(c1, c2) + foo2_3(c2, c1) - foo2_3(c2, c1)) +FROM V1 +-- !query schema +struct +-- !query output +21 21 21 + + +-- !query +CREATE FUNCTION foo2_4a(a ARRAY) RETURNS STRING RETURN +SELECT array_sort(a, (i, j) -> rank[i] - rank[j])[0] FROM (SELECT MAP('a', 1, 'b', 2) rank) +-- !query schema +struct<> +-- !query output + + + +-- !query +SELECT foo2_4a(ARRAY('a', 'b')) +-- !query schema +struct +-- !query output +a + + +-- !query +CREATE FUNCTION foo2_4b(m MAP, k STRING) RETURNS STRING RETURN +SELECT v || ' ' || v FROM (SELECT upper(m[k]) AS v) +-- !query schema +struct<> +-- !query output + + + +-- !query +SELECT foo2_4b(map('a', 'hello', 'b', 'world'), 'a') +-- !query schema +struct +-- !query output +HELLO HELLO + + +-- !query +DROP VIEW V2 +-- !query schema +struct<> +-- !query output + + + +-- !query +DROP VIEW V1 +-- !query schema +struct<> +-- !query output + diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/SQLFunctionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/SQLFunctionSuite.scala new file mode 100644 index 0000000000000..4da3b9ab1d06b --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/SQLFunctionSuite.scala @@ -0,0 +1,61 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution + +import org.apache.spark.sql.{QueryTest, Row} +import org.apache.spark.sql.test.SharedSparkSession + +/** + * Test suite for SQL user-defined functions (UDFs). + */ +class SQLFunctionSuite extends QueryTest with SharedSparkSession { + import testImplicits._ + + protected override def beforeAll(): Unit = { + super.beforeAll() + Seq((0, 1), (1, 2)).toDF("a", "b").createOrReplaceTempView("t") + } + + test("SQL scalar function") { + withUserDefinedFunction("area" -> false) { + sql( + """ + |CREATE FUNCTION area(width DOUBLE, height DOUBLE) + |RETURNS DOUBLE + |RETURN width * height + |""".stripMargin) + checkAnswer(sql("SELECT area(1, 2)"), Row(2)) + checkAnswer(sql("SELECT area(a, b) FROM t"), Seq(Row(0), Row(2))) + } + } + + test("SQL scalar function with subquery in the function body") { + withUserDefinedFunction("foo" -> false) { + withTable("tbl") { + sql("CREATE TABLE tbl AS SELECT * FROM VALUES (1, 2), (1, 3), (2, 3) t(a, b)") + sql( + """ + |CREATE FUNCTION foo(x INT) RETURNS INT + |RETURN SELECT SUM(b) FROM tbl WHERE x = a; + |""".stripMargin) + checkAnswer(sql("SELECT foo(1)"), Row(5)) + checkAnswer(sql("SELECT foo(a) FROM t"), Seq(Row(null), Row(5))) + } + } + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/expressions/ExpressionInfoSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/expressions/ExpressionInfoSuite.scala index c00f00ceaa355..a7af22a0554e9 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/expressions/ExpressionInfoSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/expressions/ExpressionInfoSuite.scala @@ -79,7 +79,8 @@ class ExpressionInfoSuite extends SparkFunSuite with SharedSparkSession { assert(info.getSource === "built-in") val validSources = Seq( - "built-in", "hive", "python_udf", "scala_udf", "java_udf", "python_udtf", "internal") + "built-in", "hive", "python_udf", "scala_udf", "java_udf", "python_udtf", "internal", + "sql_udf") validSources.foreach { source => val info = new ExpressionInfo( "testClass", null, "testName", null, "", "", "", "", "", "", source)