Skip to content

Commit

Permalink
[SPARK-50762][SQL] Add Analyzer rule for resolving SQL scalar UDFs
Browse files Browse the repository at this point in the history
### What changes were proposed in this pull request?

This PR adds a new Analyzer rule `ResolveSQLFunctions` to resolve scalar SQL UDFs by replacing a `SQLFunctionExpression` with an actual function body. It currently supports the following operators: Project, Filter, Join and Aggregate.

For example:
```
CREATE FUNCTION area(width DOUBLE, height DOUBLE) RETURNS DOUBLE
RETURN width * height;
```
and this query
```
SELECT area(a, b) FROM t;
```
will be resolved as
```
Project [area(width, height) AS area]
  +- Project [a, b, CAST(a AS DOUBLE) AS width, CAST(b AS DOUBLE) AS height]
    +- Relation [a, b]
```

### Why are the changes needed?

To support SQL UDFs.

### Does this PR introduce _any_ user-facing change?

No

### How was this patch tested?

New SQL query tests. More tests will be added once table function resolution is supported.

### Was this patch authored or co-authored using generative AI tooling?

No

Closes apache#49414 from allisonwang-db/spark-50762-resolve-scalar-udf.

Authored-by: Allison Wang <[email protected]>
Signed-off-by: Wenchen Fan <[email protected]>
  • Loading branch information
allisonwang-db authored and cloud-fan committed Jan 14, 2025
1 parent 3c7f5e2 commit bba6839
Show file tree
Hide file tree
Showing 15 changed files with 1,753 additions and 7 deletions.
13 changes: 13 additions & 0 deletions common/utils/src/main/resources/error/error-conditions.json
Original file line number Diff line number Diff line change
Expand Up @@ -3126,6 +3126,13 @@
],
"sqlState" : "42K08"
},
"INVALID_SQL_FUNCTION_PLAN_STRUCTURE" : {
"message" : [
"Invalid SQL function plan structure",
"<plan>"
],
"sqlState" : "XXKD0"
},
"INVALID_SQL_SYNTAX" : {
"message" : [
"Invalid SQL syntax:"
Expand Down Expand Up @@ -5757,6 +5764,12 @@
],
"sqlState" : "0A000"
},
"UNSUPPORTED_SQL_UDF_USAGE" : {
"message" : [
"Using SQL function <functionName> in <nodeName> is not supported."
],
"sqlState" : "0A000"
},
"UNSUPPORTED_STREAMING_OPERATOR_WITHOUT_WATERMARK" : {
"message" : [
"<outputMode> output mode not supported for <statefulOperator> on streaming DataFrames/DataSets without watermark."
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ public class ExpressionInfo {
"window_funcs", "xml_funcs", "table_funcs", "url_funcs", "variant_funcs"));

private static final Set<String> validSources =
new HashSet<>(Arrays.asList("built-in", "hive", "python_udf", "scala_udf",
new HashSet<>(Arrays.asList("built-in", "hive", "python_udf", "scala_udf", "sql_udf",
"java_udf", "python_udtf", "internal"));

public String getClassName() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -374,6 +374,7 @@ class Analyzer(override val catalogManager: CatalogManager) extends RuleExecutor
BindProcedures ::
ResolveTableSpec ::
ValidateAndStripPipeExpressions ::
ResolveSQLFunctions ::
ResolveAliases ::
ResolveSubquery ::
ResolveSubqueryColumnAliases ::
Expand Down Expand Up @@ -2364,6 +2365,277 @@ class Analyzer(override val catalogManager: CatalogManager) extends RuleExecutor
}
}

/**
* This rule resolves SQL function expressions. It pulls out function inputs and place them
* in a separate [[Project]] node below the operator and replace the SQL function with its
* actual function body. SQL function expressions in [[Aggregate]] are handled in a special
* way. Non-aggregated SQL functions in the aggregate expressions of an Aggregate need to be
* pulled out into a Project above the Aggregate before replacing the SQL function expressions
* with actual function bodies. For example:
*
* Before:
* Aggregate [c1] [foo(c1), foo(max(c2)), sum(foo(c2)) AS sum]
* +- Relation [c1, c2]
*
* After:
* Project [foo(c1), foo(max_c2), sum]
* +- Aggregate [c1] [c1, max(c2) AS max_c2, sum(foo(c2)) AS sum]
* +- Relation [c1, c2]
*/
object ResolveSQLFunctions extends Rule[LogicalPlan] {

private def hasSQLFunctionExpression(exprs: Seq[Expression]): Boolean = {
exprs.exists(_.find(_.isInstanceOf[SQLFunctionExpression]).nonEmpty)
}

/**
* Check if the function input contains aggregate expressions.
*/
private def checkFunctionInput(f: SQLFunctionExpression): Unit = {
if (f.inputs.exists(AggregateExpression.containsAggregate)) {
// The input of a SQL function should not contain aggregate functions after
// `extractAndRewrite`. If there are aggregate functions, it means they are
// nested in another aggregate function, which is not allowed.
// For example: SELECT sum(foo(sum(c1))) FROM t
// We have to throw the error here because otherwise the query plan after
// resolving the SQL function will not be valid.
throw new AnalysisException(
errorClass = "NESTED_AGGREGATE_FUNCTION",
messageParameters = Map.empty)
}
}

/**
* Resolve a SQL function expression as a logical plan check if it can be analyzed.
*/
private def resolve(f: SQLFunctionExpression): LogicalPlan = {
// Validate the SQL function input.
checkFunctionInput(f)
val plan = v1SessionCatalog.makeSQLFunctionPlan(f.name, f.function, f.inputs)
val resolved = SQLFunctionContext.withSQLFunction {
// Resolve the SQL function plan using its context.
val conf = new SQLConf()
f.function.getSQLConfigs.foreach { case (k, v) => conf.settings.put(k, v) }
SQLConf.withExistingConf(conf) {
executeSameContext(plan)
}
}
// Fail the analysis eagerly if a SQL function cannot be resolved using its input.
SimpleAnalyzer.checkAnalysis(resolved)
resolved
}

/**
* Rewrite SQL function expressions into actual resolved function bodies and extract
* function inputs into the given project list.
*/
private def rewriteSQLFunctions[E <: Expression](
expression: E,
projectList: ArrayBuffer[NamedExpression]): E = {
val newExpr = expression match {
case f: SQLFunctionExpression if !hasSQLFunctionExpression(f.inputs) &&
// Make sure LateralColumnAliasReference in parameters is resolved and eliminated first.
// Otherwise, the projectList can contain the LateralColumnAliasReference, which will be
// pushed down to a Project without the 'referenced' alias by LCA present, leaving it
// unresolved.
!f.inputs.exists(_.containsPattern(LATERAL_COLUMN_ALIAS_REFERENCE)) =>
withPosition(f) {
val plan = resolve(f)
// Extract the function input project list from the SQL function plan and
// inline the SQL function expression.
plan match {
case Project(body :: Nil, Project(aliases, _: LocalRelation)) =>
projectList ++= aliases
SQLScalarFunction(f.function, aliases.map(_.toAttribute), body)
case o =>
throw new AnalysisException(
errorClass = "INVALID_SQL_FUNCTION_PLAN_STRUCTURE",
messageParameters = Map("plan" -> o.toString))
}
}
case o => o.mapChildren(rewriteSQLFunctions(_, projectList))
}
newExpr.asInstanceOf[E]
}

/**
* Check if the given expression contains expressions that should be extracted,
* i.e. non-aggregated SQL functions with non-foldable inputs.
*/
private def shouldExtract(e: Expression): Boolean = e match {
// Return false if the expression is already an aggregate expression.
case _: AggregateExpression => false
case _: SQLFunctionExpression => true
case _: LeafExpression => false
case o => o.children.exists(shouldExtract)
}

/**
* Extract aggregate expressions from the given expression and replace
* them with attribute references.
* Example:
* Before: foo(c1) + foo(max(c2)) + max(foo(c2))
* After: foo(c1) + foo(max_c2) + max_foo_c2
* Extracted expressions: [c1, max(c2) AS max_c2, max(foo(c2)) AS max_foo_c2]
*/
private def extractAndRewrite[T <: Expression](
expression: T,
extractedExprs: ArrayBuffer[NamedExpression]): T = {
val newExpr = expression match {
case e if !shouldExtract(e) =>
val exprToAdd: NamedExpression = e match {
case o: OuterReference => Alias(o, toPrettySQL(o.e))()
case ne: NamedExpression => ne
case o => Alias(o, toPrettySQL(o))()
}
extractedExprs += exprToAdd
exprToAdd.toAttribute
case f: SQLFunctionExpression =>
val newInputs = f.inputs.map(extractAndRewrite(_, extractedExprs))
f.copy(inputs = newInputs)
case o => o.mapChildren(extractAndRewrite(_, extractedExprs))
}
newExpr.asInstanceOf[T]
}

/**
* Replace all [[SQLFunctionExpression]]s in an expression with attribute references
* from the aliasMap.
*/
private def replaceSQLFunctionWithAttr[T <: Expression](
expr: T,
aliasMap: mutable.HashMap[Expression, Alias]): T = {
expr.transform {
case f: SQLFunctionExpression if aliasMap.contains(f.canonicalized) =>
aliasMap(f.canonicalized).toAttribute
}.asInstanceOf[T]
}

private def rewrite(plan: LogicalPlan): LogicalPlan = plan match {
// Return if a sub-tree does not contain SQLFunctionExpression.
case p: LogicalPlan if !p.containsPattern(SQL_FUNCTION_EXPRESSION) => p

case f @ Filter(cond, a: Aggregate)
if !f.resolved || AggregateExpression.containsAggregate(cond) ||
ResolveGroupingAnalytics.hasGroupingFunction(cond) ||
cond.containsPattern(TEMP_RESOLVED_COLUMN) =>
// If the filter's condition contains aggregate expressions or grouping expressions or temp
// resolved column, we cannot rewrite both the filter and the aggregate until they are
// resolved by ResolveAggregateFunctions or ResolveGroupingAnalytics, because rewriting SQL
// functions in aggregate can add an additional project on top of the aggregate
// which breaks the pattern matching in those rules.
f.copy(child = a.copy(child = rewrite(a.child)))

case h @ UnresolvedHaving(_, a: Aggregate) =>
// Similarly UnresolvedHaving should be resolved by ResolveAggregateFunctions first
// before rewriting aggregate.
h.copy(child = a.copy(child = rewrite(a.child)))

case a: Aggregate if a.resolved && hasSQLFunctionExpression(a.expressions) =>
val child = rewrite(a.child)
// Extract SQL functions in the grouping expressions and place them in a project list
// below the current aggregate. Also update their appearances in the aggregate expressions.
val bottomProjectList = ArrayBuffer.empty[NamedExpression]
val aliasMap = mutable.HashMap.empty[Expression, Alias]
val newGrouping = a.groupingExpressions.map { expr =>
expr.transformDown {
case f: SQLFunctionExpression =>
val alias = aliasMap.getOrElseUpdate(f.canonicalized, Alias(f, f.name)())
bottomProjectList += alias
alias.toAttribute
}
}
val aggregateExpressions = a.aggregateExpressions.map(
replaceSQLFunctionWithAttr(_, aliasMap))

// Rewrite SQL functions in the aggregate expressions that are not wrapped in
// aggregate functions. They need to be extracted into a project list above the
// current aggregate.
val aggExprs = ArrayBuffer.empty[NamedExpression]
val topProjectList = aggregateExpressions.map(extractAndRewrite(_, aggExprs))

// Rewrite SQL functions in the new aggregate expressions that are wrapped inside
// aggregate functions.
val newAggExprs = aggExprs.map(rewriteSQLFunctions(_, bottomProjectList))

val bottomProject = if (bottomProjectList.nonEmpty) {
Project(child.output ++ bottomProjectList, child)
} else {
child
}
val newAgg = if (newGrouping.nonEmpty || newAggExprs.nonEmpty) {
a.copy(
groupingExpressions = newGrouping,
aggregateExpressions = newAggExprs.toSeq,
child = bottomProject)
} else {
bottomProject
}
if (topProjectList.nonEmpty) Project(topProjectList, newAgg) else newAgg

case p: Project if p.resolved && hasSQLFunctionExpression(p.expressions) =>
val newChild = rewrite(p.child)
val projectList = ArrayBuffer.empty[NamedExpression]
val newPList = p.projectList.map(rewriteSQLFunctions(_, projectList))
if (newPList != newChild.output) {
p.copy(newPList, Project(newChild.output ++ projectList, newChild))
} else {
assert(projectList.isEmpty)
p.copy(child = newChild)
}

case f: Filter if f.resolved && hasSQLFunctionExpression(f.expressions) =>
val newChild = rewrite(f.child)
val projectList = ArrayBuffer.empty[NamedExpression]
val newCond = rewriteSQLFunctions(f.condition, projectList)
if (newCond != f.condition) {
Project(f.output, Filter(newCond, Project(newChild.output ++ projectList, newChild)))
} else {
assert(projectList.isEmpty)
f.copy(child = newChild)
}

case j: Join if j.resolved && hasSQLFunctionExpression(j.expressions) =>
val newLeft = rewrite(j.left)
val newRight = rewrite(j.right)
val projectList = ArrayBuffer.empty[NamedExpression]
val joinCond = j.condition.map(rewriteSQLFunctions(_, projectList))
if (joinCond != j.condition) {
// Join condition cannot have non-deterministic expressions. We can safely
// replace the aliases with the original SQL function input expressions.
val aliasMap = projectList.collect { case a: Alias => a.toAttribute -> a.child }.toMap
val newJoinCond = joinCond.map(_.transform {
case a: Attribute => aliasMap.getOrElse(a, a)
})
j.copy(left = newLeft, right = newRight, condition = newJoinCond)
} else {
assert(projectList.isEmpty)
j.copy(left = newLeft, right = newRight)
}

case o: LogicalPlan if o.resolved && hasSQLFunctionExpression(o.expressions) =>
o.transformExpressionsWithPruning(_.containsPattern(SQL_FUNCTION_EXPRESSION)) {
case f: SQLFunctionExpression =>
f.failAnalysis(
errorClass = "UNSUPPORTED_SQL_UDF_USAGE",
messageParameters = Map(
"functionName" -> toSQLId(f.function.name.nameParts),
"nodeName" -> o.nodeName.toString))
}

case p: LogicalPlan => p.mapChildren(rewrite)
}

def apply(plan: LogicalPlan): LogicalPlan = {
// Only rewrite SQL functions when they are not in nested function calls.
if (SQLFunctionContext.get.nestedSQLFunctionDepth > 0) {
plan
} else {
rewrite(plan)
}
}
}

/**
* Turns projections that contain aggregate expressions into aggregations.
*/
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1106,6 +1106,8 @@ trait CheckAnalysis extends PredicateHelper with LookupCatalog with QueryErrorsB
@scala.annotation.tailrec
def cleanQueryInScalarSubquery(p: LogicalPlan): LogicalPlan = p match {
case s: SubqueryAlias => cleanQueryInScalarSubquery(s.child)
// Skip SQL function node added by the Analyzer
case s: SQLFunctionNode => cleanQueryInScalarSubquery(s.child)
case p: Project => cleanQueryInScalarSubquery(p.child)
case h: ResolvedHint => cleanQueryInScalarSubquery(h.child)
case child => child
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,8 @@
package org.apache.spark.sql.catalyst.analysis

import org.apache.spark.sql.catalyst.catalog.SQLFunction
import org.apache.spark.sql.catalyst.expressions.{Expression, Unevaluable}
import org.apache.spark.sql.catalyst.trees.TreePattern.{SQL_FUNCTION_EXPRESSION, TreePattern}
import org.apache.spark.sql.catalyst.expressions.{Expression, UnaryExpression, Unevaluable}
import org.apache.spark.sql.catalyst.trees.TreePattern.{SQL_FUNCTION_EXPRESSION, SQL_SCALAR_FUNCTION, TreePattern}
import org.apache.spark.sql.types.DataType

/**
Expand All @@ -39,3 +39,52 @@ case class SQLFunctionExpression(
newChildren: IndexedSeq[Expression]): SQLFunctionExpression = copy(inputs = newChildren)
final override val nodePatterns: Seq[TreePattern] = Seq(SQL_FUNCTION_EXPRESSION)
}

/**
* A wrapper node for a SQL scalar function expression.
*/
case class SQLScalarFunction(function: SQLFunction, inputs: Seq[Expression], child: Expression)
extends UnaryExpression with Unevaluable {
override def dataType: DataType = child.dataType
override def toString: String = s"${function.name}(${inputs.mkString(", ")})"
override def sql: String = s"${function.name}(${inputs.map(_.sql).mkString(", ")})"
override protected def withNewChildInternal(newChild: Expression): SQLScalarFunction = {
copy(child = newChild)
}
final override val nodePatterns: Seq[TreePattern] = Seq(SQL_SCALAR_FUNCTION)
// The `inputs` is for display only and does not matter in execution.
override lazy val canonicalized: Expression = copy(inputs = Nil, child = child.canonicalized)
override lazy val deterministic: Boolean = {
function.deterministic.getOrElse(true) && children.forall(_.deterministic)
}
}

/**
* Provide a way to keep state during analysis for resolving nested SQL functions.
*
* @param nestedSQLFunctionDepth The nested depth in the SQL function resolution. A SQL function
* expression should only be expanded as a [[SQLScalarFunction]] if
* the nested depth is 0.
*/
case class SQLFunctionContext(nestedSQLFunctionDepth: Int = 0)

object SQLFunctionContext {

private val value = new ThreadLocal[SQLFunctionContext]() {
override def initialValue: SQLFunctionContext = SQLFunctionContext()
}

def get: SQLFunctionContext = value.get()

def reset(): Unit = value.remove()

private def set(context: SQLFunctionContext): Unit = value.set(context)

def withSQLFunction[A](f: => A): A = {
val originContext = value.get()
val context = originContext.copy(
nestedSQLFunctionDepth = originContext.nestedSQLFunctionDepth + 1)
set(context)
try f finally { set(originContext) }
}
}
Loading

0 comments on commit bba6839

Please sign in to comment.