diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/parser/LogicalPlanBuilder.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/parser/LogicalPlanBuilder.java index e97dc3c2d4e12e9..2d2a061df2c813f 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/parser/LogicalPlanBuilder.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/parser/LogicalPlanBuilder.java @@ -4429,7 +4429,11 @@ public Command visitCreateUserDefineFunction(CreateUserDefineFunctionContext ctx functionArgsDefInfo = new FunctionArgsDefInfo(new ArrayList<>(), false); } DataType returnType = typedVisit(ctx.returnType); + returnType = returnType.conversion(); DataType intermediateType = ctx.intermediateType != null ? typedVisit(ctx.intermediateType) : null; + if (intermediateType != null) { + intermediateType = intermediateType.conversion(); + } Map properties = ctx.propertyClause() != null ? Maps.newHashMap(visitPropertyClause(ctx.propertyClause())) : Maps.newHashMap(); @@ -4482,9 +4486,11 @@ public Command visitDropFunction(DropFunctionContext ctx) { String functionName = ctx.functionIdentifier().functionNameIdentifier().getText(); String dbName = ctx.functionIdentifier().dbName != null ? ctx.functionIdentifier().dbName.getText() : null; FunctionName function = new FunctionName(dbName, functionName); - FunctionArgsDefInfo functionArgsDefInfo = null; + FunctionArgsDefInfo functionArgsDefInfo; if (ctx.functionArguments() != null) { functionArgsDefInfo = visitFunctionArguments(ctx.functionArguments()); + } else { + functionArgsDefInfo = new FunctionArgsDefInfo(new ArrayList<>(), false); } return new DropFunctionCommand(setType, ifExists, function, functionArgsDefInfo); } @@ -4497,7 +4503,7 @@ public FunctionArgsDefInfo visitFunctionArguments(FunctionArgumentsContext ctx) if (child instanceof FunctionArgumentContext) { DataType dataType = visitFunctionArgument((FunctionArgumentContext) child); if (dataType != null) { - argTypeDefs.add(dataType); + argTypeDefs.add(dataType.conversion()); } else { isVariadic = true; } diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/commands/CreateFunctionCommand.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/commands/CreateFunctionCommand.java index 2c2a803f477fa59..06c6c20661d076c 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/commands/CreateFunctionCommand.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/commands/CreateFunctionCommand.java @@ -18,8 +18,11 @@ package org.apache.doris.nereids.trees.plans.commands; import org.apache.doris.analysis.Expr; +import org.apache.doris.analysis.FunctionCallExpr; import org.apache.doris.analysis.FunctionName; +import org.apache.doris.analysis.FunctionParams; import org.apache.doris.analysis.SetType; +import org.apache.doris.analysis.SlotRef; import org.apache.doris.analysis.StmtType; import org.apache.doris.catalog.AggregateFunction; import org.apache.doris.catalog.AliasFunction; @@ -43,10 +46,26 @@ import org.apache.doris.common.util.Util; import org.apache.doris.mysql.privilege.PrivPredicate; import org.apache.doris.nereids.CascadesContext; +import org.apache.doris.nereids.analyzer.Scope; +import org.apache.doris.nereids.analyzer.UnboundSlot; import org.apache.doris.nereids.glue.translator.ExpressionTranslator; import org.apache.doris.nereids.glue.translator.PlanTranslatorContext; import org.apache.doris.nereids.properties.PhysicalProperties; +import org.apache.doris.nereids.rules.analysis.ExpressionAnalyzer; +import org.apache.doris.nereids.rules.expression.ExpressionRewriteContext; +import org.apache.doris.nereids.trees.expressions.Add; +import org.apache.doris.nereids.trees.expressions.BitAnd; +import org.apache.doris.nereids.trees.expressions.BitNot; +import org.apache.doris.nereids.trees.expressions.BitOr; +import org.apache.doris.nereids.trees.expressions.BitXor; +import org.apache.doris.nereids.trees.expressions.Divide; import org.apache.doris.nereids.trees.expressions.Expression; +import org.apache.doris.nereids.trees.expressions.IntegralDivide; +import org.apache.doris.nereids.trees.expressions.Mod; +import org.apache.doris.nereids.trees.expressions.Multiply; +import org.apache.doris.nereids.trees.expressions.SlotReference; +import org.apache.doris.nereids.trees.expressions.Subtract; +import org.apache.doris.nereids.trees.expressions.functions.BoundFunction; import org.apache.doris.nereids.trees.plans.PlanType; import org.apache.doris.nereids.trees.plans.commands.info.FunctionArgsDefInfo; import org.apache.doris.nereids.trees.plans.logical.LogicalEmptyRelation; @@ -67,6 +86,7 @@ import io.grpc.ManagedChannel; import io.grpc.netty.NettyChannelBuilder; import org.apache.commons.codec.binary.Hex; +import org.apache.commons.collections.map.CaseInsensitiveMap; import org.apache.commons.lang3.StringUtils; import java.io.IOException; @@ -211,6 +231,7 @@ public void run(ConnectContext ctx, StmtExecutor executor) throws Exception { String dbName = functionName.getDb(); if (dbName == null) { dbName = ctx.getDatabase(); + functionName.setDb(dbName); } Database db = Env.getCurrentInternalCatalog().getDbOrDdlException(dbName); db.addFunction(function, ifNotExists); @@ -891,18 +912,144 @@ private TFunctionBinaryType getFunctionBinaryType(String type) { private void analyzeAliasFunction(ConnectContext ctx) throws AnalysisException { function = AliasFunction.createFunction(functionName, argsDef.getArgTypes(), Type.VARCHAR, argsDef.isVariadic(), parameters, translateToLegacyExpr(originFunction, ctx)); - ((AliasFunction) function).analyze(); } /** * translate to legacy expr, which do not need complex expression and table columns */ - private Expr translateToLegacyExpr(Expression expression, ConnectContext ctx) { + private Expr translateToLegacyExpr(Expression expression, ConnectContext ctx) throws AnalysisException { LogicalEmptyRelation plan = new LogicalEmptyRelation( ConnectContext.get().getStatementContext().getNextRelationId(), new ArrayList<>()); CascadesContext cascadesContext = CascadesContext.initContext(ctx.getStatementContext(), plan, PhysicalProperties.ANY); + Map argTypeMap = new CaseInsensitiveMap(); + List argTypes = argsDef.getArgTypeDefs(); + if (!parameters.isEmpty()) { + if (parameters.size() != argTypes.size()) { + throw new AnalysisException(String.format("arguments' size must be same as parameters' size," + + "arguments : %s, parameters : %s", argTypes.size(), parameters.size())); + } + for (int i = 0; i < parameters.size(); ++i) { + argTypeMap.put(parameters.get(i), argTypes.get(i)); + } + } + ExpressionAnalyzer analyzer = new CustomExpressionAnalyzer(cascadesContext, argTypeMap); + expression = analyzer.analyze(expression); + PlanTranslatorContext translatorContext = new PlanTranslatorContext(cascadesContext); - return ExpressionTranslator.translate(expression, translatorContext); + ExpressionToExpr translator = new ExpressionToExpr(); + return expression.accept(translator, translatorContext); + } + + private static class CustomExpressionAnalyzer extends ExpressionAnalyzer { + private Map argTypeMap; + + public CustomExpressionAnalyzer(CascadesContext cascadesContext, Map argTypeMap) { + super(null, new Scope(ImmutableList.of()), cascadesContext, false, false); + this.argTypeMap = argTypeMap; + } + + @Override + public Expression visitUnboundSlot(UnboundSlot unboundSlot, ExpressionRewriteContext context) { + DataType dataType = argTypeMap.get(unboundSlot.getName()); + if (dataType == null) { + throw new org.apache.doris.nereids.exceptions.AnalysisException( + String.format("param %s's datatype is missed", unboundSlot.getName())); + } + return new SlotReference(unboundSlot.getName(), dataType); + } + } + + private static class ExpressionToExpr extends ExpressionTranslator { + @Override + public Expr visitSlotReference(SlotReference slotReference, PlanTranslatorContext context) { + SlotRef slotRef = new SlotRef(slotReference.getDataType().toCatalogDataType(), slotReference.nullable()); + slotRef.setLabel(slotReference.getName()); + slotRef.setCol(slotReference.getName()); + slotRef.setDisableTableName(true); + return slotRef; + } + + @Override + public Expr visitBoundFunction(BoundFunction function, PlanTranslatorContext context) { + return makeFunctionCallExpr(function, function.getName(), function.hasVarArguments(), context); + } + + @Override + public Expr visitAdd(Add add, PlanTranslatorContext context) { + return makeFunctionCallExpr(add, "add", false, context); + } + + @Override + public Expr visitSubtract(Subtract subtract, PlanTranslatorContext context) { + return makeFunctionCallExpr(subtract, "subtract", false, context); + } + + @Override + public Expr visitMultiply(Multiply multiply, PlanTranslatorContext context) { + return makeFunctionCallExpr(multiply, "multiply", false, context); + } + + @Override + public Expr visitDivide(Divide divide, PlanTranslatorContext context) { + return makeFunctionCallExpr(divide, "divide", false, context); + } + + @Override + public Expr visitIntegralDivide(IntegralDivide integralDivide, PlanTranslatorContext context) { + return makeFunctionCallExpr(integralDivide, "integralDivide", false, context); + } + + @Override + public Expr visitMod(Mod mod, PlanTranslatorContext context) { + return makeFunctionCallExpr(mod, "mod", false, context); + } + + @Override + public Expr visitBitAnd(BitAnd bitAnd, PlanTranslatorContext context) { + return makeFunctionCallExpr(bitAnd, "bitAnd", false, context); + } + + @Override + public Expr visitBitOr(BitOr bitOr, PlanTranslatorContext context) { + return makeFunctionCallExpr(bitOr, "bitOr", false, context); + } + + @Override + public Expr visitBitXor(BitXor bitXor, PlanTranslatorContext context) { + return makeFunctionCallExpr(bitXor, "bitXor", false, context); + } + + @Override + public Expr visitBitNot(BitNot bitNot, PlanTranslatorContext context) { + return makeFunctionCallExpr(bitNot, "bitNot", false, context); + } + + private Expr makeFunctionCallExpr(Expression expression, String name, boolean hasVarArguments, + PlanTranslatorContext context) { + List arguments = expression.getArguments().stream() + .map(arg -> arg.accept(this, context)) + .collect(Collectors.toList()); + + List argTypes = expression.getArguments().stream() + .map(Expression::getDataType) + .map(DataType::toCatalogDataType) + .collect(Collectors.toList()); + + NullableMode nullableMode = expression.nullable() + ? NullableMode.ALWAYS_NULLABLE + : NullableMode.ALWAYS_NOT_NULLABLE; + + org.apache.doris.catalog.ScalarFunction catalogFunction = new org.apache.doris.catalog.ScalarFunction( + new FunctionName(name), argTypes, + expression.getDataType().toCatalogDataType(), hasVarArguments, + "", TFunctionBinaryType.BUILTIN, true, true, nullableMode); + + FunctionCallExpr functionCallExpr; + // create catalog FunctionCallExpr without analyze again + functionCallExpr = new FunctionCallExpr(catalogFunction, new FunctionParams(false, arguments)); + functionCallExpr.setNullableFromNereids(expression.nullable()); + return functionCallExpr; + } } } diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/commands/DropFunctionCommand.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/commands/DropFunctionCommand.java index 75dde73e742cc55..004930193e9fd77 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/commands/DropFunctionCommand.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/commands/DropFunctionCommand.java @@ -23,6 +23,9 @@ import org.apache.doris.catalog.Database; import org.apache.doris.catalog.Env; import org.apache.doris.catalog.FunctionSearchDesc; +import org.apache.doris.common.ErrorCode; +import org.apache.doris.common.ErrorReport; +import org.apache.doris.mysql.privilege.PrivPredicate; import org.apache.doris.nereids.trees.plans.PlanType; import org.apache.doris.nereids.trees.plans.commands.info.FunctionArgsDefInfo; import org.apache.doris.nereids.trees.plans.visitor.PlanVisitor; @@ -64,6 +67,10 @@ public DropFunctionCommand(SetType setType, boolean ifExists, FunctionName funct @Override public void run(ConnectContext ctx, StmtExecutor executor) throws Exception { + // check operation privilege + if (!Env.getCurrentEnv().getAccessManager().checkGlobalPriv(ConnectContext.get(), PrivPredicate.ADMIN)) { + ErrorReport.reportAnalysisException(ErrorCode.ERR_SPECIFIC_ACCESS_DENIED_ERROR, "ADMIN"); + } argsDef.validate(); function = new FunctionSearchDesc(functionName, argsDef.getArgTypes(), argsDef.isVariadic()); if (SetType.GLOBAL.equals(setType)) { @@ -72,6 +79,7 @@ public void run(ConnectContext ctx, StmtExecutor executor) throws Exception { String dbName = functionName.getDb(); if (dbName == null) { dbName = ctx.getDatabase(); + functionName.setDb(dbName); } Database db = Env.getCurrentInternalCatalog().getDbOrDdlException(dbName); db.dropFunction(function, ifExists); diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/commands/info/FunctionArgsDefInfo.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/commands/info/FunctionArgsDefInfo.java index b8f93c0e6ca2782..e3da48bfb7cbed0 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/commands/info/FunctionArgsDefInfo.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/commands/info/FunctionArgsDefInfo.java @@ -42,6 +42,10 @@ public Type[] getArgTypes() { return argTypes; } + public List getArgTypeDefs() { + return argTypeDefs; + } + public boolean isVariadic() { return isVariadic; } diff --git a/regression-test/suites/ddl_p0/test_alias_function.groovy b/regression-test/suites/ddl_p0/test_alias_function.groovy index fa2ced713f81464..7793de925531fb4 100644 --- a/regression-test/suites/ddl_p0/test_alias_function.groovy +++ b/regression-test/suites/ddl_p0/test_alias_function.groovy @@ -18,10 +18,10 @@ suite("test_alias_function") { sql """DROP FUNCTION IF EXISTS mesh_udf_test1(INT,INT)""" - sql """CREATE ALIAS FUNCTION IF NOT EXISTS mesh_udf_test1(INT,INT) WITH PARAMETER(n,d) AS ROUND(1+floor(n/d));""" + sql """CREATE ALIAS FUNCTION mesh_udf_test1(INT,INT) WITH PARAMETER(n,d) AS ROUND(1+floor(n/d));""" qt_sql1 """select mesh_udf_test1(1,2);""" sql """DROP FUNCTION IF EXISTS mesh_udf_test2(INT,INT)""" - sql """CREATE ALIAS FUNCTION IF NOT EXISTS mesh_udf_test2(INT,INT) WITH PARAMETER(n,d) AS add(1,floor(divide(n,d)))""" + sql """CREATE ALIAS FUNCTION mesh_udf_test2(INT,INT) WITH PARAMETER(n,d) AS add(1,floor(divide(n,d)))""" qt_sql1 """select mesh_udf_test2(1,2);""" } diff --git a/regression-test/suites/query_p0/sql_functions/test_alias_function.groovy b/regression-test/suites/query_p0/sql_functions/test_alias_function.groovy index 095ec89e220f1b6..8b281d6faa05216 100644 --- a/regression-test/suites/query_p0/sql_functions/test_alias_function.groovy +++ b/regression-test/suites/query_p0/sql_functions/test_alias_function.groovy @@ -16,6 +16,12 @@ // under the License. suite('test_alias_function', "arrow_flight_sql") { + sql ''' + DROP FUNCTION IF EXISTS f1() + ''' + sql ''' + DROP FUNCTION IF EXISTS f2() + ''' sql ''' CREATE ALIAS FUNCTION IF NOT EXISTS f1(DATETIMEV2(3), INT) with PARAMETER (datetime1, int1) as date_trunc(days_sub(datetime1, int1), 'day')'''