Skip to content

Commit

Permalink
Enable node-level reduction by default (#119621) (#119688)
Browse files Browse the repository at this point in the history
This change enables node-level reduction by default in ES|QL. However, 
if the coordinator node and the target data node are the same,
node-level reduction is disabled to avoid unnecessary overhead.
  • Loading branch information
dnhatn authored Jan 7, 2025
1 parent 4adebcf commit 94de73a
Show file tree
Hide file tree
Showing 8 changed files with 127 additions and 48 deletions.
5 changes: 5 additions & 0 deletions docs/changelog/119621.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
pr: 119621
summary: Enable node-level reduction by default
area: ES|QL
type: enhancement
issues: []
Original file line number Diff line number Diff line change
Expand Up @@ -155,6 +155,7 @@ static TransportVersion def(int id) {
public static final TransportVersion NODE_SHUTDOWN_EPHEMERAL_ID_ADDED = def(8_815_00_0);
public static final TransportVersion ESQL_CCS_TELEMETRY_STATS = def(8_816_00_0);
public static final TransportVersion TEXT_EMBEDDING_QUERY_VECTOR_BUILDER_INFER_MODEL_ID = def(8_817_00_0);
public static final TransportVersion ESQL_ENABLE_NODE_LEVEL_REDUCTION = def(8_818_00_0);

/*
* STOP! READ THIS FIRST! No, really,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -240,13 +240,13 @@ private ActionFuture<EsqlQueryResponse> startEsql(String query) {
// Report the status after every action
.put("status_interval", "0ms");

if (nodeLevelReduction == false) {
// explicitly set the default (false) or don't
if (nodeLevelReduction) {
// explicitly set the default (true) or don't
if (randomBoolean()) {
settingsBuilder.put("node_level_reduction", nodeLevelReduction);
settingsBuilder.put("node_level_reduction", true);
}
} else {
settingsBuilder.put("node_level_reduction", nodeLevelReduction);
settingsBuilder.put("node_level_reduction", false);
}

var pragmas = new QueryPragmas(settingsBuilder.build());
Expand All @@ -273,14 +273,7 @@ private void cancelTask(TaskId taskId) {
private List<TaskInfo> getTasksStarting() throws Exception {
List<TaskInfo> foundTasks = new ArrayList<>();
assertBusy(() -> {
List<TaskInfo> tasks = client().admin()
.cluster()
.prepareListTasks()
.setActions(DriverTaskRunner.ACTION_NAME)
.setDetailed(true)
.get()
.getTasks();
assertThat(tasks, hasSize(equalTo(3)));
List<TaskInfo> tasks = getDriverTasks();
for (TaskInfo task : tasks) {
assertThat(task.action(), equalTo(DriverTaskRunner.ACTION_NAME));
DriverStatus status = (DriverStatus) task.status();
Expand All @@ -305,14 +298,7 @@ private List<TaskInfo> getTasksStarting() throws Exception {
private List<TaskInfo> getTasksRunning() throws Exception {
List<TaskInfo> foundTasks = new ArrayList<>();
assertBusy(() -> {
List<TaskInfo> tasks = client().admin()
.cluster()
.prepareListTasks()
.setActions(DriverTaskRunner.ACTION_NAME)
.setDetailed(true)
.get()
.getTasks();
assertThat(tasks, hasSize(equalTo(3)));
List<TaskInfo> tasks = getDriverTasks();
for (TaskInfo task : tasks) {
assertThat(task.action(), equalTo(DriverTaskRunner.ACTION_NAME));
DriverStatus status = (DriverStatus) task.status();
Expand All @@ -328,6 +314,37 @@ private List<TaskInfo> getTasksRunning() throws Exception {
return foundTasks;
}

/**
* Fetches tasks until all three driver tasks have started
*/
private List<TaskInfo> getDriverTasks() throws Exception {
List<TaskInfo> foundTasks = new ArrayList<>();
assertBusy(() -> {
List<TaskInfo> tasks = client().admin()
.cluster()
.prepareListTasks()
.setActions(DriverTaskRunner.ACTION_NAME)
.setDetailed(true)
.get()
.getTasks();
assertThat(tasks, hasSize(equalTo(3)));
List<TaskInfo> readTasks = tasks.stream().filter(t -> t.description().equals(READ_DESCRIPTION)).toList();
List<TaskInfo> mergeTasks = tasks.stream().filter(t -> t.description().equals(MERGE_DESCRIPTION)).toList();
assertThat(readTasks, hasSize(1));
assertThat(mergeTasks, hasSize(1));
// node-level reduction is disabled when the target data node is also the coordinator
if (readTasks.get(0).node().equals(mergeTasks.get(0).node())) {
REDUCE_DESCRIPTION = """
\\_ExchangeSourceOperator[]
\\_ExchangeSinkOperator""";
}
List<TaskInfo> reduceTasks = tasks.stream().filter(t -> t.description().equals(REDUCE_DESCRIPTION)).toList();
assertThat(reduceTasks, hasSize(1));
foundTasks.addAll(tasks);
});
return foundTasks;
}

private void assertCancelled(ActionFuture<EsqlQueryResponse> response) throws Exception {
Exception e = expectThrows(Exception.class, response);
Throwable cancelException = ExceptionsHelper.unwrap(e, TaskCancelledException.class);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -358,19 +358,22 @@ private void startComputeOnDataNodes(
exchangeSource.addRemoteSink(remoteSink, true, queryPragmas.concurrentExchangeClients(), ActionListener.noop());
ActionListener<ComputeResponse> computeResponseListener = computeListener.acquireCompute(clusterAlias);
var dataNodeListener = ActionListener.runBefore(computeResponseListener, () -> l.onResponse(null));
final boolean sameNode = transportService.getLocalNode().getId().equals(node.connection.getNode().getId());
var dataNodeRequest = new DataNodeRequest(
childSessionId,
configuration,
clusterAlias,
node.shardIds,
node.aliasFilters,
dataNodePlan,
originalIndices.indices(),
originalIndices.indicesOptions(),
sameNode == false && queryPragmas.nodeLevelReduction()
);
transportService.sendChildRequest(
node.connection,
DATA_ACTION_NAME,
new DataNodeRequest(
childSessionId,
configuration,
clusterAlias,
node.shardIds,
node.aliasFilters,
dataNodePlan,
originalIndices.indices(),
originalIndices.indicesOptions()
),
dataNodeRequest,
parentTask,
TransportRequestOptions.EMPTY,
new ActionListenerResponseHandler<>(dataNodeListener, ComputeResponse::new, esqlExecutor)
Expand Down Expand Up @@ -803,7 +806,7 @@ public void messageReceived(DataNodeRequest request, TransportChannel channel, T
final ActionListener<ComputeResponse> listener = new ChannelActionListener<>(channel);
final PhysicalPlan reductionPlan;
if (request.plan() instanceof ExchangeSinkExec plan) {
reductionPlan = reductionPlan(plan, request.pragmas().nodeLevelReduction());
reductionPlan = reductionPlan(plan, request.runNodeLevelReduction());
} else {
listener.onFailure(new IllegalStateException("expected exchange sink for a remote compute; got " + request.plan()));
return;
Expand All @@ -817,7 +820,8 @@ public void messageReceived(DataNodeRequest request, TransportChannel channel, T
request.aliasFilters(),
request.plan(),
request.indices(),
request.indicesOptions()
request.indicesOptions(),
request.runNodeLevelReduction()
);
try (var computeListener = ComputeListener.create(transportService, (CancellableTask) task, listener)) {
runComputeOnDataNode((CancellableTask) task, sessionId, reductionPlan, request, computeListener);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ final class DataNodeRequest extends TransportRequest implements IndicesRequest.R
private List<ShardId> shardIds;
private String[] indices;
private final IndicesOptions indicesOptions;
private final boolean runNodeLevelReduction;

DataNodeRequest(
String sessionId,
Expand All @@ -63,7 +64,8 @@ final class DataNodeRequest extends TransportRequest implements IndicesRequest.R
Map<Index, AliasFilter> aliasFilters,
PhysicalPlan plan,
String[] indices,
IndicesOptions indicesOptions
IndicesOptions indicesOptions,
boolean runNodeLevelReduction
) {
this.sessionId = sessionId;
this.configuration = configuration;
Expand All @@ -73,6 +75,7 @@ final class DataNodeRequest extends TransportRequest implements IndicesRequest.R
this.plan = plan;
this.indices = indices;
this.indicesOptions = indicesOptions;
this.runNodeLevelReduction = runNodeLevelReduction;
}

DataNodeRequest(StreamInput in) throws IOException {
Expand All @@ -97,6 +100,11 @@ final class DataNodeRequest extends TransportRequest implements IndicesRequest.R
this.indices = shardIds.stream().map(ShardId::getIndexName).distinct().toArray(String[]::new);
this.indicesOptions = IndicesOptions.strictSingleIndexNoExpandForbidClosed();
}
if (in.getTransportVersion().onOrAfter(TransportVersions.ESQL_ENABLE_NODE_LEVEL_REDUCTION)) {
this.runNodeLevelReduction = in.readBoolean();
} else {
this.runNodeLevelReduction = false;
}
}

@Override
Expand All @@ -114,6 +122,9 @@ public void writeTo(StreamOutput out) throws IOException {
out.writeStringArray(indices);
indicesOptions.writeIndicesOptions(out);
}
if (out.getTransportVersion().onOrAfter(TransportVersions.ESQL_ENABLE_NODE_LEVEL_REDUCTION)) {
out.writeBoolean(runNodeLevelReduction);
}
}

@Override
Expand Down Expand Up @@ -186,6 +197,10 @@ PhysicalPlan plan() {
return plan;
}

boolean runNodeLevelReduction() {
return runNodeLevelReduction;
}

@Override
public String getDescription() {
return "shards=" + shardIds + " plan=" + plan;
Expand All @@ -209,11 +224,22 @@ public boolean equals(Object o) {
&& plan.equals(request.plan)
&& getParentTask().equals(request.getParentTask())
&& Arrays.equals(indices, request.indices)
&& indicesOptions.equals(request.indicesOptions);
&& indicesOptions.equals(request.indicesOptions)
&& runNodeLevelReduction == request.runNodeLevelReduction;
}

@Override
public int hashCode() {
return Objects.hash(sessionId, configuration, clusterAlias, shardIds, aliasFilters, plan, Arrays.hashCode(indices), indicesOptions);
return Objects.hash(
sessionId,
configuration,
clusterAlias,
shardIds,
aliasFilters,
plan,
Arrays.hashCode(indices),
indicesOptions,
runNodeLevelReduction
);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ public final class QueryPragmas implements Writeable {

public static final Setting<Integer> MAX_CONCURRENT_SHARDS_PER_NODE = Setting.intSetting("max_concurrent_shards_per_node", 10, 1, 100);

public static final Setting<Boolean> NODE_LEVEL_REDUCTION = Setting.boolSetting("node_level_reduction", false);
public static final Setting<Boolean> NODE_LEVEL_REDUCTION = Setting.boolSetting("node_level_reduction", true);

public static final QueryPragmas EMPTY = new QueryPragmas(Settings.EMPTY);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -92,15 +92,16 @@ protected DataNodeRequest createTestInstance() {
aliasFilters,
physicalPlan,
generateRandomStringArray(10, 10, false, false),
IndicesOptions.fromOptions(randomBoolean(), randomBoolean(), randomBoolean(), randomBoolean())
IndicesOptions.fromOptions(randomBoolean(), randomBoolean(), randomBoolean(), randomBoolean()),
randomBoolean()
);
request.setParentTask(randomAlphaOfLength(10), randomNonNegativeLong());
return request;
}

@Override
protected DataNodeRequest mutateInstance(DataNodeRequest in) throws IOException {
return switch (between(0, 8)) {
return switch (between(0, 9)) {
case 0 -> {
var request = new DataNodeRequest(
randomAlphaOfLength(20),
Expand All @@ -110,7 +111,8 @@ protected DataNodeRequest mutateInstance(DataNodeRequest in) throws IOException
in.aliasFilters(),
in.plan(),
in.indices(),
in.indicesOptions()
in.indicesOptions(),
in.runNodeLevelReduction()
);
request.setParentTask(in.getParentTask());
yield request;
Expand All @@ -124,7 +126,8 @@ protected DataNodeRequest mutateInstance(DataNodeRequest in) throws IOException
in.aliasFilters(),
in.plan(),
in.indices(),
in.indicesOptions()
in.indicesOptions(),
in.runNodeLevelReduction()
);
request.setParentTask(in.getParentTask());
yield request;
Expand All @@ -139,7 +142,8 @@ protected DataNodeRequest mutateInstance(DataNodeRequest in) throws IOException
in.aliasFilters(),
in.plan(),
in.indices(),
in.indicesOptions()
in.indicesOptions(),
in.runNodeLevelReduction()
);
request.setParentTask(in.getParentTask());
yield request;
Expand All @@ -166,7 +170,8 @@ protected DataNodeRequest mutateInstance(DataNodeRequest in) throws IOException
in.aliasFilters(),
mapAndMaybeOptimize(parse(newQuery)),
in.indices(),
in.indicesOptions()
in.indicesOptions(),
in.runNodeLevelReduction()
);
request.setParentTask(in.getParentTask());
yield request;
Expand All @@ -186,7 +191,8 @@ protected DataNodeRequest mutateInstance(DataNodeRequest in) throws IOException
aliasFilters,
in.plan(),
in.indices(),
in.indicesOptions()
in.indicesOptions(),
in.runNodeLevelReduction()
);
request.setParentTask(request.getParentTask());
yield request;
Expand All @@ -200,7 +206,8 @@ protected DataNodeRequest mutateInstance(DataNodeRequest in) throws IOException
in.aliasFilters(),
in.plan(),
in.indices(),
in.indicesOptions()
in.indicesOptions(),
in.runNodeLevelReduction()
);
request.setParentTask(
randomValueOtherThan(request.getParentTask().getNodeId(), () -> randomAlphaOfLength(10)),
Expand All @@ -218,7 +225,8 @@ protected DataNodeRequest mutateInstance(DataNodeRequest in) throws IOException
in.aliasFilters(),
in.plan(),
in.indices(),
in.indicesOptions()
in.indicesOptions(),
in.runNodeLevelReduction()
);
request.setParentTask(request.getParentTask());
yield request;
Expand All @@ -233,7 +241,8 @@ protected DataNodeRequest mutateInstance(DataNodeRequest in) throws IOException
in.aliasFilters(),
in.plan(),
indices,
in.indicesOptions()
in.indicesOptions(),
in.runNodeLevelReduction()
);
request.setParentTask(request.getParentTask());
yield request;
Expand All @@ -251,7 +260,23 @@ protected DataNodeRequest mutateInstance(DataNodeRequest in) throws IOException
in.aliasFilters(),
in.plan(),
in.indices(),
indicesOptions
indicesOptions,
in.runNodeLevelReduction()
);
request.setParentTask(request.getParentTask());
yield request;
}
case 9 -> {
var request = new DataNodeRequest(
in.sessionId(),
in.configuration(),
in.clusterAlias(),
in.shardIds(),
in.aliasFilters(),
in.plan(),
in.indices(),
in.indicesOptions(),
in.runNodeLevelReduction() == false
);
request.setParentTask(request.getParentTask());
yield request;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,8 @@ public void testNoIndexPlaceholder() {
Collections.emptyMap(),
null,
generateRandomStringArray(10, 10, false, false),
IndicesOptions.fromOptions(randomBoolean(), randomBoolean(), randomBoolean(), randomBoolean())
IndicesOptions.fromOptions(randomBoolean(), randomBoolean(), randomBoolean(), randomBoolean()),
randomBoolean()
);

assertThat(request.shardIds(), equalTo(shardIds));
Expand Down

0 comments on commit 94de73a

Please sign in to comment.