Skip to content

Commit

Permalink
fix case
Browse files Browse the repository at this point in the history
  • Loading branch information
starocean999 committed Dec 27, 2024
1 parent 28bfed8 commit 02ccc8e
Show file tree
Hide file tree
Showing 6 changed files with 178 additions and 7 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -4429,7 +4429,11 @@ public Command visitCreateUserDefineFunction(CreateUserDefineFunctionContext ctx
functionArgsDefInfo = new FunctionArgsDefInfo(new ArrayList<>(), false);
}
DataType returnType = typedVisit(ctx.returnType);
returnType = returnType.conversion();
DataType intermediateType = ctx.intermediateType != null ? typedVisit(ctx.intermediateType) : null;
if (intermediateType != null) {
intermediateType = intermediateType.conversion();
}
Map<String, String> properties = ctx.propertyClause() != null
? Maps.newHashMap(visitPropertyClause(ctx.propertyClause()))
: Maps.newHashMap();
Expand Down Expand Up @@ -4482,9 +4486,11 @@ public Command visitDropFunction(DropFunctionContext ctx) {
String functionName = ctx.functionIdentifier().functionNameIdentifier().getText();
String dbName = ctx.functionIdentifier().dbName != null ? ctx.functionIdentifier().dbName.getText() : null;
FunctionName function = new FunctionName(dbName, functionName);
FunctionArgsDefInfo functionArgsDefInfo = null;
FunctionArgsDefInfo functionArgsDefInfo;
if (ctx.functionArguments() != null) {
functionArgsDefInfo = visitFunctionArguments(ctx.functionArguments());
} else {
functionArgsDefInfo = new FunctionArgsDefInfo(new ArrayList<>(), false);
}
return new DropFunctionCommand(setType, ifExists, function, functionArgsDefInfo);
}
Expand All @@ -4497,7 +4503,7 @@ public FunctionArgsDefInfo visitFunctionArguments(FunctionArgumentsContext ctx)
if (child instanceof FunctionArgumentContext) {
DataType dataType = visitFunctionArgument((FunctionArgumentContext) child);
if (dataType != null) {
argTypeDefs.add(dataType);
argTypeDefs.add(dataType.conversion());
} else {
isVariadic = true;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,11 @@
package org.apache.doris.nereids.trees.plans.commands;

import org.apache.doris.analysis.Expr;
import org.apache.doris.analysis.FunctionCallExpr;
import org.apache.doris.analysis.FunctionName;
import org.apache.doris.analysis.FunctionParams;
import org.apache.doris.analysis.SetType;
import org.apache.doris.analysis.SlotRef;
import org.apache.doris.analysis.StmtType;
import org.apache.doris.catalog.AggregateFunction;
import org.apache.doris.catalog.AliasFunction;
Expand All @@ -43,10 +46,26 @@
import org.apache.doris.common.util.Util;
import org.apache.doris.mysql.privilege.PrivPredicate;
import org.apache.doris.nereids.CascadesContext;
import org.apache.doris.nereids.analyzer.Scope;
import org.apache.doris.nereids.analyzer.UnboundSlot;
import org.apache.doris.nereids.glue.translator.ExpressionTranslator;
import org.apache.doris.nereids.glue.translator.PlanTranslatorContext;
import org.apache.doris.nereids.properties.PhysicalProperties;
import org.apache.doris.nereids.rules.analysis.ExpressionAnalyzer;
import org.apache.doris.nereids.rules.expression.ExpressionRewriteContext;
import org.apache.doris.nereids.trees.expressions.Add;
import org.apache.doris.nereids.trees.expressions.BitAnd;
import org.apache.doris.nereids.trees.expressions.BitNot;
import org.apache.doris.nereids.trees.expressions.BitOr;
import org.apache.doris.nereids.trees.expressions.BitXor;
import org.apache.doris.nereids.trees.expressions.Divide;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.IntegralDivide;
import org.apache.doris.nereids.trees.expressions.Mod;
import org.apache.doris.nereids.trees.expressions.Multiply;
import org.apache.doris.nereids.trees.expressions.SlotReference;
import org.apache.doris.nereids.trees.expressions.Subtract;
import org.apache.doris.nereids.trees.expressions.functions.BoundFunction;
import org.apache.doris.nereids.trees.plans.PlanType;
import org.apache.doris.nereids.trees.plans.commands.info.FunctionArgsDefInfo;
import org.apache.doris.nereids.trees.plans.logical.LogicalEmptyRelation;
Expand All @@ -67,6 +86,7 @@
import io.grpc.ManagedChannel;
import io.grpc.netty.NettyChannelBuilder;
import org.apache.commons.codec.binary.Hex;
import org.apache.commons.collections.map.CaseInsensitiveMap;
import org.apache.commons.lang3.StringUtils;

import java.io.IOException;
Expand Down Expand Up @@ -211,6 +231,7 @@ public void run(ConnectContext ctx, StmtExecutor executor) throws Exception {
String dbName = functionName.getDb();
if (dbName == null) {
dbName = ctx.getDatabase();
functionName.setDb(dbName);
}
Database db = Env.getCurrentInternalCatalog().getDbOrDdlException(dbName);
db.addFunction(function, ifNotExists);
Expand Down Expand Up @@ -891,18 +912,144 @@ private TFunctionBinaryType getFunctionBinaryType(String type) {
private void analyzeAliasFunction(ConnectContext ctx) throws AnalysisException {
function = AliasFunction.createFunction(functionName, argsDef.getArgTypes(),
Type.VARCHAR, argsDef.isVariadic(), parameters, translateToLegacyExpr(originFunction, ctx));
((AliasFunction) function).analyze();
}

/**
* translate to legacy expr, which do not need complex expression and table columns
*/
private Expr translateToLegacyExpr(Expression expression, ConnectContext ctx) {
private Expr translateToLegacyExpr(Expression expression, ConnectContext ctx) throws AnalysisException {
LogicalEmptyRelation plan = new LogicalEmptyRelation(
ConnectContext.get().getStatementContext().getNextRelationId(), new ArrayList<>());
CascadesContext cascadesContext = CascadesContext.initContext(ctx.getStatementContext(), plan,
PhysicalProperties.ANY);
Map<String, DataType> argTypeMap = new CaseInsensitiveMap();
List<DataType> argTypes = argsDef.getArgTypeDefs();
if (!parameters.isEmpty()) {
if (parameters.size() != argTypes.size()) {
throw new AnalysisException(String.format("arguments' size must be same as parameters' size,"
+ "arguments : %s, parameters : %s", argTypes.size(), parameters.size()));
}
for (int i = 0; i < parameters.size(); ++i) {
argTypeMap.put(parameters.get(i), argTypes.get(i));
}
}
ExpressionAnalyzer analyzer = new CustomExpressionAnalyzer(cascadesContext, argTypeMap);
expression = analyzer.analyze(expression);

PlanTranslatorContext translatorContext = new PlanTranslatorContext(cascadesContext);
return ExpressionTranslator.translate(expression, translatorContext);
ExpressionToExpr translator = new ExpressionToExpr();
return expression.accept(translator, translatorContext);
}

private static class CustomExpressionAnalyzer extends ExpressionAnalyzer {
private Map<String, DataType> argTypeMap;

public CustomExpressionAnalyzer(CascadesContext cascadesContext, Map<String, DataType> argTypeMap) {
super(null, new Scope(ImmutableList.of()), cascadesContext, false, false);
this.argTypeMap = argTypeMap;
}

@Override
public Expression visitUnboundSlot(UnboundSlot unboundSlot, ExpressionRewriteContext context) {
DataType dataType = argTypeMap.get(unboundSlot.getName());
if (dataType == null) {
throw new org.apache.doris.nereids.exceptions.AnalysisException(
String.format("param %s's datatype is missed", unboundSlot.getName()));
}
return new SlotReference(unboundSlot.getName(), dataType);
}
}

private static class ExpressionToExpr extends ExpressionTranslator {
@Override
public Expr visitSlotReference(SlotReference slotReference, PlanTranslatorContext context) {
SlotRef slotRef = new SlotRef(slotReference.getDataType().toCatalogDataType(), slotReference.nullable());
slotRef.setLabel(slotReference.getName());
slotRef.setCol(slotReference.getName());
slotRef.setDisableTableName(true);
return slotRef;
}

@Override
public Expr visitBoundFunction(BoundFunction function, PlanTranslatorContext context) {
return makeFunctionCallExpr(function, function.getName(), function.hasVarArguments(), context);
}

@Override
public Expr visitAdd(Add add, PlanTranslatorContext context) {
return makeFunctionCallExpr(add, "add", false, context);
}

@Override
public Expr visitSubtract(Subtract subtract, PlanTranslatorContext context) {
return makeFunctionCallExpr(subtract, "subtract", false, context);
}

@Override
public Expr visitMultiply(Multiply multiply, PlanTranslatorContext context) {
return makeFunctionCallExpr(multiply, "multiply", false, context);
}

@Override
public Expr visitDivide(Divide divide, PlanTranslatorContext context) {
return makeFunctionCallExpr(divide, "divide", false, context);
}

@Override
public Expr visitIntegralDivide(IntegralDivide integralDivide, PlanTranslatorContext context) {
return makeFunctionCallExpr(integralDivide, "integralDivide", false, context);
}

@Override
public Expr visitMod(Mod mod, PlanTranslatorContext context) {
return makeFunctionCallExpr(mod, "mod", false, context);
}

@Override
public Expr visitBitAnd(BitAnd bitAnd, PlanTranslatorContext context) {
return makeFunctionCallExpr(bitAnd, "bitAnd", false, context);
}

@Override
public Expr visitBitOr(BitOr bitOr, PlanTranslatorContext context) {
return makeFunctionCallExpr(bitOr, "bitOr", false, context);
}

@Override
public Expr visitBitXor(BitXor bitXor, PlanTranslatorContext context) {
return makeFunctionCallExpr(bitXor, "bitXor", false, context);
}

@Override
public Expr visitBitNot(BitNot bitNot, PlanTranslatorContext context) {
return makeFunctionCallExpr(bitNot, "bitNot", false, context);
}

private Expr makeFunctionCallExpr(Expression expression, String name, boolean hasVarArguments,
PlanTranslatorContext context) {
List<Expr> arguments = expression.getArguments().stream()
.map(arg -> arg.accept(this, context))
.collect(Collectors.toList());

List<Type> argTypes = expression.getArguments().stream()
.map(Expression::getDataType)
.map(DataType::toCatalogDataType)
.collect(Collectors.toList());

NullableMode nullableMode = expression.nullable()
? NullableMode.ALWAYS_NULLABLE
: NullableMode.ALWAYS_NOT_NULLABLE;

org.apache.doris.catalog.ScalarFunction catalogFunction = new org.apache.doris.catalog.ScalarFunction(
new FunctionName(name), argTypes,
expression.getDataType().toCatalogDataType(), hasVarArguments,
"", TFunctionBinaryType.BUILTIN, true, true, nullableMode);

FunctionCallExpr functionCallExpr;
// create catalog FunctionCallExpr without analyze again
functionCallExpr = new FunctionCallExpr(catalogFunction, new FunctionParams(false, arguments));
functionCallExpr.setNullableFromNereids(expression.nullable());
return functionCallExpr;
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,9 @@
import org.apache.doris.catalog.Database;
import org.apache.doris.catalog.Env;
import org.apache.doris.catalog.FunctionSearchDesc;
import org.apache.doris.common.ErrorCode;
import org.apache.doris.common.ErrorReport;
import org.apache.doris.mysql.privilege.PrivPredicate;
import org.apache.doris.nereids.trees.plans.PlanType;
import org.apache.doris.nereids.trees.plans.commands.info.FunctionArgsDefInfo;
import org.apache.doris.nereids.trees.plans.visitor.PlanVisitor;
Expand Down Expand Up @@ -64,6 +67,10 @@ public DropFunctionCommand(SetType setType, boolean ifExists, FunctionName funct

@Override
public void run(ConnectContext ctx, StmtExecutor executor) throws Exception {
// check operation privilege
if (!Env.getCurrentEnv().getAccessManager().checkGlobalPriv(ConnectContext.get(), PrivPredicate.ADMIN)) {
ErrorReport.reportAnalysisException(ErrorCode.ERR_SPECIFIC_ACCESS_DENIED_ERROR, "ADMIN");
}
argsDef.validate();
function = new FunctionSearchDesc(functionName, argsDef.getArgTypes(), argsDef.isVariadic());
if (SetType.GLOBAL.equals(setType)) {
Expand All @@ -72,6 +79,7 @@ public void run(ConnectContext ctx, StmtExecutor executor) throws Exception {
String dbName = functionName.getDb();
if (dbName == null) {
dbName = ctx.getDatabase();
functionName.setDb(dbName);
}
Database db = Env.getCurrentInternalCatalog().getDbOrDdlException(dbName);
db.dropFunction(function, ifExists);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,10 @@ public Type[] getArgTypes() {
return argTypes;
}

public List<DataType> getArgTypeDefs() {
return argTypeDefs;
}

public boolean isVariadic() {
return isVariadic;
}
Expand Down
4 changes: 2 additions & 2 deletions regression-test/suites/ddl_p0/test_alias_function.groovy
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,10 @@
suite("test_alias_function") {

sql """DROP FUNCTION IF EXISTS mesh_udf_test1(INT,INT)"""
sql """CREATE ALIAS FUNCTION IF NOT EXISTS mesh_udf_test1(INT,INT) WITH PARAMETER(n,d) AS ROUND(1+floor(n/d));"""
sql """CREATE ALIAS FUNCTION mesh_udf_test1(INT,INT) WITH PARAMETER(n,d) AS ROUND(1+floor(n/d));"""
qt_sql1 """select mesh_udf_test1(1,2);"""

sql """DROP FUNCTION IF EXISTS mesh_udf_test2(INT,INT)"""
sql """CREATE ALIAS FUNCTION IF NOT EXISTS mesh_udf_test2(INT,INT) WITH PARAMETER(n,d) AS add(1,floor(divide(n,d)))"""
sql """CREATE ALIAS FUNCTION mesh_udf_test2(INT,INT) WITH PARAMETER(n,d) AS add(1,floor(divide(n,d)))"""
qt_sql1 """select mesh_udf_test2(1,2);"""
}
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,12 @@
// under the License.

suite('test_alias_function', "arrow_flight_sql") {
sql '''
DROP FUNCTION IF EXISTS f1()
'''
sql '''
DROP FUNCTION IF EXISTS f2()
'''
sql '''
CREATE ALIAS FUNCTION IF NOT EXISTS f1(DATETIMEV2(3), INT)
with PARAMETER (datetime1, int1) as date_trunc(days_sub(datetime1, int1), 'day')'''
Expand Down

0 comments on commit 02ccc8e

Please sign in to comment.