Skip to content

Commit

Permalink
[BugFix] add an id argument to mark non-deterministic functions (back…
Browse files Browse the repository at this point in the history
…port #46592) (#46605)

Signed-off-by: packy92 <[email protected]>
Co-authored-by: packy92 <[email protected]>
Co-authored-by: packy92 <[email protected]>
  • Loading branch information
3 people authored Jun 5, 2024
1 parent d4b7d68 commit ffc356a
Show file tree
Hide file tree
Showing 6 changed files with 40 additions and 16 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@

/**
* Scalar operator support function call
* Please be careful when adding new attributes. Rewriting expr operation exists everywhere in the optimizer.
* If you add new attributes, please make sure that the new attributes will not be erased by the rewriting operation.
*/
public class CallOperator extends ScalarOperator {
private String fnName;
Expand All @@ -38,9 +40,6 @@ public class CallOperator extends ScalarOperator {
// Ignore nulls.
private boolean ignoreNulls = false;

// for nonDeterministicFunctions, to reuse it in common exprs
private int id = 0;

public CallOperator(String fnName, Type returnType, List<ScalarOperator> arguments) {
this(fnName, returnType, arguments, null);
}
Expand Down Expand Up @@ -86,10 +85,6 @@ public boolean isAggregate() {
return fn != null && fn instanceof AggregateFunction;
}

public void setId(int id) {
this.id = id;
}

@Override
public String toString() {
return fnName + "(" + (isDistinct ? "distinct " : "") +
Expand Down Expand Up @@ -152,7 +147,7 @@ public ColumnRefSet getUsedColumns() {

@Override
public int hashCode() {
return Objects.hash(fnName, arguments, isDistinct, id);
return Objects.hash(fnName, arguments, isDistinct);
}

@Override
Expand All @@ -167,8 +162,7 @@ public boolean equals(Object obj) {
return isDistinct == other.isDistinct &&
Objects.equals(fnName, other.fnName) &&
Objects.equals(type, other.type) &&
Objects.equals(arguments, other.arguments) &&
id == other.id;
Objects.equals(arguments, other.arguments);
}

@Override
Expand All @@ -185,7 +179,6 @@ public ScalarOperator clone() {
operator.fnName = this.fnName;
operator.isDistinct = this.isDistinct;
operator.ignoreNulls = this.ignoreNulls;
operator.id = this.id;
return operator;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,8 @@ public class BaseScalarOperatorShuttle extends ScalarOperatorVisitor<ScalarOpera
.put(ArraySliceOperator.class, (op, childOps) -> new ArraySliceOperator(op.getType(), childOps))
.put(CallOperator.class, (op, childOps) -> {
CallOperator call = (CallOperator) op;
return new CallOperator(call.getFnName(), call.getType(), childOps, call.getFunction(), call.isDistinct()); })
return new CallOperator(call.getFnName(), call.getType(), childOps, call.getFunction(),
call.isDistinct()); })
.put(PredicateOperator.class, (op, childOps) -> op)
.put(BetweenPredicateOperator.class, (op, childOps) -> {
BetweenPredicateOperator between = (BetweenPredicateOperator) op;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -147,7 +147,7 @@ private ScalarOperator tryRewrite(ScalarOperator operator) {

@Override
public ScalarOperator visitCall(CallOperator call, Void context) {
ScalarOperator operator = new CallOperator(call.getFnName(),
CallOperator operator = new CallOperator(call.getFnName(),
call.getType(),
call.getChildren().stream().map(argument -> argument.accept(this, null)).collect(toImmutableList()),
call.getFunction(),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -603,16 +603,19 @@ public ScalarOperator visitFunctionCall(FunctionCallExpr node, Context context)
.map(child -> visit(child, context.clone(node)))
.collect(Collectors.toList());

// for nonDeterministicFunctions, we need add an argument as its unique id to distinguish
// the reusing behavior in common exprs
if (FunctionSet.nonDeterministicFunctions.contains(node.getFnName().getFunction())) {
arguments.add(ConstantOperator.createInt(columnRefFactory.getNextUniqueId()));
}

CallOperator callOperator = new CallOperator(
node.getFnName().getFunction(),
node.getType(),
arguments,
node.getFn(),
node.getParams().isDistinct());
callOperator.setHints(node.getHints());
if (FunctionSet.nonDeterministicFunctions.contains(node.getFnName().getFunction())) {
callOperator.setId(columnRefFactory.getNextUniqueId());
}
return callOperator;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -486,6 +486,19 @@ public Expr visitCall(CallOperator call, FormatterContext context) {
"",
((ConstantOperator) call.getChild(0)).getBigint());
break;
case "rand":
case "random":
case "uuid":
case "sleep":
List<Expr> arguments = Lists.newArrayList();
if (call.getChildren().size() == 2) {
arguments.add(buildExpr.build(call.getChild(0), context));
}
callExpr = new FunctionCallExpr(call.getFnName(), new FunctionParams(false, arguments));
Preconditions.checkNotNull(call.getFunction());
callExpr.setFn(call.getFunction());
callExpr.setIgnoreNulls(call.getIgnoreNulls());
break;
default:
List<Expr> arg = call.getChildren().stream()
.map(expr -> buildExpr.build(expr, context))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,20 @@ public void testRandomReuse() throws Exception {
" | <slot 2> : random()\n" +
" | <slot 3> : random()");
}

{
String query = "select a, b, a + b from (select random() * 1000 a, random() * 1000 b from t0) t";
String plan = getFragmentPlan(query);
PlanTestBase.assertContains(plan, "1:Project\n" +
" | <slot 4> : 9: multiply\n" +
" | <slot 5> : 10: multiply\n" +
" | <slot 6> : 9: multiply + 10: multiply\n" +
" | common expressions:\n" +
" | <slot 7> : random()\n" +
" | <slot 8> : random()\n" +
" | <slot 9> : 7: random * 1000.0\n" +
" | <slot 10> : 8: random * 1000.0");
}
}

@Test
Expand Down

0 comments on commit ffc356a

Please sign in to comment.