Skip to content

Commit

Permalink
Stop retaining reference to intermediary aggregation results in Query…
Browse files Browse the repository at this point in the history
…PhaseResultConsumer (elastic#119984)

We retained a reference to the partial `MergeResult` until after the search response has been sent.
This can waste a lot of memory in some cases where partial merges don't do much to reduce memory consumption.
Lets `null` out all the fields that may retain heavyweight references on `reduce`.
Also, creating new lists saves churn and makes it easier to reason about things for the 2 mutable lists
this makes non-final and saves some copying.
  • Loading branch information
original-brownbear authored Jan 16, 2025
1 parent 6b1112d commit 4b59fa7
Showing 1 changed file with 55 additions and 59 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,6 @@

import java.util.ArrayDeque;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.Comparator;
import java.util.List;
Expand Down Expand Up @@ -67,8 +66,8 @@ public class QueryPhaseResultConsumer extends ArraySearchPhaseResults<SearchPhas
private final Consumer<Exception> onPartialMergeFailure;

private final int batchReduceSize;
private final List<QuerySearchResult> buffer = new ArrayList<>();
private final List<SearchShard> emptyResults = new ArrayList<>();
private List<QuerySearchResult> buffer = new ArrayList<>();
private List<SearchShard> emptyResults = new ArrayList<>();
// the memory that is accounted in the circuit breaker for this consumer
private volatile long circuitBreakerBytes;
// the memory that is currently used in the buffer
Expand Down Expand Up @@ -159,32 +158,40 @@ public SearchPhaseController.ReducedQueryPhase reduce() throws Exception {
if (f != null) {
throw f;
}

List<QuerySearchResult> buffer;
synchronized (this) {
// final reduce, we're done with the buffer so we just null it out and continue with a local variable to
// save field references. The synchronized block is never contended but needed to have a memory barrier and sync buffer's
// contents with all the previous writers to it
buffer = this.buffer;
buffer = buffer == null ? Collections.emptyList() : buffer;
this.buffer = null;
}
// ensure consistent ordering
sortBuffer();
buffer.sort(RESULT_COMPARATOR);
final TopDocsStats topDocsStats = this.topDocsStats;
var mergeResult = this.mergeResult;
this.mergeResult = null;
final int resultSize = buffer.size() + (mergeResult == null ? 0 : 1);
final List<TopDocs> topDocsList = hasTopDocs ? new ArrayList<>(resultSize) : null;
final List<DelayableWriteable<InternalAggregations>> aggsList = hasAggs ? new ArrayList<>(resultSize) : null;
synchronized (this) {
if (mergeResult != null) {
if (topDocsList != null) {
topDocsList.add(mergeResult.reducedTopDocs);
}
if (aggsList != null) {
aggsList.add(DelayableWriteable.referencing(mergeResult.reducedAggs));
}
if (mergeResult != null) {
if (topDocsList != null) {
topDocsList.add(mergeResult.reducedTopDocs);
}
for (QuerySearchResult result : buffer) {
topDocsStats.add(result.topDocs(), result.searchTimedOut(), result.terminatedEarly());
if (topDocsList != null) {
TopDocsAndMaxScore topDocs = result.consumeTopDocs();
setShardIndex(topDocs.topDocs, result.getShardIndex());
topDocsList.add(topDocs.topDocs);
}
if (aggsList != null) {
aggsList.add(result.getAggs());
}
if (aggsList != null) {
aggsList.add(DelayableWriteable.referencing(mergeResult.reducedAggs));
}
}
for (QuerySearchResult result : buffer) {
topDocsStats.add(result.topDocs(), result.searchTimedOut(), result.terminatedEarly());
if (topDocsList != null) {
TopDocsAndMaxScore topDocs = result.consumeTopDocs();
setShardIndex(topDocs.topDocs, result.getShardIndex());
topDocsList.add(topDocs.topDocs);
}
if (aggsList != null) {
aggsList.add(result.getAggs());
}
}
SearchPhaseController.ReducedQueryPhase reducePhase;
Expand All @@ -206,7 +213,7 @@ public SearchPhaseController.ReducedQueryPhase reduce() throws Exception {
performFinalReduce
);
} finally {
releaseAggs();
releaseAggs(buffer);
}
if (hasAggs
// reduced aggregations can be null if all shards failed
Expand All @@ -226,25 +233,25 @@ public SearchPhaseController.ReducedQueryPhase reduce() throws Exception {
);
}
return reducePhase;

}

private static final Comparator<QuerySearchResult> RESULT_COMPARATOR = Comparator.comparingInt(QuerySearchResult::getShardIndex);

private MergeResult partialReduce(
QuerySearchResult[] toConsume,
List<SearchShard> emptyResults,
List<QuerySearchResult> toConsume,
List<SearchShard> processedShards,
TopDocsStats topDocsStats,
MergeResult lastMerge,
int numReducePhases
) {
// ensure consistent ordering
Arrays.sort(toConsume, RESULT_COMPARATOR);
toConsume.sort(RESULT_COMPARATOR);

final List<SearchShard> processedShards = new ArrayList<>(emptyResults);
final TopDocs newTopDocs;
final InternalAggregations newAggs;
final List<DelayableWriteable<InternalAggregations>> aggsList;
final int resultSetSize = toConsume.length + (lastMerge != null ? 1 : 0);
final int resultSetSize = toConsume.size() + (lastMerge != null ? 1 : 0);
if (hasAggs) {
aggsList = new ArrayList<>(resultSetSize);
if (lastMerge != null) {
Expand Down Expand Up @@ -307,12 +314,6 @@ private boolean hasPendingMerges() {
return queue.isEmpty() == false || runningTask.get() != null;
}

void sortBuffer() {
if (buffer.size() > 0) {
buffer.sort(RESULT_COMPARATOR);
}
}

private synchronized void addWithoutBreaking(long size) {
circuitBreaker.addWithoutBreaking(size);
circuitBreakerBytes += size;
Expand Down Expand Up @@ -376,21 +377,21 @@ private void consume(QuerySearchResult result, Runnable next) {
}
}
if (hasFailure == false) {
var b = buffer;
aggsCurrentBufferSize += aggsSize;
// add one if a partial merge is pending
int size = buffer.size() + (hasPartialReduce ? 1 : 0);
int size = b.size() + (hasPartialReduce ? 1 : 0);
if (size >= batchReduceSize) {
hasPartialReduce = true;
executeNextImmediately = false;
QuerySearchResult[] clone = buffer.toArray(QuerySearchResult[]::new);
MergeTask task = new MergeTask(clone, aggsCurrentBufferSize, new ArrayList<>(emptyResults), next);
MergeTask task = new MergeTask(b, aggsCurrentBufferSize, emptyResults, next);
b = buffer = new ArrayList<>();
emptyResults = new ArrayList<>();
aggsCurrentBufferSize = 0;
buffer.clear();
emptyResults.clear();
queue.add(task);
tryExecuteNext();
}
buffer.add(result);
b.add(result);
}
}
}
Expand All @@ -404,10 +405,13 @@ private void consume(QuerySearchResult result, Runnable next) {
}

private void releaseBuffer() {
for (QuerySearchResult querySearchResult : buffer) {
querySearchResult.releaseAggs();
var b = buffer;
if (b != null) {
this.buffer = null;
for (QuerySearchResult querySearchResult : b) {
querySearchResult.releaseAggs();
}
}
buffer.clear();
}

private synchronized void onMergeFailure(Exception exc) {
Expand Down Expand Up @@ -449,7 +453,7 @@ private void tryExecuteNext() {
@Override
protected void doRun() {
MergeTask mergeTask = task;
QuerySearchResult[] toConsume = mergeTask.consumeBuffer();
List<QuerySearchResult> toConsume = mergeTask.consumeBuffer();
while (mergeTask != null) {
final MergeResult thisMergeResult = mergeResult;
long estimatedTotalSize = (thisMergeResult != null ? thisMergeResult.estimatedSize : 0) + mergeTask.aggsBufferSize;
Expand Down Expand Up @@ -512,15 +516,7 @@ public void onFailure(Exception exc) {
});
}

private synchronized void releaseAggs() {
if (hasAggs) {
for (QuerySearchResult result : buffer) {
result.releaseAggs();
}
}
}

private static void releaseAggs(QuerySearchResult... toConsume) {
private static void releaseAggs(List<QuerySearchResult> toConsume) {
for (QuerySearchResult result : toConsume) {
result.releaseAggs();
}
Expand All @@ -535,19 +531,19 @@ private record MergeResult(

private static class MergeTask {
private final List<SearchShard> emptyResults;
private QuerySearchResult[] buffer;
private List<QuerySearchResult> buffer;
private final long aggsBufferSize;
private Runnable next;

private MergeTask(QuerySearchResult[] buffer, long aggsBufferSize, List<SearchShard> emptyResults, Runnable next) {
private MergeTask(List<QuerySearchResult> buffer, long aggsBufferSize, List<SearchShard> emptyResults, Runnable next) {
this.buffer = buffer;
this.aggsBufferSize = aggsBufferSize;
this.emptyResults = emptyResults;
this.next = next;
}

public synchronized QuerySearchResult[] consumeBuffer() {
QuerySearchResult[] toRet = buffer;
public synchronized List<QuerySearchResult> consumeBuffer() {
List<QuerySearchResult> toRet = buffer;
buffer = null;
return toRet;
}
Expand All @@ -559,7 +555,7 @@ public synchronized Runnable consumeListener() {
}

public void cancel() {
QuerySearchResult[] buffer = consumeBuffer();
List<QuerySearchResult> buffer = consumeBuffer();
if (buffer != null) {
releaseAggs(buffer);
}
Expand Down

0 comments on commit 4b59fa7

Please sign in to comment.