Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[fix](nereids) fix UnknownValue's reference in simplify range rule #44637

Merged
merged 11 commits into from
Nov 28, 2024
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -92,12 +92,12 @@ public List<ExpressionPatternMatcher<? extends Expression>> buildRules() {
/** rewrite */
public static Expression rewrite(CompoundPredicate expr, ExpressionRewriteContext context) {
ValueDesc valueDesc = expr.accept(new RangeInference(), context);
Expression exprForNonNull = valueDesc.toExpressionForNonNull();
if (exprForNonNull == null) {
Expression toExpr = valueDesc.toExpression();
if (toExpr == null) {
// this mean cannot simplify
return valueDesc.exprForNonNull;
return valueDesc.toExpr;
}
return exprForNonNull;
return toExpr;
}

private static class RangeInference extends ExpressionVisitor<ValueDesc, ExpressionRewriteContext> {
Expand Down Expand Up @@ -197,18 +197,18 @@ private ValueDesc simplify(ExpressionRewriteContext context,
}

// use UnknownValue to wrap different references
return new UnknownValue(context, valuePerRefs, originExpr, exprOp);
return new UnknownValue(context, originExpr, valuePerRefs, exprOp);
}
}

private abstract static class ValueDesc {
ExpressionRewriteContext context;
Expression exprForNonNull;
Expression toExpr;
Expression reference;

public ValueDesc(ExpressionRewriteContext context, Expression reference, Expression exprForNonNull) {
public ValueDesc(ExpressionRewriteContext context, Expression reference, Expression toExpr) {
this.context = context;
this.exprForNonNull = exprForNonNull;
this.toExpr = toExpr;
this.reference = reference;
}

Expand All @@ -220,28 +220,28 @@ public static ValueDesc union(ExpressionRewriteContext context,
if (count == discrete.values.size()) {
return range;
}
Expression exprForNonNull = FoldConstantRuleOnFE.evaluate(
ExpressionUtils.or(range.exprForNonNull, discrete.exprForNonNull), context);
Expression toExpr = FoldConstantRuleOnFE.evaluate(
ExpressionUtils.or(range.toExpr, discrete.toExpr), context);
List<ValueDesc> sourceValues = reverseOrder
? ImmutableList.of(discrete, range)
: ImmutableList.of(range, discrete);
return new UnknownValue(context, sourceValues, exprForNonNull, ExpressionUtils::or);
return new UnknownValue(context, toExpr, sourceValues, ExpressionUtils::or);
}

public abstract ValueDesc intersect(ValueDesc other);

public static ValueDesc intersect(ExpressionRewriteContext context, RangeValue range, DiscreteValue discrete) {
DiscreteValue result = new DiscreteValue(context, discrete.reference, discrete.exprForNonNull);
DiscreteValue result = new DiscreteValue(context, discrete.reference, discrete.toExpr);
discrete.values.stream().filter(x -> range.range.contains(x)).forEach(result.values::add);
if (!result.values.isEmpty()) {
return result;
}
Expression originExprForNonNull = FoldConstantRuleOnFE.evaluate(
ExpressionUtils.and(range.exprForNonNull, discrete.exprForNonNull), context);
return new EmptyValue(context, range.reference, originExprForNonNull);
Expression originExpr = FoldConstantRuleOnFE.evaluate(
ExpressionUtils.and(range.toExpr, discrete.toExpr), context);
return new EmptyValue(context, range.reference, originExpr);
}

public abstract Expression toExpressionForNonNull();
public abstract Expression toExpression();

public static ValueDesc range(ExpressionRewriteContext context, ComparisonPredicate predicate) {
Literal value = (Literal) predicate.right();
Expand Down Expand Up @@ -271,8 +271,8 @@ public static ValueDesc discrete(ExpressionRewriteContext context, InPredicate i

private static class EmptyValue extends ValueDesc {

public EmptyValue(ExpressionRewriteContext context, Expression reference, Expression exprForNonNull) {
super(context, reference, exprForNonNull);
public EmptyValue(ExpressionRewriteContext context, Expression reference, Expression toExpr) {
super(context, reference, toExpr);
}

@Override
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do not need to change union and intersect of EmptyValue?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do not need to change union and intersect of EmptyValue?

if A op B's references are different, then generate a UnknownValue, this UnknownValue would keep both A and B, no delete any of them.

so if A op B got result is A, or result is B, then we known that A and B's referece must be the same.

so for A op B, if A is EmptyValue, then

  1. if B's reference different with A's reference, then generate a UnknownValue, UnknownValue's toExpression will keep both A and B
  2. if B's reference equals to A's reference, then EmptyValue union B = B is ok, EmptyValue intersect B = EmptyValue is ok too.

Expand All @@ -286,7 +286,7 @@ public ValueDesc intersect(ValueDesc other) {
}

@Override
public Expression toExpressionForNonNull() {
public Expression toExpression() {
if (reference.nullable()) {
return new And(new IsNull(reference), new NullLiteral(BooleanType.INSTANCE));
} else {
Expand All @@ -303,65 +303,63 @@ public Expression toExpressionForNonNull() {
private static class RangeValue extends ValueDesc {
Range<Literal> range;

public RangeValue(ExpressionRewriteContext context, Expression reference, Expression exprForNonNull) {
super(context, reference, exprForNonNull);
public RangeValue(ExpressionRewriteContext context, Expression reference, Expression toExpr) {
super(context, reference, toExpr);
}

@Override
public ValueDesc union(ValueDesc other) {
if (other instanceof EmptyValue) {
return other.union(this);
}
try {
if (other instanceof RangeValue) {
Expression originExprForNonNull = FoldConstantRuleOnFE.evaluate(
ExpressionUtils.or(exprForNonNull, other.exprForNonNull), context);
RangeValue o = (RangeValue) other;
if (range.isConnected(o.range)) {
RangeValue rangeValue = new RangeValue(context, reference, originExprForNonNull);
rangeValue.range = range.span(o.range);
return rangeValue;
}
return new UnknownValue(context, ImmutableList.of(this, other),
originExprForNonNull, ExpressionUtils::or);
if (other instanceof RangeValue) {
Expression originExpr = FoldConstantRuleOnFE.evaluate(
ExpressionUtils.or(toExpr, other.toExpr), context);
RangeValue o = (RangeValue) other;
if (range.isConnected(o.range)) {
RangeValue rangeValue = new RangeValue(context, reference, originExpr);
rangeValue.range = range.span(o.range);
return rangeValue;
}
return new UnknownValue(context, originExpr,
ImmutableList.of(this, other), ExpressionUtils::or);
}
if (other instanceof DiscreteValue) {
return union(context, this, (DiscreteValue) other, false);
} catch (Exception e) {
Expression originExprForNonNull = FoldConstantRuleOnFE.evaluate(
ExpressionUtils.or(exprForNonNull, other.exprForNonNull), context);
return new UnknownValue(context, ImmutableList.of(this, other),
originExprForNonNull, ExpressionUtils::or);
}
Expression originExpr = FoldConstantRuleOnFE.evaluate(
ExpressionUtils.or(toExpr, other.toExpr), context);
return new UnknownValue(context, originExpr,
ImmutableList.of(this, other), ExpressionUtils::or);
}

@Override
public ValueDesc intersect(ValueDesc other) {
if (other instanceof EmptyValue) {
return other.intersect(this);
}
try {
if (other instanceof RangeValue) {
Expression originExprForNonNull = FoldConstantRuleOnFE.evaluate(
ExpressionUtils.and(exprForNonNull, other.exprForNonNull), context);
RangeValue o = (RangeValue) other;
if (range.isConnected(o.range)) {
RangeValue rangeValue = new RangeValue(context, reference, originExprForNonNull);
rangeValue.range = range.intersection(o.range);
return rangeValue;
}
return new EmptyValue(context, reference, originExprForNonNull);
if (other instanceof RangeValue) {
Expression originExpr = FoldConstantRuleOnFE.evaluate(
ExpressionUtils.and(toExpr, other.toExpr), context);
RangeValue o = (RangeValue) other;
if (range.isConnected(o.range)) {
RangeValue rangeValue = new RangeValue(context, reference, originExpr);
rangeValue.range = range.intersection(o.range);
return rangeValue;
}
return new EmptyValue(context, reference, originExpr);
}
if (other instanceof DiscreteValue) {
return intersect(context, this, (DiscreteValue) other);
} catch (Exception e) {
Expression originExprForNonNull = FoldConstantRuleOnFE.evaluate(
ExpressionUtils.and(exprForNonNull, other.exprForNonNull), context);
return new UnknownValue(context, ImmutableList.of(this, other),
originExprForNonNull, ExpressionUtils::and);
}
Expression originExpr = FoldConstantRuleOnFE.evaluate(
ExpressionUtils.and(toExpr, other.toExpr), context);
return new UnknownValue(context, originExpr,
ImmutableList.of(this, other), ExpressionUtils::and);
}

@Override
public Expression toExpressionForNonNull() {
public Expression toExpression() {
List<Expression> result = Lists.newArrayList();
if (range.hasLowerBound()) {
if (range.lowerBoundType() == BoundType.CLOSED) {
Expand Down Expand Up @@ -403,13 +401,13 @@ private static class DiscreteValue extends ValueDesc {
Set<Literal> values;

public DiscreteValue(ExpressionRewriteContext context,
Expression reference, Expression exprForNonNull, Literal... values) {
this(context, reference, exprForNonNull, Arrays.asList(values));
Expression reference, Expression toExpr, Literal... values) {
this(context, reference, toExpr, Arrays.asList(values));
}

public DiscreteValue(ExpressionRewriteContext context,
Expression reference, Expression exprForNonNull, Collection<Literal> values) {
super(context, reference, exprForNonNull);
Expression reference, Expression toExpr, Collection<Literal> values) {
super(context, reference, toExpr);
this.values = Sets.newTreeSet(values);
}

Expand All @@ -418,53 +416,51 @@ public ValueDesc union(ValueDesc other) {
if (other instanceof EmptyValue) {
return other.union(this);
}
try {
if (other instanceof DiscreteValue) {
Expression originExprForNonNull = FoldConstantRuleOnFE.evaluate(
ExpressionUtils.or(exprForNonNull, other.exprForNonNull), context);
DiscreteValue discreteValue = new DiscreteValue(context, reference, originExprForNonNull);
discreteValue.values.addAll(((DiscreteValue) other).values);
discreteValue.values.addAll(this.values);
return discreteValue;
}
if (other instanceof DiscreteValue) {
Expression originExpr = FoldConstantRuleOnFE.evaluate(
ExpressionUtils.or(toExpr, other.toExpr), context);
DiscreteValue discreteValue = new DiscreteValue(context, reference, originExpr);
discreteValue.values.addAll(((DiscreteValue) other).values);
discreteValue.values.addAll(this.values);
return discreteValue;
}
if (other instanceof RangeValue) {
return union(context, (RangeValue) other, this, true);
} catch (Exception e) {
Expression originExprForNonNull = FoldConstantRuleOnFE.evaluate(
ExpressionUtils.or(exprForNonNull, other.exprForNonNull), context);
return new UnknownValue(context, ImmutableList.of(this, other),
originExprForNonNull, ExpressionUtils::or);
}
Expression originExpr = FoldConstantRuleOnFE.evaluate(
ExpressionUtils.or(toExpr, other.toExpr), context);
return new UnknownValue(context, originExpr,
ImmutableList.of(this, other), ExpressionUtils::or);
}

@Override
public ValueDesc intersect(ValueDesc other) {
if (other instanceof EmptyValue) {
return other.intersect(this);
}
try {
if (other instanceof DiscreteValue) {
Expression originExprForNonNull = FoldConstantRuleOnFE.evaluate(
ExpressionUtils.and(exprForNonNull, other.exprForNonNull), context);
DiscreteValue discreteValue = new DiscreteValue(context, reference, originExprForNonNull);
discreteValue.values.addAll(((DiscreteValue) other).values);
discreteValue.values.retainAll(this.values);
if (discreteValue.values.isEmpty()) {
return new EmptyValue(context, reference, originExprForNonNull);
} else {
return discreteValue;
}
if (other instanceof DiscreteValue) {
Expression originExpr = FoldConstantRuleOnFE.evaluate(
ExpressionUtils.and(toExpr, other.toExpr), context);
DiscreteValue discreteValue = new DiscreteValue(context, reference, originExpr);
discreteValue.values.addAll(((DiscreteValue) other).values);
discreteValue.values.retainAll(this.values);
if (discreteValue.values.isEmpty()) {
return new EmptyValue(context, reference, originExpr);
} else {
return discreteValue;
}
}
if (other instanceof RangeValue) {
return intersect(context, (RangeValue) other, this);
} catch (Exception e) {
Expression originExprForNonNull = FoldConstantRuleOnFE.evaluate(
ExpressionUtils.and(exprForNonNull, other.exprForNonNull), context);
return new UnknownValue(context, ImmutableList.of(this, other),
originExprForNonNull, ExpressionUtils::and);
}
Expression originExpr = FoldConstantRuleOnFE.evaluate(
ExpressionUtils.and(toExpr, other.toExpr), context);
return new UnknownValue(context, originExpr,
ImmutableList.of(this, other), ExpressionUtils::and);
}

@Override
public Expression toExpressionForNonNull() {
public Expression toExpression() {
// NOTICE: it's related with `InPredicateToEqualToRule`
// They are same processes, so must change synchronously.
if (values.size() == 1) {
Expand Down Expand Up @@ -498,41 +494,53 @@ private UnknownValue(ExpressionRewriteContext context, Expression expr) {
mergeExprOp = null;
}

public UnknownValue(ExpressionRewriteContext context,
List<ValueDesc> sourceValues, Expression exprForNonNull, BinaryOperator<Expression> mergeExprOp) {
super(context, sourceValues.get(0).reference, exprForNonNull);
public UnknownValue(ExpressionRewriteContext context, Expression toExpr,
List<ValueDesc> sourceValues, BinaryOperator<Expression> mergeExprOp) {
super(context, genReference(sourceValues, toExpr), toExpr);
this.sourceValues = ImmutableList.copyOf(sourceValues);
this.mergeExprOp = mergeExprOp;
}

private static Expression genReference(List<ValueDesc> sourceValues, Expression toExpr) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

add some comment to explain why need generate reference from toExpr

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

add some comment to explain why need generate reference from toExpr

had add comment

Expression reference = sourceValues.get(0).reference;
for (int i = 1; i < sourceValues.size(); i++) {
if (!reference.equals(sourceValues.get(i).reference)) {
return toExpr;
}
}
return reference;
}

@Override
public ValueDesc union(ValueDesc other) {
Expression originExprForNonNull = FoldConstantRuleOnFE.evaluate(
ExpressionUtils.or(exprForNonNull, other.exprForNonNull), context);
return new UnknownValue(context, ImmutableList.of(this, other), originExprForNonNull, ExpressionUtils::or);
Expression originExpr = FoldConstantRuleOnFE.evaluate(
ExpressionUtils.or(toExpr, other.toExpr), context);
return new UnknownValue(context, originExpr,
ImmutableList.of(this, other), ExpressionUtils::or);
}

@Override
public ValueDesc intersect(ValueDesc other) {
Expression originExprForNonNull = FoldConstantRuleOnFE.evaluate(
ExpressionUtils.and(exprForNonNull, other.exprForNonNull), context);
return new UnknownValue(context, ImmutableList.of(this, other), originExprForNonNull, ExpressionUtils::and);
Expression originExpr = FoldConstantRuleOnFE.evaluate(
ExpressionUtils.and(toExpr, other.toExpr), context);
return new UnknownValue(context, originExpr,
ImmutableList.of(this, other), ExpressionUtils::and);
}

@Override
public Expression toExpressionForNonNull() {
public Expression toExpression() {
if (sourceValues.isEmpty()) {
return exprForNonNull;
return toExpr;
}
Expression result = sourceValues.get(0).toExpressionForNonNull();
Expression result = sourceValues.get(0).toExpression();
for (int i = 1; i < sourceValues.size(); i++) {
result = mergeExprOp.apply(result, sourceValues.get(i).toExpressionForNonNull());
result = mergeExprOp.apply(result, sourceValues.get(i).toExpression());
}
result = FoldConstantRuleOnFE.evaluate(result, context);
// ATTN: we must return original expr, because OrToIn is implemented with MutableState,
// newExpr will lose these states leading to dead loop by OrToIn -> SimplifyRange -> FoldConstantByFE
if (result.equals(exprForNonNull)) {
return exprForNonNull;
if (result.equals(toExpr)) {
return toExpr;
}
return result;
}
Expand Down
Loading
Loading