diff --git a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/exec/ControllerImpl.java b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/exec/ControllerImpl.java index 1497bd1022bf..7c0a18f3db94 100644 --- a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/exec/ControllerImpl.java +++ b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/exec/ControllerImpl.java @@ -433,8 +433,10 @@ private void runInternal(final QueryListener queryListener, final Closer closer) } } if (queryKernel != null && queryKernel.isSuccess()) { - // If successful, encourage the tasks to exit successfully. - postFinishToAllTasks(); + // If successful, encourage workers to exit successfully. + // Only send this command to participating workers. For task-based queries, this is all tasks, since tasks + // are launched only when needed. For Dart, this is any servers that were actually assigned work items. + postFinishToWorkers(queryKernel.getAllParticipatingWorkers()); workerManager.stop(false); } else { // If not successful, cancel running tasks. @@ -1462,15 +1464,15 @@ private List findIntervalsToDrop(final Set publishedSegme return IntervalUtils.difference(replaceIntervals, publishIntervals); } - private CounterSnapshotsTree getCountersFromAllTasks() + private CounterSnapshotsTree fetchCountersFromWorkers(final IntSet workers) { final CounterSnapshotsTree retVal = new CounterSnapshotsTree(); final List taskList = getWorkerIds(); final List> futures = new ArrayList<>(); - for (String taskId : taskList) { - futures.add(netClient.getCounters(taskId)); + for (int workerNumber : workers) { + futures.add(netClient.getCounters(taskList.get(workerNumber))); } final List snapshotsTrees = @@ -1483,14 +1485,14 @@ private CounterSnapshotsTree getCountersFromAllTasks() return retVal; } - private void postFinishToAllTasks() + private void postFinishToWorkers(final IntSet workers) { final List taskList = getWorkerIds(); final List> futures = new ArrayList<>(); - for (String taskId : taskList) { - futures.add(netClient.postFinish(taskId)); + for (int workerNumber : workers) { + futures.add(netClient.postFinish(taskList.get(workerNumber))); } FutureUtils.getUnchecked(MSQFutureUtils.allAsList(futures, true), true); @@ -1505,7 +1507,7 @@ private CounterSnapshotsTree makeCountersSnapshotForLiveReports() private CounterSnapshotsTree getFinalCountersSnapshot(@Nullable final ControllerQueryKernel queryKernel) { if (queryKernel != null && queryKernel.isSuccess()) { - return getCountersFromAllTasks(); + return fetchCountersFromWorkers(queryKernel.getAllParticipatingWorkers()); } else { return makeCountersSnapshotForLiveReports(); } diff --git a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/kernel/controller/ControllerQueryKernel.java b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/kernel/controller/ControllerQueryKernel.java index 62a133269093..90c9b496721d 100644 --- a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/kernel/controller/ControllerQueryKernel.java +++ b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/kernel/controller/ControllerQueryKernel.java @@ -27,6 +27,7 @@ import it.unimi.dsi.fastutil.ints.Int2IntMap; import it.unimi.dsi.fastutil.ints.Int2ObjectAVLTreeMap; import it.unimi.dsi.fastutil.ints.Int2ObjectMap; +import it.unimi.dsi.fastutil.ints.IntAVLTreeSet; import it.unimi.dsi.fastutil.ints.IntSet; import org.apache.druid.frame.key.ClusterByPartitions; import org.apache.druid.java.util.common.IAE; @@ -643,6 +644,20 @@ public void failStage(final StageId stageId) doWithStageTracker(stageId, ControllerStageTracker::fail); } + /** + * Returns the set of all worker numbers that have participated in work done so far by this query. + */ + public IntSet getAllParticipatingWorkers() + { + final IntSet retVal = new IntAVLTreeSet(); + + for (final ControllerStageTracker tracker : stageTrackers.values()) { + retVal.addAll(tracker.getWorkerInputs().workers()); + } + + return retVal; + } + /** * Fetches and returns the stage kernel corresponding to the provided stage id, else throws {@link IAE} */ diff --git a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/kernel/controller/ControllerStageTracker.java b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/kernel/controller/ControllerStageTracker.java index 0a62ba24b639..338a35e0d244 100644 --- a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/kernel/controller/ControllerStageTracker.java +++ b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/kernel/controller/ControllerStageTracker.java @@ -63,7 +63,6 @@ import java.util.Map; import java.util.Set; import java.util.TreeMap; -import java.util.stream.IntStream; /** * Controller-side state machine for each stage. Used by {@link ControllerQueryKernel} to form the overall state @@ -137,7 +136,7 @@ private ControllerStageTracker( this.workerInputs = workerInputs; this.maxRetainedPartitionSketchBytes = maxRetainedPartitionSketchBytes; - initializeWorkerState(workerCount); + initializeWorkerState(workerInputs.workers()); if (stageDef.mustGatherResultKeyStatistics()) { this.completeKeyStatisticsInformation = @@ -149,14 +148,13 @@ private ControllerStageTracker( } /** - * Initialize stage for each worker to {@link ControllerWorkerStagePhase#NEW} - * - * @param workerCount + * Initialize stage for each worker to {@link ControllerWorkerStagePhase#NEW}. */ - private void initializeWorkerState(int workerCount) + private void initializeWorkerState(IntSet workers) { - IntStream.range(0, workerCount) - .forEach(wokerNumber -> workerToPhase.put(wokerNumber, ControllerWorkerStagePhase.NEW)); + for (int workerNumber : workers) { + workerToPhase.put(workerNumber, ControllerWorkerStagePhase.NEW); + } } /**