Skip to content

Commit

Permalink
Refactor how SearchHits are processed in ML module (elastic#120258)
Browse files Browse the repository at this point in the history
Don't accumulate Rows on heap to save some heap.
  • Loading branch information
iverase authored Jan 17, 2025
1 parent 9782179 commit 4a2abab
Show file tree
Hide file tree
Showing 6 changed files with 114 additions and 103 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -77,10 +77,10 @@ public TransportPreviewDataFrameAnalyticsAction(
this.clusterService = clusterService;
}

private static Map<String, Object> mergeRow(DataFrameDataExtractor.Row row, List<String> fieldNames) {
return row.getValues() == null
private static Map<String, Object> mergeRow(String[] row, List<String> fieldNames) {
return row == null
? Collections.emptyMap()
: IntStream.range(0, row.getValues().length).boxed().collect(Collectors.toMap(fieldNames::get, i -> row.getValues()[i]));
: IntStream.range(0, row.length).boxed().collect(Collectors.toMap(fieldNames::get, i -> row[i]));
}

@Override
Expand Down Expand Up @@ -121,7 +121,7 @@ void preview(Task task, DataFrameAnalyticsConfig config, ActionListener<Response
).newExtractor(false);
extractor.preview(delegate.delegateFailureAndWrap((l, rows) -> {
List<String> fieldNames = extractor.getFieldNames();
l.onResponse(new Response(rows.stream().map((r) -> mergeRow(r, fieldNames)).collect(Collectors.toList())));
l.onResponse(new Response(rows.stream().map(r -> mergeRow(r, fieldNames)).collect(Collectors.toList())));
}));
}));
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@
import org.elasticsearch.index.query.QueryBuilder;
import org.elasticsearch.index.query.QueryBuilders;
import org.elasticsearch.search.SearchHit;
import org.elasticsearch.search.SearchHits;
import org.elasticsearch.search.fetch.StoredFieldsContext;
import org.elasticsearch.search.sort.SortOrder;
import org.elasticsearch.xpack.core.ClientHelper;
Expand Down Expand Up @@ -107,14 +106,14 @@ public void cancel() {
isCancelled = true;
}

public Optional<List<Row>> next() throws IOException {
public Optional<SearchHit[]> next() throws IOException {
if (hasNext() == false) {
throw new NoSuchElementException();
}

Optional<List<Row>> hits = Optional.ofNullable(nextSearch());
if (hits.isPresent() && hits.get().isEmpty() == false) {
lastSortKey = hits.get().get(hits.get().size() - 1).getSortKey();
Optional<SearchHit[]> hits = Optional.ofNullable(nextSearch());
if (hits.isPresent() && hits.get().length > 0) {
lastSortKey = (long) hits.get()[hits.get().length - 1].getSortValues()[0];
} else {
hasNext = false;
}
Expand All @@ -126,7 +125,7 @@ public Optional<List<Row>> next() throws IOException {
* Does no sorting of the results.
* @param listener To alert with the extracted rows
*/
public void preview(ActionListener<List<Row>> listener) {
public void preview(ActionListener<List<String[]>> listener) {

SearchRequestBuilder searchRequestBuilder = new SearchRequestBuilder(client)
// This ensures the search throws if there are failures and the scroll context gets cleared automatically
Expand Down Expand Up @@ -155,22 +154,24 @@ public void preview(ActionListener<List<Row>> listener) {
return;
}

List<Row> rows = new ArrayList<>(searchResponse.getHits().getHits().length);
List<String[]> rows = new ArrayList<>(searchResponse.getHits().getHits().length);
for (SearchHit hit : searchResponse.getHits().getHits()) {
var unpooled = hit.asUnpooled();
String[] extractedValues = extractValues(unpooled);
rows.add(extractedValues == null ? new Row(null, unpooled, true) : new Row(extractedValues, unpooled, false));
String[] extractedValues = extractValues(hit);
rows.add(extractedValues);
}
delegate.onResponse(rows);
})
);
}

protected List<Row> nextSearch() throws IOException {
protected SearchHit[] nextSearch() throws IOException {
if (isCancelled) {
return null;
}
return tryRequestWithSearchResponse(() -> executeSearchRequest(buildSearchRequest()));
}

private List<Row> tryRequestWithSearchResponse(Supplier<SearchResponse> request) throws IOException {
private SearchHit[] tryRequestWithSearchResponse(Supplier<SearchResponse> request) throws IOException {
try {

// We've set allow_partial_search_results to false which means if something
Expand All @@ -179,7 +180,7 @@ private List<Row> tryRequestWithSearchResponse(Supplier<SearchResponse> request)
try {
LOGGER.trace(() -> "[" + context.jobId + "] Search response was obtained");

List<Row> rows = processSearchResponse(searchResponse);
SearchHit[] rows = processSearchResponse(searchResponse);

// Request was successfully executed and processed so we can restore the flag to retry if a future failure occurs
hasPreviousSearchFailed = false;
Expand Down Expand Up @@ -246,22 +247,12 @@ private void setFetchSource(SearchRequestBuilder searchRequestBuilder) {
}
}

private List<Row> processSearchResponse(SearchResponse searchResponse) {
if (searchResponse.getHits().getHits().length == 0) {
private SearchHit[] processSearchResponse(SearchResponse searchResponse) {
if (isCancelled || searchResponse.getHits().getHits().length == 0) {
hasNext = false;
return null;
}

SearchHits hits = searchResponse.getHits();
List<Row> rows = new ArrayList<>(hits.getHits().length);
for (SearchHit hit : hits) {
if (isCancelled) {
hasNext = false;
break;
}
rows.add(createRow(hit));
}
return rows;
return searchResponse.getHits().asUnpooled().getHits();
}

private String extractNonProcessedValues(SearchHit hit, String organicFeature) {
Expand Down Expand Up @@ -317,14 +308,13 @@ private String[] extractProcessedValue(ProcessedField processedField, SearchHit
return extractedValue;
}

private Row createRow(SearchHit hit) {
var unpooled = hit.asUnpooled();
String[] extractedValues = extractValues(unpooled);
public Row createRow(SearchHit hit) {
String[] extractedValues = extractValues(hit);
if (extractedValues == null) {
return new Row(null, unpooled, true);
return new Row(null, hit, true);
}
boolean isTraining = trainTestSplitter.get().isTraining(extractedValues);
Row row = new Row(extractedValues, unpooled, isTraining);
Row row = new Row(extractedValues, hit, isTraining);
LOGGER.trace(
() -> format(
"[%s] Extracted row: sort key = [%s], is_training = [%s], values = %s",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import org.elasticsearch.common.settings.Settings;
import org.elasticsearch.common.util.concurrent.ThreadContext;
import org.elasticsearch.index.query.QueryBuilders;
import org.elasticsearch.search.SearchHit;
import org.elasticsearch.threadpool.ThreadPool;
import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsConfig;
import org.elasticsearch.xpack.core.ml.job.messages.Messages;
Expand Down Expand Up @@ -256,9 +257,14 @@ private static void writeDataRows(
long rowsProcessed = 0;

while (dataExtractor.hasNext()) {
Optional<List<DataFrameDataExtractor.Row>> rows = dataExtractor.next();
Optional<SearchHit[]> rows = dataExtractor.next();
if (rows.isPresent()) {
for (DataFrameDataExtractor.Row row : rows.get()) {
for (SearchHit searchHit : rows.get()) {
if (dataExtractor.isCancelled()) {
break;
}
rowsProcessed++;
DataFrameDataExtractor.Row row = dataExtractor.createRow(searchHit);
if (row.shouldSkip()) {
dataCountsTracker.incrementSkippedDocsCount();
} else {
Expand All @@ -271,7 +277,6 @@ private static void writeDataRows(
}
}
}
rowsProcessed += rows.get().size();
progressTracker.updateLoadingDataProgress(rowsProcessed >= totalRows ? 100 : (int) (rowsProcessed * 100.0 / totalRows));
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
import org.elasticsearch.common.settings.Settings;
import org.elasticsearch.core.Nullable;
import org.elasticsearch.search.SearchHit;
import org.elasticsearch.search.SearchHits;
import org.elasticsearch.tasks.TaskId;
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
import org.elasticsearch.xpack.ml.dataframe.extractor.DataFrameDataExtractor;
Expand All @@ -22,11 +23,9 @@
import org.elasticsearch.xpack.ml.utils.persistence.ResultsPersisterService;

import java.io.IOException;
import java.util.Collections;
import java.util.Iterator;
import java.util.LinkedHashMap;
import java.util.LinkedList;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Optional;
Expand Down Expand Up @@ -97,6 +96,9 @@ private void addResultAndJoinIfEndOfBatch(RowResults rowResults) {
private void joinCurrentResults() {
try (LimitAwareBulkIndexer bulkIndexer = new LimitAwareBulkIndexer(settings, this::executeBulkRequest)) {
while (currentResults.isEmpty() == false) {
if (dataExtractor.isCancelled()) {
break;
}
RowResults result = currentResults.pop();
DataFrameDataExtractor.Row row = dataFrameRowsIterator.next();
checkChecksumsMatch(row, result);
Expand Down Expand Up @@ -164,20 +166,20 @@ private void consumeDataExtractor() throws IOException {

private class ResultMatchingDataFrameRows implements Iterator<DataFrameDataExtractor.Row> {

private List<DataFrameDataExtractor.Row> currentDataFrameRows = Collections.emptyList();
private SearchHit[] currentDataFrameRows = SearchHits.EMPTY;
private int currentDataFrameRowsIndex;

@Override
public boolean hasNext() {
return dataExtractor.hasNext() || currentDataFrameRowsIndex < currentDataFrameRows.size();
return dataExtractor.hasNext() || currentDataFrameRowsIndex < currentDataFrameRows.length;
}

@Override
public DataFrameDataExtractor.Row next() {
DataFrameDataExtractor.Row row = null;
while (hasNoMatch(row) && hasNext()) {
advanceToNextBatchIfNecessary();
row = currentDataFrameRows.get(currentDataFrameRowsIndex++);
row = dataExtractor.createRow(currentDataFrameRows[currentDataFrameRowsIndex++]);
}

if (hasNoMatch(row)) {
Expand All @@ -191,13 +193,13 @@ private static boolean hasNoMatch(DataFrameDataExtractor.Row row) {
}

private void advanceToNextBatchIfNecessary() {
if (currentDataFrameRowsIndex >= currentDataFrameRows.size()) {
currentDataFrameRows = getNextDataRowsBatch().orElse(Collections.emptyList());
if (currentDataFrameRowsIndex >= currentDataFrameRows.length) {
currentDataFrameRows = getNextDataRowsBatch().orElse(SearchHits.EMPTY);
currentDataFrameRowsIndex = 0;
}
}

private Optional<List<DataFrameDataExtractor.Row>> getNextDataRowsBatch() {
private Optional<SearchHit[]> getNextDataRowsBatch() {
try {
return dataExtractor.next();
} catch (IOException e) {
Expand Down
Loading

0 comments on commit 4a2abab

Please sign in to comment.