Skip to content

Commit

Permalink
[8.x] ESQL: optimise aggregations filtered by false/null into evals (e…
Browse files Browse the repository at this point in the history
…lastic#115858) (elastic#116713)

* ESQL: optimise aggregations filtered by false/null into evals (elastic#115858)

This adds a new optimiser rule to extract aggregate functions filtered by a `FALSE` or `NULL` into evals. The value taken by the evaluation is `0L`, for `COUNT()` and `COUNT_DISTINCT()`, `NULL` otherwise.

Example:
```
... | STATS x = someAgg(y) WHERE FALSE {BY z} | ...
=>
... | STATS x = someAgg(y) {BY z} > | EVAL x = NULL | KEEP x{, z} | ...
```

Related: elastic#114352.

* swap out list's getFirst/Last
  • Loading branch information
bpintea authored Nov 13, 2024
1 parent 08f8312 commit fa541d2
Show file tree
Hide file tree
Showing 7 changed files with 505 additions and 7 deletions.
5 changes: 5 additions & 0 deletions docs/changelog/115858.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
pr: 115858
summary: "ESQL: optimise aggregations filtered by false/null into evals"
area: ES|QL
type: enhancement
issues: []
110 changes: 110 additions & 0 deletions x-pack/plugin/esql/qa/testFixtures/src/main/resources/stats.csv-spec
Original file line number Diff line number Diff line change
Expand Up @@ -2382,6 +2382,116 @@ max:integer |max_a:integer|min:integer | min_a:integer
74999 |null |25324 | null
;

statsWithAllFiltersFalse
required_capability: per_agg_filtering
from employees
| stats max = max(height.float) where false,
min = min(height.float) where to_string(null) == "abc",
count = count(height.float) where false,
count_distinct = count_distinct(salary) where to_string(null) == "def"
;

max:double |min:double |count:long |count_distinct:long
null |null |0 |0
;

statsWithExpressionsAllFiltersFalse
required_capability: per_agg_filtering
from employees
| stats max = max(height.float + 1) where null,
count = count(height.float) + 2 where false,
mix = min(height.float + 1) + count_distinct(emp_no) + 2 where length(null) == 3
;

max:double |count:long |mix:double
null |2 |null
;

statsWithFalseFilterAndGroup
required_capability: per_agg_filtering
from employees
| stats max = max(height.float + 1) where null,
count = count(height.float) + 2 where false
by job_positions
| sort job_positions
| limit 4
;

max:double |count:long |job_positions:keyword
null |2 |Accountant
null |2 |Architect
null |2 |Business Analyst
null |2 |Data Scientist
;

statsWithFalseFiltersAndGroups
required_capability: per_agg_filtering
from employees
| eval my_length = length(concat(first_name, null))
| stats count_distinct = count_distinct(height.float + 1) where null,
count = count(height.float) + 2 where false,
values = values(first_name) where my_length > 3
by job_positions, is_rehired
| sort job_positions, is_rehired
| limit 10
;

count_distinct:long |count:long |values:keyword |job_positions:keyword |is_rehired:boolean
0 |2 |null |Accountant |false
0 |2 |null |Accountant |true
0 |2 |null |Accountant |null
0 |2 |null |Architect |false
0 |2 |null |Architect |true
0 |2 |null |Architect |null
0 |2 |null |Business Analyst |false
0 |2 |null |Business Analyst |true
0 |2 |null |Business Analyst |null
0 |2 |null |Data Scientist |false
;

statsWithMixedFiltersAndGroup
required_capability: per_agg_filtering
from employees
| eval my_length = length(concat(first_name, null))
| stats count = count(my_length) where false,
values = mv_slice(mv_sort(values(first_name)), 0, 1)
by job_positions
| sort job_positions
| limit 4
;

count:long |values:keyword |job_positions:keyword
0 |[Arumugam, Bojan] |Accountant
0 |[Alejandro, Charlene] |Architect
0 |[Basil, Breannda] |Business Analyst
0 |[Berni, Breannda] |Data Scientist
;

prunedStatsFollowedByStats
from employees
| eval my_length = length(concat(first_name, null))
| stats count = count(my_length) where false,
values = mv_slice(values(first_name), 0, 1) where my_length > 0
| stats count_distinct = count_distinct(count)
;

count_distinct:long
1
;

statsWithFalseFiltersFromRow
required_capability: per_agg_filtering
row x = null, a = 1, b = [2,3,4]
| stats c=max(a) where x
by b
;

c:integer |b:integer
null |2
null |3
null |4
;

statsWithBasicExpressionFiltered
required_capability: per_agg_filtering
from employees
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,15 @@
package org.elasticsearch.xpack.esql.optimizer;

import org.elasticsearch.xpack.esql.optimizer.rules.logical.PropagateEmptyRelation;
import org.elasticsearch.xpack.esql.optimizer.rules.logical.ReplaceStatsFilteredAggWithEval;
import org.elasticsearch.xpack.esql.optimizer.rules.logical.local.InferIsNotNull;
import org.elasticsearch.xpack.esql.optimizer.rules.logical.local.InferNonNullAggConstraint;
import org.elasticsearch.xpack.esql.optimizer.rules.logical.local.LocalPropagateEmptyRelation;
import org.elasticsearch.xpack.esql.optimizer.rules.logical.local.ReplaceMissingFieldWithNull;
import org.elasticsearch.xpack.esql.optimizer.rules.logical.local.ReplaceTopNWithLimitAndSort;
import org.elasticsearch.xpack.esql.plan.logical.LogicalPlan;
import org.elasticsearch.xpack.esql.rule.ParameterizedRuleExecutor;
import org.elasticsearch.xpack.esql.rule.Rule;

import java.util.ArrayList;
import java.util.List;
Expand Down Expand Up @@ -50,20 +52,31 @@ protected List<Batch<LogicalPlan>> batches() {
rules.add(local);
// TODO: if the local rules haven't touched the tree, the rest of the rules can be skipped
rules.addAll(asList(operators(), cleanup()));
replaceRules(rules);
return rules;
return replaceRules(rules);
}

@SuppressWarnings("unchecked")
private List<Batch<LogicalPlan>> replaceRules(List<Batch<LogicalPlan>> listOfRules) {
for (Batch<LogicalPlan> batch : listOfRules) {
List<Batch<LogicalPlan>> newBatches = new ArrayList<>(listOfRules.size());
for (var batch : listOfRules) {
var rules = batch.rules();
for (int i = 0; i < rules.length; i++) {
if (rules[i] instanceof PropagateEmptyRelation) {
rules[i] = new LocalPropagateEmptyRelation();
List<Rule<?, LogicalPlan>> newRules = new ArrayList<>(rules.length);
boolean updated = false;
for (var r : rules) {
if (r instanceof PropagateEmptyRelation) {
newRules.add(new LocalPropagateEmptyRelation());
updated = true;
} else if (r instanceof ReplaceStatsFilteredAggWithEval) {
// skip it: once a fragment contains an Agg, this can no longer be pruned, which the rule can do
updated = true;
} else {
newRules.add(r);
}
}
batch = updated ? batch.with(newRules.toArray(Rule[]::new)) : batch;
newBatches.add(batch);
}
return listOfRules;
return newBatches;
}

public LogicalPlan localOptimize(LogicalPlan plan) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@
import org.elasticsearch.xpack.esql.optimizer.rules.logical.ReplaceLimitAndSortAsTopN;
import org.elasticsearch.xpack.esql.optimizer.rules.logical.ReplaceOrderByExpressionWithEval;
import org.elasticsearch.xpack.esql.optimizer.rules.logical.ReplaceRegexMatch;
import org.elasticsearch.xpack.esql.optimizer.rules.logical.ReplaceStatsFilteredAggWithEval;
import org.elasticsearch.xpack.esql.optimizer.rules.logical.ReplaceTrivialTypeConversions;
import org.elasticsearch.xpack.esql.optimizer.rules.logical.SetAsOptimized;
import org.elasticsearch.xpack.esql.optimizer.rules.logical.SimplifyComparisonsArithmetics;
Expand Down Expand Up @@ -170,6 +171,7 @@ protected static Batch<LogicalPlan> operators() {
new CombineBinaryComparisons(),
new CombineDisjunctions(),
new SimplifyComparisonsArithmetics(DataType::areCompatible),
new ReplaceStatsFilteredAggWithEval(),
// prune/elimination
new PruneFilters(),
new PruneColumns(),
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/

package org.elasticsearch.xpack.esql.optimizer.rules.logical;

import org.elasticsearch.compute.data.Block;
import org.elasticsearch.compute.data.BlockUtils;
import org.elasticsearch.xpack.esql.core.expression.Alias;
import org.elasticsearch.xpack.esql.core.expression.Attribute;
import org.elasticsearch.xpack.esql.core.expression.Literal;
import org.elasticsearch.xpack.esql.core.expression.NamedExpression;
import org.elasticsearch.xpack.esql.core.tree.Source;
import org.elasticsearch.xpack.esql.expression.function.aggregate.AggregateFunction;
import org.elasticsearch.xpack.esql.expression.function.aggregate.Count;
import org.elasticsearch.xpack.esql.expression.function.aggregate.CountDistinct;
import org.elasticsearch.xpack.esql.plan.logical.Aggregate;
import org.elasticsearch.xpack.esql.plan.logical.Eval;
import org.elasticsearch.xpack.esql.plan.logical.LogicalPlan;
import org.elasticsearch.xpack.esql.plan.logical.Project;
import org.elasticsearch.xpack.esql.plan.logical.local.LocalRelation;
import org.elasticsearch.xpack.esql.plan.logical.local.LocalSupplier;
import org.elasticsearch.xpack.esql.planner.PlannerUtils;

import java.util.ArrayList;
import java.util.List;

/**
* Replaces an aggregation function having a false/null filter with an EVAL node.
* <pre>
* ... | STATS x = someAgg(y) WHERE FALSE {BY z} | ...
* =>
* ... | STATS x = someAgg(y) {BY z} > | EVAL x = NULL | KEEP x{, z} | ...
* </pre>
*/
public class ReplaceStatsFilteredAggWithEval extends OptimizerRules.OptimizerRule<Aggregate> {
@Override
protected LogicalPlan rule(Aggregate aggregate) {
int oldAggSize = aggregate.aggregates().size();
List<NamedExpression> newAggs = new ArrayList<>(oldAggSize);
List<Alias> newEvals = new ArrayList<>(oldAggSize);
List<NamedExpression> newProjections = new ArrayList<>(oldAggSize);

for (var ne : aggregate.aggregates()) {
if (ne instanceof Alias alias
&& alias.child() instanceof AggregateFunction aggFunction
&& aggFunction.hasFilter()
&& aggFunction.filter() instanceof Literal literal
&& Boolean.FALSE.equals(literal.fold())) {

Object value = aggFunction instanceof Count || aggFunction instanceof CountDistinct ? 0L : null;
Alias newAlias = alias.replaceChild(Literal.of(aggFunction, value));
newEvals.add(newAlias);
newProjections.add(newAlias.toAttribute());
} else {
newAggs.add(ne); // agg function unchanged or grouping key
newProjections.add(ne.toAttribute());
}
}

LogicalPlan plan = aggregate;
if (newEvals.isEmpty() == false) {
if (newAggs.isEmpty()) { // the Aggregate node is pruned
plan = localRelation(aggregate.source(), newEvals);
} else {
plan = aggregate.with(aggregate.child(), aggregate.groupings(), newAggs);
plan = new Eval(aggregate.source(), plan, newEvals);
plan = new Project(aggregate.source(), plan, newProjections);
}
}
return plan;
}

private static LocalRelation localRelation(Source source, List<Alias> newEvals) {
Block[] blocks = new Block[newEvals.size()];
List<Attribute> attributes = new ArrayList<>(newEvals.size());
for (int i = 0; i < newEvals.size(); i++) {
Alias alias = newEvals.get(i);
attributes.add(alias.toAttribute());
blocks[i] = BlockUtils.constantBlock(PlannerUtils.NON_BREAKING_BLOCK_FACTORY, ((Literal) alias.child()).value(), 1);
}
return new LocalRelation(source, attributes, LocalSupplier.of(blocks));

}
}
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,10 @@ public String name() {
return name;
}

public Batch<TreeType> with(Rule<?, TreeType>[] rules) {
return new Batch<>(name, limit, rules);
}

public Rule<?, TreeType>[] rules() {
return rules;
}
Expand Down
Loading

0 comments on commit fa541d2

Please sign in to comment.