diff --git a/extensions-core/multi-stage-query/pom.xml b/extensions-core/multi-stage-query/pom.xml index 8939018661ce..9e637acff009 100644 --- a/extensions-core/multi-stage-query/pom.xml +++ b/extensions-core/multi-stage-query/pom.xml @@ -186,6 +186,11 @@ datasketches-memory provided + + it.unimi.dsi + fastutil + provided + it.unimi.dsi fastutil-core @@ -288,6 +293,13 @@ test-jar test + + org.apache.druid + druid-indexing-service + ${project.parent.version} + test-jar + test + org.apache.druid druid-sql diff --git a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/exec/Controller.java b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/exec/Controller.java index 5e23a42b2fa1..f04286dd7c42 100644 --- a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/exec/Controller.java +++ b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/exec/Controller.java @@ -19,64 +19,42 @@ package org.apache.druid.msq.exec; -import com.fasterxml.jackson.annotation.JsonCreator; -import com.fasterxml.jackson.annotation.JsonProperty; -import org.apache.druid.indexer.TaskStatus; import org.apache.druid.indexer.report.TaskReport; import org.apache.druid.msq.counters.CounterSnapshots; import org.apache.druid.msq.counters.CounterSnapshotsTree; import org.apache.druid.msq.indexing.MSQControllerTask; import org.apache.druid.msq.indexing.client.ControllerChatHandler; import org.apache.druid.msq.indexing.error.MSQErrorReport; +import org.apache.druid.msq.kernel.StageId; import org.apache.druid.msq.statistics.PartialKeyStatisticsInformation; import javax.annotation.Nullable; import java.util.List; /** - * Interface for the controller of a multi-stage query. + * Interface for the controller of a multi-stage query. Each Controller is specific to a particular query. + * + * @see WorkerImpl the production implementation */ public interface Controller { - /** - * POJO for capturing the status of a controller task that is currently running. - */ - class RunningControllerStatus - { - private final String id; - - @JsonCreator - public RunningControllerStatus(String id) - { - this.id = id; - } - - @JsonProperty("id") - public String getId() - { - return id; - } - } - /** * Unique task/query ID for the batch query run by this controller. + * + * Controller IDs must be globally unique. For tasks, this is the task ID from {@link MSQControllerTask#getId()}. */ - String id(); - - /** - * The task which this controller runs. - */ - MSQControllerTask task(); + String queryId(); /** * Runs the controller logic in the current thread. Surrounding classes provide the execution thread. */ - TaskStatus run() throws Exception; + void run(QueryListener listener) throws Exception; /** - * Terminate the query DAG upon a cancellation request. + * Terminate the controller upon a cancellation request. Causes a concurrently-running {@link #run} method in + * a separate thread to cancel all outstanding work and exit. */ - void stopGracefully(); + void stop(); // Worker-to-controller messages @@ -84,13 +62,29 @@ public String getId() * Accepts a {@link PartialKeyStatisticsInformation} and updates the controller key statistics information. If all key * statistics have been gathered, enqueues the task with the {@link WorkerSketchFetcher} to generate partiton boundaries. * This is intended to be called by the {@link ControllerChatHandler}. + * + * @see ControllerClient#postPartialKeyStatistics(StageId, int, PartialKeyStatisticsInformation) + */ + void updatePartialKeyStatisticsInformation( + int stageNumber, + int workerNumber, + Object partialKeyStatisticsInformationObject + ); + + /** + * Sent by workers when they finish reading their input, in cases where they would not otherwise be calling + * {@link #updatePartialKeyStatisticsInformation(int, int, Object)}. + * + * @see ControllerClient#postDoneReadingInput(StageId, int) */ - void updatePartialKeyStatisticsInformation(int stageNumber, int workerNumber, Object partialKeyStatisticsInformationObject); + void doneReadingInput(int stageNumber, int workerNumber); /** * System error reported by a subtask. Note that the errors are organized by * taskId, not by query/stage/worker, because system errors are associated * with a task rather than a specific query/stage/worker execution context. + * + * @see ControllerClient#postWorkerError(String, MSQErrorReport) */ void workerError(MSQErrorReport errorReport); @@ -98,16 +92,22 @@ public String getId() * System warning reported by a subtask. Indicates that the worker has encountered a non-lethal error. Worker should * continue its execution in such a case. If the worker wants to report an error and stop its execution, * please use {@link Controller#workerError} + * + * @see ControllerClient#postWorkerWarning(List) */ void workerWarning(List errorReports); /** * Periodic update of {@link CounterSnapshots} from subtasks. + * + * @see ControllerClient#postCounters(String, CounterSnapshotsTree) */ void updateCounters(String taskId, CounterSnapshotsTree snapshotsTree); /** * Reports that results are ready for a subtask. + * + * @see ControllerClient#postResultsComplete(StageId, int, Object) */ void resultsComplete( String queryId, diff --git a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/exec/ControllerClient.java b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/exec/ControllerClient.java index afd1ece4dad1..405ff4fb9026 100644 --- a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/exec/ControllerClient.java +++ b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/exec/ControllerClient.java @@ -21,6 +21,7 @@ import org.apache.druid.msq.counters.CounterSnapshotsTree; import org.apache.druid.msq.indexing.error.MSQErrorReport; +import org.apache.druid.msq.kernel.StageDefinition; import org.apache.druid.msq.kernel.StageId; import org.apache.druid.msq.statistics.PartialKeyStatisticsInformation; @@ -43,6 +44,21 @@ void postPartialKeyStatistics( PartialKeyStatisticsInformation partialKeyStatisticsInformation ) throws IOException; + /** + * Client side method to tell the controller that a particular stage and worker is done reading its input. + * + * The main purpose of this call is to let the controller know when it can stop running the input stage. This helps + * execution roll smoothly from stage to stage during pipelined execution. For backwards-compatibility reasons + * (this is a newer method, only really useful when pipelining), this call should be skipped if the query is not + * pipelining stages. + * + * Only used when {@link StageDefinition#doesSortDuringShuffle()} and *not* + * {@link StageDefinition#mustGatherResultKeyStatistics()}. When the stage gathers result key statistics, workers + * call {@link #postPartialKeyStatistics(StageId, int, PartialKeyStatisticsInformation)} instead, which has the same + * effect of telling the controller that the worker is done reading its input. + */ + void postDoneReadingInput(StageId stageId, int workerNumber) throws IOException; + /** * Client-side method to update the controller with counters for a particular stage and worker. The controller uses * this to compile live reports, track warnings generated etc. diff --git a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/exec/ControllerContext.java b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/exec/ControllerContext.java index 0aa90688b910..40b114511c28 100644 --- a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/exec/ControllerContext.java +++ b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/exec/ControllerContext.java @@ -21,24 +21,44 @@ import com.fasterxml.jackson.databind.ObjectMapper; import com.google.inject.Injector; -import org.apache.druid.client.coordinator.CoordinatorClient; -import org.apache.druid.indexer.report.TaskReport; import org.apache.druid.indexing.common.actions.TaskActionClient; import org.apache.druid.java.util.common.io.Closer; import org.apache.druid.java.util.emitter.service.ServiceEmitter; +import org.apache.druid.msq.indexing.MSQSpec; +import org.apache.druid.msq.input.InputSpecSlicer; +import org.apache.druid.msq.input.table.SegmentsInputSlice; +import org.apache.druid.msq.input.table.TableInputSpec; +import org.apache.druid.msq.kernel.QueryDefinition; +import org.apache.druid.msq.kernel.controller.ControllerQueryKernelConfig; import org.apache.druid.server.DruidNode; /** - * Context used by multi-stage query controllers. - * - * Useful because it allows test fixtures to provide their own implementations. + * Context used by multi-stage query controllers. Useful because it allows test fixtures to provide their own + * implementations. */ public interface ControllerContext { - ServiceEmitter emitter(); + /** + * Configuration for {@link org.apache.druid.msq.kernel.controller.ControllerQueryKernel}. + */ + ControllerQueryKernelConfig queryKernelConfig(MSQSpec querySpec, QueryDefinition queryDef); + /** + * Callback from the controller implementation to "register" the controller. Used in the indexing task implementation + * to set up the task chat web service. + */ + void registerController(Controller controller, Closer closer); + + /** + * JSON-enabled object mapper. + */ ObjectMapper jsonMapper(); + /** + * Emit a metric using a {@link ServiceEmitter}. + */ + void emitMetric(String metric, Number value); + /** * Provides a way for tasks to request injectable objects. Useful because tasks are not able to request injection * at the time of server startup, because the server doesn't know what tasks it will be running. @@ -51,32 +71,33 @@ public interface ControllerContext DruidNode selfNode(); /** - * Provide access to the Coordinator service. + * Provides an {@link InputSpecSlicer} that slices {@link TableInputSpec} into {@link SegmentsInputSlice}. */ - CoordinatorClient coordinatorClient(); + InputSpecSlicer newTableInputSpecSlicer(); /** - * Provide access to segment actions in the Overlord. + * Provide access to segment actions in the Overlord. Only called for ingestion queries, i.e., where + * {@link MSQSpec#getDestination()} is {@link org.apache.druid.msq.indexing.destination.DataSourceMSQDestination}. */ TaskActionClient taskActionClient(); /** * Provides services about workers: starting, canceling, obtaining status. + * + * @param queryId query ID + * @param querySpec query spec + * @param queryKernelConfig config from {@link #queryKernelConfig(MSQSpec, QueryDefinition)} + * @param workerFailureListener listener that receives callbacks when workers fail */ - WorkerManagerClient workerManager(); - - /** - * Callback from the controller implementation to "register" the controller. Used in the indexing task implementation - * to set up the task chat web service. - */ - void registerController(Controller controller, Closer closer); + WorkerManager newWorkerManager( + String queryId, + MSQSpec querySpec, + ControllerQueryKernelConfig queryKernelConfig, + WorkerFailureListener workerFailureListener + ); /** * Client for communicating with workers. */ - WorkerClient taskClientFor(Controller controller); - /** - * Writes controller task report. - */ - void writeReports(String controllerTaskId, TaskReport.ReportMap reports); + WorkerClient newWorkerClient(); } diff --git a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/exec/ControllerImpl.java b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/exec/ControllerImpl.java index 682e2b484e4e..b10fbe76ecfa 100644 --- a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/exec/ControllerImpl.java +++ b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/exec/ControllerImpl.java @@ -25,23 +25,17 @@ import com.google.common.base.Preconditions; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; -import com.google.common.collect.ImmutableSet; import com.google.common.collect.Iterables; import com.google.common.collect.Iterators; -import com.google.common.util.concurrent.FutureCallback; -import com.google.common.util.concurrent.Futures; +import com.google.common.collect.Ordering; import com.google.common.util.concurrent.ListenableFuture; -import com.google.common.util.concurrent.ListeningExecutorService; import com.google.common.util.concurrent.MoreExecutors; -import com.google.common.util.concurrent.SettableFuture; import it.unimi.dsi.fastutil.ints.Int2ObjectAVLTreeMap; import it.unimi.dsi.fastutil.ints.Int2ObjectMap; import it.unimi.dsi.fastutil.ints.Int2ObjectOpenHashMap; import it.unimi.dsi.fastutil.ints.IntArraySet; import it.unimi.dsi.fastutil.ints.IntList; import it.unimi.dsi.fastutil.ints.IntSet; -import org.apache.calcite.sql.type.SqlTypeName; -import org.apache.druid.client.ImmutableSegmentLoadInfo; import org.apache.druid.client.indexing.ClientCompactionTaskTransformSpec; import org.apache.druid.common.guava.FutureUtils; import org.apache.druid.data.input.StringTuple; @@ -51,7 +45,7 @@ import org.apache.druid.discovery.BrokerClient; import org.apache.druid.error.DruidException; import org.apache.druid.frame.allocation.ArenaMemoryAllocator; -import org.apache.druid.frame.channel.FrameChannelSequence; +import org.apache.druid.frame.channel.ReadableConcatFrameChannel; import org.apache.druid.frame.key.ClusterBy; import org.apache.druid.frame.key.ClusterByPartition; import org.apache.druid.frame.key.ClusterByPartitions; @@ -64,11 +58,9 @@ import org.apache.druid.frame.write.InvalidFieldException; import org.apache.druid.frame.write.InvalidNullByteException; import org.apache.druid.indexer.TaskState; -import org.apache.druid.indexer.TaskStatus; import org.apache.druid.indexer.partitions.DimensionRangePartitionsSpec; import org.apache.druid.indexer.partitions.DynamicPartitionsSpec; import org.apache.druid.indexer.partitions.PartitionsSpec; -import org.apache.druid.indexer.report.TaskContextReport; import org.apache.druid.indexer.report.TaskReport; import org.apache.druid.indexing.common.LockGranularity; import org.apache.druid.indexing.common.TaskLock; @@ -76,7 +68,6 @@ import org.apache.druid.indexing.common.actions.LockListAction; import org.apache.druid.indexing.common.actions.LockReleaseAction; import org.apache.druid.indexing.common.actions.MarkSegmentsAsUnusedAction; -import org.apache.druid.indexing.common.actions.RetrieveUsedSegmentsAction; import org.apache.druid.indexing.common.actions.SegmentAllocateAction; import org.apache.druid.indexing.common.actions.SegmentTransactionalAppendAction; import org.apache.druid.indexing.common.actions.SegmentTransactionalInsertAction; @@ -88,6 +79,7 @@ import org.apache.druid.indexing.common.task.batch.parallel.TombstoneHelper; import org.apache.druid.indexing.overlord.SegmentPublishResult; import org.apache.druid.java.util.common.DateTimes; +import org.apache.druid.java.util.common.Either; import org.apache.druid.java.util.common.IAE; import org.apache.druid.java.util.common.ISE; import org.apache.druid.java.util.common.Intervals; @@ -97,9 +89,6 @@ import org.apache.druid.java.util.common.concurrent.Execs; import org.apache.druid.java.util.common.granularity.Granularities; import org.apache.druid.java.util.common.granularity.Granularity; -import org.apache.druid.java.util.common.guava.Sequences; -import org.apache.druid.java.util.common.guava.Yielder; -import org.apache.druid.java.util.common.guava.Yielders; import org.apache.druid.java.util.common.io.Closer; import org.apache.druid.java.util.common.logger.Logger; import org.apache.druid.msq.counters.CounterSnapshots; @@ -109,13 +98,11 @@ import org.apache.druid.msq.indexing.MSQControllerTask; import org.apache.druid.msq.indexing.MSQSpec; import org.apache.druid.msq.indexing.MSQTuningConfig; -import org.apache.druid.msq.indexing.MSQWorkerTaskLauncher; import org.apache.druid.msq.indexing.WorkerCount; import org.apache.druid.msq.indexing.client.ControllerChatHandler; import org.apache.druid.msq.indexing.destination.DataSourceMSQDestination; import org.apache.druid.msq.indexing.destination.DurableStorageMSQDestination; import org.apache.druid.msq.indexing.destination.ExportMSQDestination; -import org.apache.druid.msq.indexing.destination.MSQSelectDestination; import org.apache.druid.msq.indexing.destination.TaskReportMSQDestination; import org.apache.druid.msq.indexing.error.CanceledFault; import org.apache.druid.msq.indexing.error.CannotParseExternalDataFault; @@ -129,16 +116,13 @@ import org.apache.druid.msq.indexing.error.MSQErrorReport; import org.apache.druid.msq.indexing.error.MSQException; import org.apache.druid.msq.indexing.error.MSQFault; -import org.apache.druid.msq.indexing.error.MSQFaultUtils; import org.apache.druid.msq.indexing.error.MSQWarningReportLimiterPublisher; -import org.apache.druid.msq.indexing.error.MSQWarnings; import org.apache.druid.msq.indexing.error.QueryNotSupportedFault; import org.apache.druid.msq.indexing.error.TooManyBucketsFault; import org.apache.druid.msq.indexing.error.TooManyWarningsFault; import org.apache.druid.msq.indexing.error.UnknownFault; import org.apache.druid.msq.indexing.error.WorkerRpcFailedFault; import org.apache.druid.msq.indexing.processor.SegmentGeneratorFrameProcessorFactory; -import org.apache.druid.msq.indexing.report.MSQResultsReport; import org.apache.druid.msq.indexing.report.MSQSegmentReport; import org.apache.druid.msq.indexing.report.MSQStagesReport; import org.apache.druid.msq.indexing.report.MSQStatusReport; @@ -160,9 +144,7 @@ import org.apache.druid.msq.input.stage.StageInputSlice; import org.apache.druid.msq.input.stage.StageInputSpec; import org.apache.druid.msq.input.stage.StageInputSpecSlicer; -import org.apache.druid.msq.input.table.DataSegmentWithLocation; import org.apache.druid.msq.input.table.TableInputSpec; -import org.apache.druid.msq.input.table.TableInputSpecSlicer; import org.apache.druid.msq.kernel.QueryDefinition; import org.apache.druid.msq.kernel.QueryDefinitionBuilder; import org.apache.druid.msq.kernel.StageDefinition; @@ -170,9 +152,9 @@ import org.apache.druid.msq.kernel.StagePartition; import org.apache.druid.msq.kernel.WorkOrder; import org.apache.druid.msq.kernel.controller.ControllerQueryKernel; +import org.apache.druid.msq.kernel.controller.ControllerQueryKernelConfig; import org.apache.druid.msq.kernel.controller.ControllerStagePhase; import org.apache.druid.msq.kernel.controller.WorkerInputs; -import org.apache.druid.msq.querykit.DataSegmentTimelineView; import org.apache.druid.msq.querykit.MultiQueryKit; import org.apache.druid.msq.querykit.QueryKit; import org.apache.druid.msq.querykit.QueryKitUtils; @@ -191,8 +173,6 @@ import org.apache.druid.msq.util.MSQFutureUtils; import org.apache.druid.msq.util.MultiStageQueryContext; import org.apache.druid.msq.util.PassthroughAggregatorFactory; -import org.apache.druid.msq.util.SqlStatementResourceHelper; -import org.apache.druid.query.DruidMetrics; import org.apache.druid.query.Query; import org.apache.druid.query.QueryContext; import org.apache.druid.query.aggregation.AggregatorFactory; @@ -213,19 +193,17 @@ import org.apache.druid.segment.realtime.appenderator.SegmentIdWithShardSpec; import org.apache.druid.segment.transform.TransformSpec; import org.apache.druid.server.DruidNode; -import org.apache.druid.server.coordination.DruidServerMetadata; -import org.apache.druid.sql.calcite.planner.ColumnMapping; import org.apache.druid.sql.calcite.planner.ColumnMappings; import org.apache.druid.sql.calcite.rel.DruidQuery; import org.apache.druid.sql.http.ResultFormat; import org.apache.druid.storage.ExportStorageProvider; import org.apache.druid.timeline.CompactionState; import org.apache.druid.timeline.DataSegment; -import org.apache.druid.timeline.SegmentTimeline; import org.apache.druid.timeline.partition.DimensionRangeShardSpec; import org.apache.druid.timeline.partition.NumberedPartialShardSpec; import org.apache.druid.timeline.partition.NumberedShardSpec; import org.apache.druid.timeline.partition.ShardSpec; +import org.apache.druid.utils.CloseableUtils; import org.apache.druid.utils.CollectionUtils; import org.joda.time.DateTime; import org.joda.time.Interval; @@ -234,7 +212,6 @@ import java.io.IOException; import java.util.ArrayList; import java.util.Arrays; -import java.util.Collection; import java.util.Collections; import java.util.Comparator; import java.util.HashMap; @@ -251,8 +228,6 @@ import java.util.concurrent.BlockingQueue; import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.ConcurrentLinkedQueue; -import java.util.concurrent.ThreadLocalRandom; -import java.util.concurrent.TimeUnit; import java.util.concurrent.atomic.AtomicReference; import java.util.function.Consumer; import java.util.function.Function; @@ -264,8 +239,11 @@ public class ControllerImpl implements Controller { private static final Logger log = new Logger(ControllerImpl.class); - private final MSQControllerTask task; + private final String queryId; + private final MSQSpec querySpec; + private final ResultsContext resultsContext; private final ControllerContext context; + private volatile ControllerQueryKernelConfig queryKernelConfig; /** * Queue of "commands" to run on the {@link ControllerQueryKernel}. Various threads insert into the queue @@ -308,88 +286,61 @@ public class ControllerImpl implements Controller // For live reports. Written by the main controller thread, read by HTTP threads. private final ConcurrentHashMap stagePartitionCountsForLiveReports = new ConcurrentHashMap<>(); - private WorkerSketchFetcher workerSketchFetcher; - // Time at which the query started. + // Stage number -> output channel mode. Only set for stages that have started. // For live reports. Written by the main controller thread, read by HTTP threads. + private final ConcurrentHashMap stageOutputChannelModesForLiveReports = + new ConcurrentHashMap<>(); + + private WorkerSketchFetcher workerSketchFetcher; // WorkerNumber -> WorkOrders which need to be retried and our determined by the controller. // Map is always populated in the main controller thread by addToRetryQueue, and pruned in retryFailedTasks. private final Map> workOrdersToRetry = new HashMap<>(); + + // Time at which the query started. + // For live reports. Written by the main controller thread, read by HTTP threads. private volatile DateTime queryStartTime = null; private volatile DruidNode selfDruidNode; - private volatile MSQWorkerTaskLauncher workerTaskLauncher; + private volatile WorkerManager workerManager; private volatile WorkerClient netClient; private volatile FaultsExceededChecker faultsExceededChecker = null; private Map stageToStatsMergingMode; - private WorkerMemoryParameters workerMemoryParameters; - private boolean isDurableStorageEnabled; - private final boolean isFaultToleranceEnabled; - private final boolean isFailOnEmptyInsertEnabled; private volatile SegmentLoadStatusFetcher segmentLoadWaiter; @Nullable private MSQSegmentReport segmentReport; public ControllerImpl( - final MSQControllerTask task, - final ControllerContext context + final String queryId, + final MSQSpec querySpec, + final ResultsContext resultsContext, + final ControllerContext controllerContext ) { - this.task = task; - this.context = context; - this.isDurableStorageEnabled = MultiStageQueryContext.isDurableStorageEnabled( - task.getQuerySpec().getQuery().context() - ); - this.isFaultToleranceEnabled = MultiStageQueryContext.isFaultToleranceEnabled( - task.getQuerySpec().getQuery().context() - ); - this.isFailOnEmptyInsertEnabled = MultiStageQueryContext.isFailOnEmptyInsertEnabled( - task.getQuerySpec().getQuery().context() - ); - } - - @Override - public String id() - { - return task.getId(); + this.queryId = Preconditions.checkNotNull(queryId, "queryId"); + this.querySpec = Preconditions.checkNotNull(querySpec, "querySpec"); + this.resultsContext = Preconditions.checkNotNull(resultsContext, "resultsContext"); + this.context = Preconditions.checkNotNull(controllerContext, "controllerContext"); } @Override - public MSQControllerTask task() + public String queryId() { - return task; + return queryId; } @Override - public TaskStatus run() throws Exception + public void run(final QueryListener queryListener) throws Exception { - final Closer closer = Closer.create(); - - try { - return runTask(closer); - } - catch (Throwable e) { - try { - closer.close(); - } - catch (Throwable e2) { - e.addSuppressed(e2); - } - - // We really don't expect this to error out. runTask should handle everything nicely. If it doesn't, something - // strange happened, so log it. - log.warn(e, "Encountered unhandled controller exception."); - return TaskStatus.failure(id(), e.toString()); - } - finally { - closer.close(); + try (final Closer closer = Closer.create()) { + runInternal(queryListener, closer); } } @Override - public void stopGracefully() + public void stop() { final QueryDefinition queryDef = queryDefRef.get(); @@ -403,18 +354,17 @@ public void stopGracefully() } ); - if (workerTaskLauncher != null) { - workerTaskLauncher.stop(true); + if (workerManager != null) { + workerManager.stop(true); } } - public TaskStatus runTask(final Closer closer) + private void runInternal(final QueryListener queryListener, final Closer closer) { QueryDefinition queryDef = null; ControllerQueryKernel queryKernel = null; ListenableFuture workerTaskRunnerFuture = null; CounterSnapshotsTree countersSnapshot = null; - Yielder resultsYielder = null; Throwable exceptionEncountered = null; final TaskState taskStateForReport; @@ -423,17 +373,24 @@ public TaskStatus runTask(final Closer closer) try { // Planning-related: convert the native query from MSQSpec into a multi-stage QueryDefinition. this.queryStartTime = DateTimes.nowUtc(); + context.registerController(this, closer); queryDef = initializeQueryDefAndState(closer); - final InputSpecSlicerFactory inputSpecSlicerFactory = makeInputSpecSlicerFactory(makeDataSegmentTimelineView()); - // Execution-related: run the multi-stage QueryDefinition. + final InputSpecSlicerFactory inputSpecSlicerFactory = + makeInputSpecSlicerFactory(context.newTableInputSpecSlicer()); + final Pair> queryRunResult = - new RunQueryUntilDone(queryDef, inputSpecSlicerFactory, closer).run(); + new RunQueryUntilDone( + queryDef, + queryKernelConfig, + inputSpecSlicerFactory, + queryListener, + closer + ).run(); queryKernel = Preconditions.checkNotNull(queryRunResult.lhs); workerTaskRunnerFuture = Preconditions.checkNotNull(queryRunResult.rhs); - resultsYielder = getFinalResultsYielder(queryDef, queryKernel); handleQueryResults(queryDef, queryKernel); } catch (Throwable e) { @@ -458,20 +415,24 @@ public TaskStatus runTask(final Closer closer) } else { // Query failure. Generate an error report and log the error(s) we encountered. final String selfHost = MSQTasks.getHostFromSelfNode(selfDruidNode); - final MSQErrorReport controllerError = - exceptionEncountered != null - ? MSQErrorReport.fromException( - id(), - selfHost, - null, - exceptionEncountered, - task.getQuerySpec().getColumnMappings() - ) - : null; + final MSQErrorReport controllerError; + + if (exceptionEncountered != null) { + controllerError = MSQErrorReport.fromException( + queryId(), + selfHost, + null, + exceptionEncountered, + querySpec.getColumnMappings() + ); + } else { + controllerError = null; + } + MSQErrorReport workerError = workerErrorRef.get(); taskStateForReport = TaskState.FAILED; - errorForReport = MSQTasks.makeErrorReport(id(), selfHost, controllerError, workerError); + errorForReport = MSQTasks.makeErrorReport(queryId(), selfHost, controllerError, workerError); // Log the errors we encountered. if (controllerError != null) { @@ -482,33 +443,14 @@ public TaskStatus runTask(final Closer closer) log.warn("Worker: %s", MSQTasks.errorReportToLogMessage(workerError)); } } - MSQResultsReport resultsReport = null; if (queryKernel != null && queryKernel.isSuccess()) { // If successful, encourage the tasks to exit successfully. - // get results before posting finish to the tasks. - if (resultsYielder != null) { - resultsReport = makeResultsTaskReport( - queryDef, - resultsYielder, - task.getQuerySpec().getColumnMappings(), - task.getSqlTypeNames(), - MultiStageQueryContext.getSelectDestination(task.getQuerySpec().getQuery().context()) - ); - try { - resultsYielder.close(); - } - catch (IOException e) { - throw new RuntimeException("Unable to fetch results of various worker tasks successfully", e); - } - } else { - resultsReport = null; - } postFinishToAllTasks(); - workerTaskLauncher.stop(false); + workerManager.stop(false); } else { // If not successful, cancel running tasks. - if (workerTaskLauncher != null) { - workerTaskLauncher.stop(true); + if (workerManager != null) { + workerManager.stop(true); } } @@ -523,10 +465,13 @@ public TaskStatus runTask(final Closer closer) } } - boolean shouldWaitForSegmentLoad = MultiStageQueryContext.shouldWaitForSegmentLoad(task.getQuerySpec().getQuery().context()); + boolean shouldWaitForSegmentLoad = MultiStageQueryContext.shouldWaitForSegmentLoad(querySpec.getQuery().context()); try { - releaseTaskLocks(); + if (MSQControllerTask.isIngestion(querySpec)) { + releaseTaskLocks(); + } + cleanUpDurableStorageIfNeeded(); if (queryKernel != null && queryKernel.isSuccess()) { @@ -534,7 +479,7 @@ public TaskStatus runTask(final Closer closer) // If successful, there are segments created and segment load is enabled, segmentLoadWaiter should wait // for them to become available. log.info("Controller will now wait for segments to be loaded. The query has already finished executing," - + " and results will be included once the segments are loaded, even if this query is cancelled now."); + + " and results will be included once the segments are loaded, even if this query is canceled now."); segmentLoadWaiter.waitForSegmentsToLoad(); } } @@ -544,70 +489,54 @@ public TaskStatus runTask(final Closer closer) log.warn(e, "Exception thrown during cleanup. Ignoring it and writing task report."); } - try { - // Write report even if something went wrong. - final MSQStagesReport stagesReport; - - if (queryDef != null) { - final Map stagePhaseMap; - - if (queryKernel != null) { - // Once the query finishes, cleanup would have happened for all the stages that were successful - // Therefore we mark it as done to make the reports prettier and more accurate - queryKernel.markSuccessfulTerminalStagesAsFinished(); - stagePhaseMap = queryKernel.getActiveStages() - .stream() - .collect( - Collectors.toMap(StageId::getStageNumber, queryKernel::getStagePhase) - ); - } else { - stagePhaseMap = Collections.emptyMap(); - } + // Generate report even if something went wrong. + final MSQStagesReport stagesReport; - stagesReport = makeStageReport( - queryDef, - stagePhaseMap, - stageRuntimesForLiveReports, - stageWorkerCountsForLiveReports, - stagePartitionCountsForLiveReports - ); + if (queryDef != null) { + final Map stagePhaseMap; + + if (queryKernel != null) { + // Once the query finishes, cleanup would have happened for all the stages that were successful + // Therefore we mark it as done to make the reports prettier and more accurate + queryKernel.markSuccessfulTerminalStagesAsFinished(); + stagePhaseMap = queryKernel.getActiveStages() + .stream() + .collect( + Collectors.toMap(StageId::getStageNumber, queryKernel::getStagePhase) + ); } else { - stagesReport = null; - } - - final MSQTaskReportPayload taskReportPayload = new MSQTaskReportPayload( - makeStatusReport( - taskStateForReport, - errorForReport, - workerWarnings, - queryStartTime, - new Interval(queryStartTime, DateTimes.nowUtc()).toDurationMillis(), - workerTaskLauncher, - segmentLoadWaiter, - segmentReport - ), - stagesReport, - countersSnapshot, - resultsReport - ); - context.writeReports( - id(), - TaskReport.buildTaskReports( - new MSQTaskReport(id(), taskReportPayload), - new TaskContextReport(id(), task.getContext()) - ) - ); - } - catch (Throwable e) { - log.warn(e, "Error encountered while writing task report. Skipping."); - } + stagePhaseMap = Collections.emptyMap(); + } - if (taskStateForReport == TaskState.SUCCESS) { - return TaskStatus.success(id()); + stagesReport = makeStageReport( + queryDef, + stagePhaseMap, + stageRuntimesForLiveReports, + stageWorkerCountsForLiveReports, + stagePartitionCountsForLiveReports, + stageOutputChannelModesForLiveReports + ); } else { - // errorForReport is nonnull when taskStateForReport != SUCCESS. Use that message. - return TaskStatus.failure(id(), MSQFaultUtils.generateMessageWithErrorCode(errorForReport.getFault())); - } + stagesReport = null; + } + + final MSQTaskReportPayload taskReportPayload = new MSQTaskReportPayload( + makeStatusReport( + taskStateForReport, + errorForReport, + workerWarnings, + queryStartTime, + new Interval(queryStartTime, DateTimes.nowUtc()).toDurationMillis(), + workerManager, + segmentLoadWaiter, + segmentReport + ), + stagesReport, + countersSnapshot, + null + ); + + queryListener.onQueryComplete(taskReportPayload); } /** @@ -644,105 +573,59 @@ public void addToKernelManipulationQueue(Consumer kernelC private QueryDefinition initializeQueryDefAndState(final Closer closer) { - final QueryContext queryContext = task.getQuerySpec().getQuery().context(); - if (isFaultToleranceEnabled) { - if (!queryContext.containsKey(MultiStageQueryContext.CTX_DURABLE_SHUFFLE_STORAGE)) { - // if context key not set, enable durableStorage automatically. - isDurableStorageEnabled = true; - } else { - // if context key is set, and durableStorage is turned on. - if (MultiStageQueryContext.isDurableStorageEnabled(queryContext)) { - isDurableStorageEnabled = true; - } else { - throw new MSQException( - UnknownFault.forMessage( - StringUtils.format( - "Context param[%s] cannot be explicitly set to false when context param[%s] is" - + " set to true. Either remove the context param[%s] or explicitly set it to true.", - MultiStageQueryContext.CTX_DURABLE_SHUFFLE_STORAGE, - MultiStageQueryContext.CTX_FAULT_TOLERANCE, - MultiStageQueryContext.CTX_DURABLE_SHUFFLE_STORAGE - ))); - } - } - } else { - isDurableStorageEnabled = MultiStageQueryContext.isDurableStorageEnabled(queryContext); - } - - log.debug("Task [%s] durable storage mode is set to %s.", task.getId(), isDurableStorageEnabled); - log.debug("Task [%s] fault tolerance mode is set to %s.", task.getId(), isFaultToleranceEnabled); - this.selfDruidNode = context.selfNode(); - context.registerController(this, closer); - - this.netClient = new ExceptionWrappingWorkerClient(context.taskClientFor(this)); - closer.register(netClient::close); + this.netClient = new ExceptionWrappingWorkerClient(context.newWorkerClient()); + closer.register(netClient); final QueryDefinition queryDef = makeQueryDefinition( - id(), + queryId(), makeQueryControllerToolKit(), - task.getQuerySpec(), + querySpec, context.jsonMapper() ); - QueryValidator.validateQueryDef(queryDef); - queryDefRef.set(queryDef); - - final long maxParseExceptions = task.getQuerySpec().getQuery().context().getLong( - MSQWarnings.CTX_MAX_PARSE_EXCEPTIONS_ALLOWED, - MSQWarnings.DEFAULT_MAX_PARSE_EXCEPTIONS_ALLOWED - ); - - ImmutableMap.Builder taskContextOverridesBuilder = ImmutableMap.builder(); - taskContextOverridesBuilder - .put(MultiStageQueryContext.CTX_DURABLE_SHUFFLE_STORAGE, isDurableStorageEnabled) - .put(MSQWarnings.CTX_MAX_PARSE_EXCEPTIONS_ALLOWED, maxParseExceptions); - - if (!MSQControllerTask.isIngestion(task.getQuerySpec())) { - if (MSQControllerTask.writeResultsToDurableStorage(task.getQuerySpec())) { - taskContextOverridesBuilder.put( - MultiStageQueryContext.CTX_SELECT_DESTINATION, - MSQSelectDestination.DURABLESTORAGE.getName() - ); - } else { - // we need not pass the value 'TaskReport' to the worker since the worker impl does not do anything in such a case. - // but we are passing it anyway for completeness - taskContextOverridesBuilder.put( - MultiStageQueryContext.CTX_SELECT_DESTINATION, - MSQSelectDestination.TASKREPORT.getName() + if (log.isDebugEnabled()) { + try { + log.debug( + "Query[%s] definition: %s", + queryDef.getQueryId(), + context.jsonMapper().writerWithDefaultPrettyPrinter().writeValueAsString(queryDef) ); } + catch (JsonProcessingException e) { + throw new RuntimeException(e); + } } - taskContextOverridesBuilder.put( - MultiStageQueryContext.CTX_IS_REINDEX, - MSQControllerTask.isReplaceInputDataSourceTask(task) - ); - - // propagate the controller's tags to the worker task for enhanced metrics reporting - Map tags = task.getContextValue(DruidMetrics.TAGS); - if (tags != null) { - taskContextOverridesBuilder.put(DruidMetrics.TAGS, tags); - } + QueryValidator.validateQueryDef(queryDef); + queryDefRef.set(queryDef); - this.workerTaskLauncher = new MSQWorkerTaskLauncher( - id(), - task.getDataSource(), - context, + queryKernelConfig = context.queryKernelConfig(querySpec, queryDef); + workerManager = context.newWorkerManager( + queryId, + querySpec, + queryKernelConfig, (failedTask, fault) -> { - if (isFaultToleranceEnabled && ControllerQueryKernel.isRetriableFault(fault)) { - addToKernelManipulationQueue((kernel) -> { + if (queryKernelConfig.isFaultTolerant() && ControllerQueryKernel.isRetriableFault(fault)) { + addToKernelManipulationQueue(kernel -> { addToRetryQueue(kernel, failedTask.getWorkerNumber(), fault); }); } else { throw new MSQException(fault); } - }, - taskContextOverridesBuilder.build(), - // 10 minutes +- 2 minutes jitter - TimeUnit.SECONDS.toMillis(600 + ThreadLocalRandom.current().nextInt(-4, 5) * 30L) + } ); + if (queryKernelConfig.isFaultTolerant() && !(workerManager instanceof RetryCapableWorkerManager)) { + // Not expected to happen, since all WorkerManager impls are currently retry-capable. Defensive check + // for future-proofing. + throw DruidException.defensive( + "Cannot run with fault tolerance since workerManager class[%s] does not support retrying", + workerManager.getClass().getName() + ); + } + + final long maxParseExceptions = MultiStageQueryContext.getMaxParseExceptions(querySpec.getQuery().context()); this.faultsExceededChecker = new FaultsExceededChecker( ImmutableMap.of(CannotParseExternalDataFault.CODE, maxParseExceptions) ); @@ -754,15 +637,14 @@ private QueryDefinition initializeQueryDefAndState(final Closer closer) stageDefinition.getId().getStageNumber(), finalizeClusterStatisticsMergeMode( stageDefinition, - MultiStageQueryContext.getClusterStatisticsMergeMode(queryContext) + MultiStageQueryContext.getClusterStatisticsMergeMode(querySpec.getQuery().context()) ) ) ); - this.workerMemoryParameters = WorkerMemoryParameters.createProductionInstanceForController(context.injector()); this.workerSketchFetcher = new WorkerSketchFetcher( netClient, - workerTaskLauncher, - isFaultToleranceEnabled + workerManager, + queryKernelConfig.isFaultTolerant() ); closer.register(workerSketchFetcher::close); @@ -777,10 +659,14 @@ private QueryDefinition initializeQueryDefAndState(final Closer closer) */ private void addToRetryQueue(ControllerQueryKernel kernel, int worker, MSQFault fault) { + // Blind cast to RetryCapableWorkerManager is safe, since we verified that workerManager is retry-capable + // when initially creating it. + final RetryCapableWorkerManager retryCapableWorkerManager = (RetryCapableWorkerManager) workerManager; + List retriableWorkOrders = kernel.getWorkInCaseWorkerEligibleForRetryElseThrow(worker, fault); - if (retriableWorkOrders.size() != 0) { + if (!retriableWorkOrders.isEmpty()) { log.info("Submitting worker[%s] for relaunch because of fault[%s]", worker, fault); - workerTaskLauncher.submitForRelaunch(worker); + retryCapableWorkerManager.submitForRelaunch(worker); workOrdersToRetry.compute(worker, (workerNumber, workOrders) -> { if (workOrders == null) { return new HashSet<>(retriableWorkOrders); @@ -790,11 +676,11 @@ private void addToRetryQueue(ControllerQueryKernel kernel, int worker, MSQFault } }); } else { - log.info( + log.debug( "Worker[%d] has no active workOrders that need relaunch therefore not relaunching", worker ); - workerTaskLauncher.reportFailedInactiveWorker(worker); + retryCapableWorkerManager.reportFailedInactiveWorker(worker); } } @@ -813,6 +699,11 @@ public void updatePartialKeyStatisticsInformation( addToKernelManipulationQueue( queryKernel -> { final StageId stageId = queryKernel.getStageId(stageNumber); + + if (queryKernel.isStageFinished(stageId)) { + return; + } + final PartialKeyStatisticsInformation partialKeyStatisticsInformation; try { @@ -835,19 +726,41 @@ public void updatePartialKeyStatisticsInformation( ); } + @Override + public void doneReadingInput(int stageNumber, int workerNumber) + { + addToKernelManipulationQueue( + queryKernel -> { + final StageId stageId = queryKernel.getStageId(stageNumber); + + if (queryKernel.isStageFinished(stageId)) { + return; + } + + queryKernel.setDoneReadingInputForStageAndWorker(stageId, workerNumber); + } + ); + } @Override public void workerError(MSQErrorReport errorReport) { - if (workerTaskLauncher.isTaskCanceledByController(errorReport.getTaskId()) || - !workerTaskLauncher.isTaskLatest(errorReport.getTaskId())) { - log.info("Ignoring task %s", errorReport.getTaskId()); - } else { - workerErrorRef.compareAndSet( - null, - mapQueryColumnNameToOutputColumnName(errorReport) - ); + if (queryKernelConfig.isFaultTolerant()) { + // Blind cast to RetryCapableWorkerManager in fault-tolerant mode is safe, since when fault-tolerance is + // enabled, we verify that workerManager is retry-capable when initially creating it. + final RetryCapableWorkerManager retryCapableWorkerManager = (RetryCapableWorkerManager) workerManager; + + if (retryCapableWorkerManager.isTaskCanceledByController(errorReport.getTaskId()) || + !retryCapableWorkerManager.isWorkerActive(errorReport.getTaskId())) { + log.debug( + "Ignoring error report for worker[%s] because it was intentionally shut down.", + errorReport.getTaskId() + ); + return; + } } + + workerErrorRef.compareAndSet(null, mapQueryColumnNameToOutputColumnName(errorReport)); } /** @@ -920,6 +833,11 @@ public void resultsComplete( addToKernelManipulationQueue( queryKernel -> { final StageId stageId = new StageId(queryId, stageNumber); + + if (queryKernel.isStageFinished(stageId)) { + return; + } + final Object convertedResultObject; try { convertedResultObject = context.jsonMapper().convertValue( @@ -936,7 +854,6 @@ public void resultsComplete( ); } - queryKernel.setResultsCompleteForStageAndWorker(stageId, workerNumber, convertedResultObject); } ); @@ -954,7 +871,7 @@ public TaskReport.ReportMap liveReports() return TaskReport.buildTaskReports( new MSQTaskReport( - id(), + queryId(), new MSQTaskReportPayload( makeStatusReport( TaskState.RUNNING, @@ -962,7 +879,7 @@ public TaskReport.ReportMap liveReports() workerWarnings, queryStartTime, queryStartTime == null ? -1L : new Interval(queryStartTime, DateTimes.nowUtc()).toDurationMillis(), - workerTaskLauncher, + workerManager, segmentLoadWaiter, segmentReport ), @@ -971,7 +888,8 @@ public TaskReport.ReportMap liveReports() stagePhasesForLiveReports, stageRuntimesForLiveReports, stageWorkerCountsForLiveReports, - stagePartitionCountsForLiveReports + stagePartitionCountsForLiveReports, + stageOutputChannelModesForLiveReports ), makeCountersSnapshotForLiveReports(), null @@ -982,9 +900,9 @@ public TaskReport.ReportMap liveReports() /** * @param isStageOutputEmpty {@code true} if the stage output is empty, {@code false} if the stage output is non-empty, - * {@code null} for stages where cluster key statistics are not gathered or is incomplete. + * {@code null} for stages where cluster key statistics are not gathered or is incomplete. * - * @return the segments that will be generated by this job. Delegates to + * @return the segments that will be generated by this job. Delegates to * {@link #generateSegmentIdsWithShardSpecsForAppend} or {@link #generateSegmentIdsWithShardSpecsForReplace} as * appropriate. This is a potentially expensive call, since it requires calling Overlord APIs. * @@ -1014,7 +932,7 @@ private List generateSegmentIdsWithShardSpecs( destination, partitionBoundaries, keyReader, - MultiStageQueryContext.validateAndGetTaskLockType(QueryContext.of(task.getQuerySpec().getQuery().getContext()), false), + MultiStageQueryContext.validateAndGetTaskLockType(QueryContext.of(querySpec.getQuery().getContext()), false), isStageOutputEmpty ); } @@ -1024,7 +942,7 @@ private List generateSegmentIdsWithShardSpecs( * Used by {@link #generateSegmentIdsWithShardSpecs}. * * @param isStageOutputEmpty {@code true} if the stage output is empty, {@code false} if the stage output is non-empty, - * {@code null} for stages where cluster key statistics are not gathered or is incomplete. + * {@code null} for stages where cluster key statistics are not gathered or is incomplete. */ private List generateSegmentIdsWithShardSpecsForAppend( final DataSourceMSQDestination destination, @@ -1055,13 +973,13 @@ private List generateSegmentIdsWithShardSpecsForAppend( try { allocation = context.taskActionClient().submit( new SegmentAllocateAction( - task.getDataSource(), + destination.getDataSource(), timestamp, // Same granularity for queryGranularity, segmentGranularity because we don't have insight here // into what queryGranularity "actually" is. (It depends on what time floor function was used.) segmentGranularity, segmentGranularity, - id(), + queryId(), previousSegmentId, false, NumberedPartialShardSpec.instance(), @@ -1081,7 +999,7 @@ private List generateSegmentIdsWithShardSpecsForAppend( if (allocation == null) { throw new MSQException( new InsertCannotAllocateSegmentFault( - task.getDataSource(), + destination.getDataSource(), segmentGranularity.bucket(timestamp), null ) @@ -1095,7 +1013,7 @@ private List generateSegmentIdsWithShardSpecsForAppend( if (!IntervalUtils.isAligned(allocation.getInterval(), segmentGranularity)) { throw new MSQException( new InsertCannotAllocateSegmentFault( - task.getDataSource(), + destination.getDataSource(), segmentGranularity.bucket(timestamp), allocation.getInterval() ) @@ -1113,8 +1031,7 @@ private List generateSegmentIdsWithShardSpecsForAppend( * Used by {@link #generateSegmentIdsWithShardSpecs}. * * @param isStageOutputEmpty {@code true} if the stage output is empty, {@code false} if the stage output is non-empty, - * {@code null} for stages where cluster key statistics are not gathered or is incomplete. - * + * {@code null} for stages where cluster key statistics are not gathered or is incomplete. */ private List generateSegmentIdsWithShardSpecsForReplace( final DataSourceMSQDestination destination, @@ -1135,10 +1052,17 @@ private List generateSegmentIdsWithShardSpecsForReplace( final List shardColumns; final Pair, String> shardReasonPair; - shardReasonPair = computeShardColumns(signature, clusterBy, task.getQuerySpec().getColumnMappings(), mayHaveMultiValuedClusterByFields); + shardReasonPair = computeShardColumns( + signature, + clusterBy, + querySpec.getColumnMappings(), + mayHaveMultiValuedClusterByFields + ); + shardColumns = shardReasonPair.lhs; String reason = shardReasonPair.rhs; - log.info(StringUtils.format("ShardSpec chosen: %s", reason)); + log.info("ShardSpec chosen: %s", reason); + if (shardColumns.isEmpty()) { segmentReport = new MSQSegmentReport(NumberedShardSpec.class.getSimpleName(), reason); } else { @@ -1194,26 +1118,21 @@ private List generateSegmentIdsWithShardSpecsForReplace( shardSpec = new DimensionRangeShardSpec(shardColumns, start, end, segmentNumber, ranges.size()); } - retVal[partitionNumber] = new SegmentIdWithShardSpec(task.getDataSource(), interval, version, shardSpec); + retVal[partitionNumber] = new SegmentIdWithShardSpec(destination.getDataSource(), interval, version, shardSpec); } } return Arrays.asList(retVal); } - /** - * Returns a complete list of task ids, ordered by worker number. The Nth task has worker number N. - *

- * If the currently-running set of tasks is incomplete, returns an absent Optional. - */ @Override public List getTaskIds() { - if (workerTaskLauncher == null) { + if (workerManager == null) { return Collections.emptyList(); } - return workerTaskLauncher.getActiveTasks(); + return workerManager.getWorkerIds(); } @SuppressWarnings({"unchecked", "rawtypes"}) @@ -1225,7 +1144,7 @@ private Int2ObjectMap makeWorkerFactoryInfosForStage( @Nullable final List segmentsToGenerate ) { - if (MSQControllerTask.isIngestion(task.getQuerySpec()) && + if (MSQControllerTask.isIngestion(querySpec) && stageNumber == queryDef.getFinalStageDefinition().getStageNumber()) { // noinspection unchecked,rawtypes return (Int2ObjectMap) makeSegmentGeneratorWorkerFactoryInfos(workerInputs, segmentsToGenerate); @@ -1247,94 +1166,6 @@ private QueryKit makeQueryControllerToolKit() return new MultiQueryKit(kitMap); } - private DataSegmentTimelineView makeDataSegmentTimelineView() - { - final SegmentSource includeSegmentSource = MultiStageQueryContext.getSegmentSources( - task.getQuerySpec() - .getQuery() - .context() - ); - - final boolean includeRealtime = SegmentSource.shouldQueryRealtimeServers(includeSegmentSource); - - return (dataSource, intervals) -> { - final Iterable realtimeAndHistoricalSegments; - - // Fetch the realtime segments and segments loaded on the historical. Do this first so that we don't miss any - // segment if they get handed off between the two calls. Segments loaded on historicals are deduplicated below, - // since we are only interested in realtime segments for now. - if (includeRealtime) { - realtimeAndHistoricalSegments = context.coordinatorClient().fetchServerViewSegments(dataSource, intervals); - } else { - realtimeAndHistoricalSegments = ImmutableList.of(); - } - - // Fetch all published, used segments (all non-realtime segments) from the metadata store. - // If the task is operating with a REPLACE lock, - // any segment created after the lock was acquired for its interval will not be considered. - final Collection publishedUsedSegments; - try { - // Additional check as the task action does not accept empty intervals - if (intervals.isEmpty()) { - publishedUsedSegments = Collections.emptySet(); - } else { - publishedUsedSegments = context.taskActionClient().submit(new RetrieveUsedSegmentsAction( - dataSource, - intervals - )); - } - } - catch (IOException e) { - throw new MSQException(e, UnknownFault.forException(e)); - } - - int realtimeCount = 0; - - // Deduplicate segments, giving preference to published used segments. - // We do this so that if any segments have been handed off in between the two metadata calls above, - // we directly fetch it from deep storage. - Set unifiedSegmentView = new HashSet<>(publishedUsedSegments); - - // Iterate over the realtime segments and segments loaded on the historical - for (ImmutableSegmentLoadInfo segmentLoadInfo : realtimeAndHistoricalSegments) { - ImmutableSet servers = segmentLoadInfo.getServers(); - // Filter out only realtime servers. We don't want to query historicals for now, but we can in the future. - // This check can be modified then. - Set realtimeServerMetadata - = servers.stream() - .filter(druidServerMetadata -> includeSegmentSource.getUsedServerTypes() - .contains(druidServerMetadata.getType()) - ) - .collect(Collectors.toSet()); - if (!realtimeServerMetadata.isEmpty()) { - realtimeCount += 1; - DataSegmentWithLocation dataSegmentWithLocation = new DataSegmentWithLocation( - segmentLoadInfo.getSegment(), - realtimeServerMetadata - ); - unifiedSegmentView.add(dataSegmentWithLocation); - } else { - // We don't have any segments of the required segment source, ignore the segment - } - } - - if (includeRealtime) { - log.info( - "Fetched total [%d] segments from coordinator: [%d] from metadata stoure, [%d] from server view", - unifiedSegmentView.size(), - publishedUsedSegments.size(), - realtimeCount - ); - } - - if (unifiedSegmentView.isEmpty()) { - return Optional.empty(); - } else { - return Optional.of(SegmentTimeline.forSegments(unifiedSegmentView)); - } - }; - } - private Int2ObjectMap> makeSegmentGeneratorWorkerFactoryInfos( final WorkerInputs workerInputs, final List segmentsToGenerate @@ -1369,75 +1200,59 @@ private Int2ObjectMap> makeSegmentGeneratorWorkerFa * * @param queryKernel * @param contactFn - * @param workers set of workers to contact - * @param successCallBack After contacting all the tasks, a custom callback is invoked in the main thread for each successfully contacted task. - * @param retryOnFailure If true, after contacting all the tasks, adds this worker to retry queue in the main thread. - * If false, cancel all the futures and propagate the exception to the caller. + * @param workers set of workers to contact + * @param successFn After contacting all the tasks, a custom callback is invoked in the main thread for each successfully contacted task. + * @param retryOnFailure If true, after contacting all the tasks, adds this worker to retry queue in the main thread. + * If false, cancel all the futures and propagate the exception to the caller. */ private void contactWorkersForStage( final ControllerQueryKernel queryKernel, - final TaskContactFn contactFn, final IntSet workers, - final TaskContactSuccess successCallBack, + final TaskContactFn contactFn, + final TaskContactSuccess successFn, final boolean retryOnFailure ) { - final List taskIds = getTaskIds(); - final List> taskFutures = new ArrayList<>(workers.size()); + // Sorted copy of target worker numbers to ensure consistent iteration order. + final List workersCopy = Ordering.natural().sortedCopy(workers); + final List workerIds = getTaskIds(); + final List> workerFutures = new ArrayList<>(workersCopy.size()); try { - workerTaskLauncher.waitUntilWorkersReady(workers); + workerManager.waitForWorkers(workers); } catch (InterruptedException e) { + Thread.currentThread().interrupt(); throw new RuntimeException(e); } - Set failedCalls = ConcurrentHashMap.newKeySet(); - Set successfulCalls = ConcurrentHashMap.newKeySet(); - - for (int workerNumber : workers) { - final String taskId = taskIds.get(workerNumber); - SettableFuture settableFuture = SettableFuture.create(); - ListenableFuture apiFuture = contactFn.contactTask(netClient, taskId, workerNumber); - Futures.addCallback(apiFuture, new FutureCallback() - { - @Override - public void onSuccess(@Nullable Void result) - { - successfulCalls.add(taskId); - settableFuture.set(true); - } - - @Override - public void onFailure(Throwable t) - { - if (retryOnFailure) { - log.info( - t, - "Detected failure while contacting task[%s]. Initiating relaunch of worker[%d] if applicable", - taskId, - MSQTasks.workerFromTaskId(taskId) - ); - failedCalls.add(taskId); - settableFuture.set(false); - } else { - settableFuture.setException(t); - } - } - }, MoreExecutors.directExecutor()); - - taskFutures.add(settableFuture); + for (final int workerNumber : workersCopy) { + workerFutures.add(contactFn.contactTask(netClient, workerIds.get(workerNumber), workerNumber)); } - FutureUtils.getUnchecked(MSQFutureUtils.allAsList(taskFutures, true), true); + final List> workerResults = + FutureUtils.getUnchecked(FutureUtils.coalesce(workerFutures), true); - for (String taskId : successfulCalls) { - successCallBack.onSuccess(taskId); - } + for (int i = 0; i < workerResults.size(); i++) { + final int workerNumber = workersCopy.get(i); + final String workerId = workerIds.get(workerNumber); + final Either workerResult = workerResults.get(i); + + if (workerResult.isValue()) { + successFn.onSuccess(workerId, workerNumber); + } else if (retryOnFailure) { + // Possibly retryable failure. + log.info( + workerResult.error(), + "Detected failure while contacting task[%s]. Initiating relaunch of worker[%d] if applicable", + workerId, + workerNumber + ); - if (retryOnFailure) { - for (String taskId : failedCalls) { - addToRetryQueue(queryKernel, MSQTasks.workerFromTaskId(taskId), new WorkerRpcFailedFault(taskId)); + addToRetryQueue(queryKernel, workerNumber, new WorkerRpcFailedFault(workerId)); + } else { + // Nonretryable failure. + throw new RuntimeException(workerResult.error()); } } } @@ -1462,10 +1277,12 @@ private void startWorkForStage( queryKernel.startStage(stageId); contactWorkersForStage( queryKernel, + workOrders.keySet(), (netClient, taskId, workerNumber) -> ( - netClient.postWorkOrder(taskId, workOrders.get(workerNumber))), workOrders.keySet(), - (taskId) -> queryKernel.workOrdersSentForWorker(stageId, MSQTasks.workerFromTaskId(taskId)), - isFaultToleranceEnabled + netClient.postWorkOrder(taskId, workOrders.get(workerNumber))), + (workerId, workerNumber) -> + queryKernel.workOrdersSentForWorker(stageId, workerNumber), + queryKernelConfig.isFaultTolerant() ); } @@ -1481,14 +1298,12 @@ private void postResultPartitionBoundariesForStage( contactWorkersForStage( queryKernel, - (netClient, taskId, workerNumber) -> netClient.postResultPartitionBoundaries( - taskId, - stageId, - resultPartitionBoundaries - ), workers, - (taskId) -> queryKernel.partitionBoundariesSentForWorker(stageId, MSQTasks.workerFromTaskId(taskId)), - isFaultToleranceEnabled + (netClient, workerId, workerNumber) -> + netClient.postResultPartitionBoundaries(workerId, stageId, resultPartitionBoundaries), + (workerId, workerNumber) -> + queryKernel.partitionBoundariesSentForWorker(stageId, workerNumber), + queryKernelConfig.isFaultTolerant() ); } @@ -1499,11 +1314,11 @@ private void postResultPartitionBoundariesForStage( private void publishAllSegments(final Set segments) throws IOException { final DataSourceMSQDestination destination = - (DataSourceMSQDestination) task.getQuerySpec().getDestination(); + (DataSourceMSQDestination) querySpec.getDestination(); final Set segmentsWithTombstones = new HashSet<>(segments); int numTombstones = 0; final TaskLockType taskLockType = MultiStageQueryContext.validateAndGetTaskLockType( - QueryContext.of(task.getQuerySpec().getQuery().getContext()), + QueryContext.of(querySpec.getQuery().getContext()), destination.isReplaceTimeChunks() ); @@ -1516,7 +1331,7 @@ private void publishAllSegments(final Set segments) throws IOExcept Set tombstones = tombstoneHelper.computeTombstoneSegmentsForReplace( intervalsToDrop, destination.getReplaceTimeChunks(), - task.getDataSource(), + destination.getDataSource(), destination.getSegmentGranularity(), Limits.MAX_PARTITION_BUCKETS ); @@ -1537,15 +1352,15 @@ private void publishAllSegments(final Set segments) throws IOExcept // This should not need a segment load wait as segments are marked as unused immediately. for (final Interval interval : intervalsToDrop) { context.taskActionClient() - .submit(new MarkSegmentsAsUnusedAction(task.getDataSource(), interval)); + .submit(new MarkSegmentsAsUnusedAction(destination.getDataSource(), interval)); } } else { - if (MultiStageQueryContext.shouldWaitForSegmentLoad(task.getQuerySpec().getQuery().context())) { + if (MultiStageQueryContext.shouldWaitForSegmentLoad(querySpec.getQuery().context())) { segmentLoadWaiter = new SegmentLoadStatusFetcher( context.injector().getInstance(BrokerClient.class), context.jsonMapper(), - task.getId(), - task.getDataSource(), + queryId, + destination.getDataSource(), segmentsWithTombstones, true ); @@ -1556,12 +1371,12 @@ private void publishAllSegments(final Set segments) throws IOExcept ); } } else if (!segments.isEmpty()) { - if (MultiStageQueryContext.shouldWaitForSegmentLoad(task.getQuerySpec().getQuery().context())) { + if (MultiStageQueryContext.shouldWaitForSegmentLoad(querySpec.getQuery().context())) { segmentLoadWaiter = new SegmentLoadStatusFetcher( context.injector().getInstance(BrokerClient.class), context.jsonMapper(), - task.getId(), - task.getDataSource(), + queryId, + destination.getDataSource(), segments, true ); @@ -1573,9 +1388,9 @@ private void publishAllSegments(final Set segments) throws IOExcept ); } - task.emitMetric(context.emitter(), "ingest/tombstones/count", numTombstones); + context.emitMetric("ingest/tombstones/count", numTombstones); // Include tombstones in the reported segments count - task.emitMetric(context.emitter(), "ingest/segments/count", segmentsWithTombstones.size()); + context.emitMetric("ingest/segments/count", segmentsWithTombstones.size()); } private static TaskAction createAppendAction( @@ -1614,7 +1429,7 @@ private List findIntervalsToDrop(final Set publishedSegme { // Safe to cast because publishAllSegments is only called for dataSource destinations. final DataSourceMSQDestination destination = - (DataSourceMSQDestination) task.getQuerySpec().getDestination(); + (DataSourceMSQDestination) querySpec.getDestination(); final List replaceIntervals = new ArrayList<>(JodaUtils.condenseIntervals(destination.getReplaceTimeChunks())); final List publishIntervals = @@ -1671,80 +1486,6 @@ private CounterSnapshotsTree getFinalCountersSnapshot(@Nullable final Controller } } - @Nullable - private Yielder getFinalResultsYielder( - final QueryDefinition queryDef, - final ControllerQueryKernel queryKernel - ) - { - if (queryKernel.isSuccess() && isInlineResults(task.getQuerySpec())) { - final StageId finalStageId = queryKernel.getStageId(queryDef.getFinalStageDefinition().getStageNumber()); - final List taskIds = getTaskIds(); - final Closer closer = Closer.create(); - - final ListeningExecutorService resultReaderExec = - MoreExecutors.listeningDecorator(Execs.singleThreaded("result-reader-%d")); - closer.register(resultReaderExec::shutdownNow); - - final InputChannelFactory inputChannelFactory; - - if (isDurableStorageEnabled || MSQControllerTask.writeResultsToDurableStorage(task.getQuerySpec())) { - inputChannelFactory = DurableStorageInputChannelFactory.createStandardImplementation( - id(), - MSQTasks.makeStorageConnector( - context.injector()), - closer, - MSQControllerTask.writeResultsToDurableStorage(task.getQuerySpec()) - ); - } else { - inputChannelFactory = new WorkerInputChannelFactory(netClient, () -> taskIds); - } - - final InputChannels inputChannels = new InputChannelsImpl( - queryDef, - queryKernel.getResultPartitionsForStage(finalStageId), - inputChannelFactory, - () -> ArenaMemoryAllocator.createOnHeap(5_000_000), - new FrameProcessorExecutor(resultReaderExec), - null - ); - - return Yielders.each( - Sequences.concat( - StreamSupport.stream(queryKernel.getResultPartitionsForStage(finalStageId).spliterator(), false) - .map( - readablePartition -> { - try { - return new FrameChannelSequence( - inputChannels.openChannel( - new StagePartition( - queryKernel.getStageDefinition(finalStageId).getId(), - readablePartition.getPartitionNumber() - ) - ) - ); - } - catch (IOException e) { - throw new RuntimeException(e); - } - } - ).collect(Collectors.toList()) - ).flatMap( - frame -> - SqlStatementResourceHelper.getResultSequence( - task, - queryDef.getFinalStageDefinition(), - frame, - context.jsonMapper() - ) - ) - .withBaggage(resultReaderExec::shutdownNow) - ); - } else { - return null; - } - } - private void handleQueryResults( final QueryDefinition queryDef, final ControllerQueryKernel queryKernel @@ -1753,22 +1494,21 @@ private void handleQueryResults( if (!queryKernel.isSuccess()) { return; } - if (MSQControllerTask.isIngestion(task.getQuerySpec())) { + if (MSQControllerTask.isIngestion(querySpec)) { // Publish segments if needed. final StageId finalStageId = queryKernel.getStageId(queryDef.getFinalStageDefinition().getStageNumber()); - //noinspection unchecked @SuppressWarnings("unchecked") Set segments = (Set) queryKernel.getResultObjectForStage(finalStageId); - boolean storeCompactionState = QueryContext.of(task.getQuerySpec().getQuery().getContext()) + boolean storeCompactionState = QueryContext.of(querySpec.getQuery().getContext()) .getBoolean( Tasks.STORE_COMPACTION_STATE_KEY, Tasks.DEFAULT_STORE_COMPACTION_STATE ); if (!segments.isEmpty() && storeCompactionState) { - DataSourceMSQDestination destination = (DataSourceMSQDestination) task.getQuerySpec().getDestination(); + DataSourceMSQDestination destination = (DataSourceMSQDestination) querySpec.getDestination(); if (!destination.isReplaceTimeChunks()) { // Store compaction state only for replace queries. log.warn( @@ -1782,7 +1522,7 @@ private void handleQueryResults( ShardSpec shardSpec = segments.stream().findFirst().get().getShardSpec(); Function, Set> compactionStateAnnotateFunction = addCompactionStateToSegments( - task(), + querySpec, context.jsonMapper(), dataSchema, shardSpec, @@ -1793,9 +1533,28 @@ private void handleQueryResults( } log.info("Query [%s] publishing %d segments.", queryDef.getQueryId(), segments.size()); publishAllSegments(segments); - } else if (MSQControllerTask.isExport(task.getQuerySpec())) { + } else if (MSQControllerTask.isExport(querySpec)) { + // Write manifest file. + ExportMSQDestination destination = (ExportMSQDestination) querySpec.getDestination(); + ExportMetadataManager exportMetadataManager = new ExportMetadataManager(destination.getExportStorageProvider()); + + final StageId finalStageId = queryKernel.getStageId(queryDef.getFinalStageDefinition().getStageNumber()); + //noinspection unchecked + + + Object resultObjectForStage = queryKernel.getResultObjectForStage(finalStageId); + if (!(resultObjectForStage instanceof List)) { + // This might occur if all workers are running on an older version. We are not able to write a manifest file in this case. + log.warn("Was unable to create manifest file due to "); + return; + } + @SuppressWarnings("unchecked") + List exportedFiles = (List) queryKernel.getResultObjectForStage(finalStageId); + log.info("Query [%s] exported %d files.", queryDef.getQueryId(), exportedFiles.size()); + exportMetadataManager.writeMetadata(exportedFiles); + } else if (MSQControllerTask.isExport(querySpec)) { // Write manifest file. - ExportMSQDestination destination = (ExportMSQDestination) task.getQuerySpec().getDestination(); + ExportMSQDestination destination = (ExportMSQDestination) querySpec.getDestination(); ExportMetadataManager exportMetadataManager = new ExportMetadataManager(destination.getExportStorageProvider()); final StageId finalStageId = queryKernel.getStageId(queryDef.getFinalStageDefinition().getStageNumber()); @@ -1816,14 +1575,14 @@ private void handleQueryResults( } private static Function, Set> addCompactionStateToSegments( - MSQControllerTask task, + MSQSpec querySpec, ObjectMapper jsonMapper, DataSchema dataSchema, ShardSpec shardSpec, String queryId ) { - final MSQTuningConfig tuningConfig = task.getQuerySpec().getTuningConfig(); + final MSQTuningConfig tuningConfig = querySpec.getTuningConfig(); PartitionsSpec partitionSpec; if (Objects.equals(shardSpec.getType(), ShardSpec.Type.RANGE)) { @@ -1848,7 +1607,7 @@ private static Function, Set> addCompactionStateTo ))); } - Granularity segmentGranularity = ((DataSourceMSQDestination) task.getQuerySpec().getDestination()) + Granularity segmentGranularity = ((DataSourceMSQDestination) querySpec.getDestination()) .getSegmentGranularity(); GranularitySpec granularitySpec = new UniformGranularitySpec( @@ -1895,15 +1654,15 @@ private static Function, Set> addCompactionStateTo */ private void cleanUpDurableStorageIfNeeded() { - if (isDurableStorageEnabled) { - final String controllerDirName = DurableStorageUtils.getControllerDirectory(task.getId()); + if (queryKernelConfig != null && queryKernelConfig.isDurableStorage()) { + final String controllerDirName = DurableStorageUtils.getControllerDirectory(queryId()); try { // Delete all temporary files as a failsafe MSQTasks.makeStorageConnector(context.injector()).deleteRecursively(controllerDirName); } catch (Exception e) { // If an error is thrown while cleaning up a file, log it and try to continue with the cleanup - log.warn(e, "Error while cleaning up temporary files at path %s", controllerDirName); + log.warn(e, "Error while cleaning up temporary files at path[%s]. Skipping.", controllerDirName); } } } @@ -1945,10 +1704,9 @@ private static QueryDefinition makeQueryDefinition( queryToPlan = querySpec.getQuery(); } } else { - shuffleSpecFactory = querySpec.getDestination() - .getShuffleSpecFactory( - MultiStageQueryContext.getRowsPerPage(querySpec.getQuery().context()) - ); + shuffleSpecFactory = + querySpec.getDestination() + .getShuffleSpecFactory(MultiStageQueryContext.getRowsPerPage(querySpec.getQuery().context())); queryToPlan = querySpec.getQuery(); } @@ -1992,7 +1750,7 @@ private static QueryDefinition makeQueryDefinition( // Add all query stages. // Set shuffleCheckHasMultipleValues on the stage that serves as input to the final segment-generation stage. - final QueryDefinitionBuilder builder = QueryDefinition.builder(); + final QueryDefinitionBuilder builder = QueryDefinition.builder(queryId); for (final StageDefinition stageDef : queryDef.getStageDefinitions()) { if (stageDef.equals(finalShuffleStageDef)) { @@ -2004,7 +1762,7 @@ private static QueryDefinition makeQueryDefinition( // Then, add a segment-generation stage. final DataSchema dataSchema = - generateDataSchema(querySpec, querySignature, queryClusterBy, columnMappings, jsonMapper); + makeDataSchemaForIngestion(querySpec, querySignature, queryClusterBy, columnMappings, jsonMapper); builder.add( StageDefinition.builder(queryDef.getNextStageNumber()) @@ -2027,7 +1785,7 @@ private static QueryDefinition makeQueryDefinition( // attaching new query results stage if the final stage does sort during shuffle so that results are ordered. StageDefinition finalShuffleStageDef = queryDef.getFinalStageDefinition(); if (finalShuffleStageDef.doesSortDuringShuffle()) { - final QueryDefinitionBuilder builder = QueryDefinition.builder(); + final QueryDefinitionBuilder builder = QueryDefinition.builder(queryId); builder.addAll(queryDef); builder.add(StageDefinition.builder(queryDef.getNextStageNumber()) .inputs(new StageInputSpec(queryDef.getFinalStageDefinition().getStageNumber())) @@ -2063,9 +1821,8 @@ private static QueryDefinition makeQueryDefinition( .build(e, "Exception occurred while connecting to export destination."); } - final ResultFormat resultFormat = exportMSQDestination.getResultFormat(); - final QueryDefinitionBuilder builder = QueryDefinition.builder(); + final QueryDefinitionBuilder builder = QueryDefinition.builder(queryId); builder.addAll(queryDef); builder.add(StageDefinition.builder(queryDef.getNextStageNumber()) .inputs(new StageInputSpec(queryDef.getFinalStageDefinition().getStageNumber())) @@ -2085,7 +1842,12 @@ private static QueryDefinition makeQueryDefinition( } } - private static DataSchema generateDataSchema( + private static String getDataSourceForIngestion(final MSQSpec querySpec) + { + return ((DataSourceMSQDestination) querySpec.getDestination()).getDataSource(); + } + + private static DataSchema makeDataSchemaForIngestion( MSQSpec querySpec, RowSignature querySignature, ClusterBy queryClusterBy, @@ -2187,19 +1949,6 @@ private static boolean isRollupQuery(Query query) && !query.context().getBoolean(GroupByQueryConfig.CTX_KEY_ENABLE_MULTI_VALUE_UNNESTING, true); } - private static boolean isInlineResults(final MSQSpec querySpec) - { - return querySpec.getDestination() instanceof TaskReportMSQDestination - || querySpec.getDestination() instanceof DurableStorageMSQDestination; - } - - private static boolean isTimeBucketedIngestion(final MSQSpec querySpec) - { - return MSQControllerTask.isIngestion(querySpec) - && !((DataSourceMSQDestination) querySpec.getDestination()).getSegmentGranularity() - .equals(Granularities.ALL); - } - /** * Compute shard columns for {@link DimensionRangeShardSpec}. Returns an empty list if range-based sharding * is not applicable. @@ -2477,7 +2226,8 @@ private static MSQStagesReport makeStageReport( final Map stagePhaseMap, final Map stageRuntimeMap, final Map stageWorkerCountMap, - final Map stagePartitionCountMap + final Map stagePartitionCountMap, + final Map stageOutputChannelModeMap ) { return MSQStagesReport.create( @@ -2485,35 +2235,8 @@ private static MSQStagesReport makeStageReport( ImmutableMap.copyOf(stagePhaseMap), copyOfStageRuntimesEndingAtCurrentTime(stageRuntimeMap), stageWorkerCountMap, - stagePartitionCountMap - ); - } - - private static MSQResultsReport makeResultsTaskReport( - final QueryDefinition queryDef, - final Yielder resultsYielder, - final ColumnMappings columnMappings, - @Nullable final List sqlTypeNames, - final MSQSelectDestination selectDestination - ) - { - final RowSignature querySignature = queryDef.getFinalStageDefinition().getSignature(); - final ImmutableList.Builder mappedSignature = ImmutableList.builder(); - - for (final ColumnMapping mapping : columnMappings.getMappings()) { - mappedSignature.add( - new MSQResultsReport.ColumnAndType( - mapping.getOutputColumn(), - querySignature.getColumnType(mapping.getQueryColumn()).orElse(null) - ) - ); - } - - return MSQResultsReport.createReportAndLimitRowsIfNeeded( - mappedSignature.build(), - sqlTypeNames, - resultsYielder, - selectDestination + stagePartitionCountMap, + stageOutputChannelModeMap ); } @@ -2523,17 +2246,17 @@ private static MSQStatusReport makeStatusReport( final Queue errorReports, @Nullable final DateTime queryStartTime, final long queryDuration, - MSQWorkerTaskLauncher taskLauncher, + final WorkerManager taskLauncher, final SegmentLoadStatusFetcher segmentLoadWaiter, @Nullable MSQSegmentReport msqSegmentReport ) { int pendingTasks = -1; int runningTasks = 1; - Map> workerStatsMap = new HashMap<>(); + Map> workerStatsMap = new HashMap<>(); if (taskLauncher != null) { - WorkerCount workerTaskCount = taskLauncher.getWorkerTaskCount(); + WorkerCount workerTaskCount = taskLauncher.getWorkerCount(); pendingTasks = workerTaskCount.getPendingWorkerCount(); runningTasks = workerTaskCount.getRunningWorkerCount() + 1; // To account for controller. workerStatsMap = taskLauncher.getWorkerStats(); @@ -2557,15 +2280,15 @@ private static MSQStatusReport makeStatusReport( ); } - private static InputSpecSlicerFactory makeInputSpecSlicerFactory(final DataSegmentTimelineView timelineView) + private static InputSpecSlicerFactory makeInputSpecSlicerFactory(final InputSpecSlicer tableInputSpecSlicer) { - return stagePartitionsMap -> new MapInputSpecSlicer( + return (stagePartitionsMap, stageOutputChannelModeMap) -> new MapInputSpecSlicer( ImmutableMap., InputSpecSlicer>builder() - .put(StageInputSpec.class, new StageInputSpecSlicer(stagePartitionsMap)) + .put(StageInputSpec.class, new StageInputSpecSlicer(stagePartitionsMap, stageOutputChannelModeMap)) .put(ExternalInputSpec.class, new ExternalInputSpecSlicer()) .put(InlineInputSpec.class, new InlineInputSpecSlicer()) .put(LookupInputSpec.class, new LookupInputSpecSlicer()) - .put(TableInputSpec.class, new TableInputSpecSlicer(timelineView)) + .put(TableInputSpec.class, tableInputSpecSlicer) .build() ); } @@ -2679,11 +2402,12 @@ private class RunQueryUntilDone { private final QueryDefinition queryDef; private final InputSpecSlicerFactory inputSpecSlicerFactory; + private final QueryListener queryListener; private final Closer closer; private final ControllerQueryKernel queryKernel; /** - * Return value of {@link MSQWorkerTaskLauncher#start()}. Set by {@link #startTaskLauncher()}. + * Return value of {@link WorkerManager#start()}. Set by {@link #startTaskLauncher()}. */ private ListenableFuture workerTaskLauncherFuture; @@ -2694,20 +2418,26 @@ private class RunQueryUntilDone */ private List segmentsToGenerate; + /** + * Future that resolves when the reader from {@link #startQueryResultsReader()} finishes. Prior to that method + * being called, this future is null. + */ + @Nullable + private ListenableFuture queryResultsReaderFuture; + public RunQueryUntilDone( final QueryDefinition queryDef, + final ControllerQueryKernelConfig queryKernelConfig, final InputSpecSlicerFactory inputSpecSlicerFactory, + final QueryListener queryListener, final Closer closer ) { this.queryDef = queryDef; this.inputSpecSlicerFactory = inputSpecSlicerFactory; + this.queryListener = queryListener; this.closer = closer; - this.queryKernel = new ControllerQueryKernel( - queryDef, - workerMemoryParameters.getPartitionStatisticsMaxRetainedBytes(), - isFaultToleranceEnabled - ); + this.queryKernel = new ControllerQueryKernel(queryDef, queryKernelConfig); } /** @@ -2717,15 +2447,20 @@ private Pair> run() throws IOExceptio { startTaskLauncher(); + boolean runAgain; while (!queryKernel.isDone()) { startStages(); fetchStatsFromWorkers(); sendPartitionBoundaries(); updateLiveReportMaps(); - cleanUpEffectivelyFinishedStages(); + readQueryResults(); + runAgain = cleanUpEffectivelyFinishedStages(); retryFailedTasks(); checkForErrorsInSketchFetcher(); - runKernelCommands(); + + if (!runAgain) { + runKernelCommands(); + } } if (!queryKernel.isSuccess()) { @@ -2745,11 +2480,21 @@ private void checkForErrorsInSketchFetcher() } } + /** + * Read query results, if appropriate and possible. Returns true if something was read. + */ + private void readQueryResults() + { + // Open query results channel, if appropriate. + if (queryListener.readResults() && queryKernel.canReadQueryResults() && queryResultsReaderFuture == null) { + startQueryResultsReader(); + } + } private void retryFailedTasks() throws InterruptedException { // if no work orders to rety skip - if (workOrdersToRetry.size() == 0) { + if (workOrdersToRetry.isEmpty()) { return; } Set workersNeedToBeFullyStarted = new HashSet<>(); @@ -2765,7 +2510,7 @@ private void retryFailedTasks() throws InterruptedException new StageId(queryDef.getQueryId(), workOrder.getStageNumber()), (stageId, workOrders) -> { if (workOrders == null) { - workOrders = new HashMap(); + workOrders = new HashMap<>(); } workOrders.put(workerStages.getKey(), workOrder); return workOrders; @@ -2775,27 +2520,23 @@ private void retryFailedTasks() throws InterruptedException } // wait till the workers identified above are fully ready - workerTaskLauncher.waitUntilWorkersReady(workersNeedToBeFullyStarted); + workerManager.waitForWorkers(workersNeedToBeFullyStarted); for (Map.Entry> stageWorkOrders : stageWorkerOrders.entrySet()) { - contactWorkersForStage( queryKernel, - (netClient, taskId, workerNumber) -> netClient.postWorkOrder( - taskId, - stageWorkOrders.getValue().get(workerNumber) - ), new IntArraySet(stageWorkOrders.getValue().keySet()), - (taskId) -> { - int workerNumber = MSQTasks.workerFromTaskId(taskId); + (netClient, workerId, workerNumber) -> + netClient.postWorkOrder(workerId, stageWorkOrders.getValue().get(workerNumber)), + (workerId, workerNumber) -> { queryKernel.workOrdersSentForWorker(stageWorkOrders.getKey(), workerNumber); // remove successfully contacted workOrders from workOrdersToRetry workOrdersToRetry.compute(workerNumber, (task, workOrderSet) -> { - if (workOrderSet == null || workOrderSet.size() == 0 || !workOrderSet.remove(stageWorkOrders.getValue() - .get( - workerNumber))) { - throw new ISE("Worker[%d] orders not found", workerNumber); + if (workOrderSet == null + || workOrderSet.size() == 0 + || !workOrderSet.remove(stageWorkOrders.getValue().get(workerNumber))) { + throw new ISE("Worker[%s] with number[%d] orders not found", workerId, workerNumber); } if (workOrderSet.size() == 0) { return null; @@ -2803,7 +2544,7 @@ private void retryFailedTasks() throws InterruptedException return workOrderSet; }); }, - isFaultToleranceEnabled + queryKernelConfig.isFaultTolerant() ); } } @@ -2827,16 +2568,16 @@ private void runKernelCommands() throws InterruptedException } /** - * Start up the {@link MSQWorkerTaskLauncher}, such that later on it can be used to launch new tasks - * via {@link MSQWorkerTaskLauncher#launchTasksIfNeeded}. + * Start up the {@link WorkerManager}, such that later on it can be used to launch new tasks + * via {@link WorkerManager#launchWorkersIfNeeded}. */ private void startTaskLauncher() { // Start tasks. log.debug("Query [%s] starting task launcher.", queryDef.getQueryId()); - workerTaskLauncherFuture = workerTaskLauncher.start(); - closer.register(() -> workerTaskLauncher.stop(true)); + workerTaskLauncherFuture = workerManager.start(); + closer.register(() -> workerManager.stop(true)); workerTaskLauncherFuture.addListener( () -> @@ -2857,7 +2598,7 @@ private void fetchStatsFromWorkers() for (Map.Entry> stageToWorker : queryKernel.getStagesAndWorkersToFetchClusterStats() .entrySet()) { - List allTasks = workerTaskLauncher.getActiveTasks(); + List allTasks = workerManager.getWorkerIds(); Set tasks = stageToWorker.getValue().stream().map(allTasks::get).collect(Collectors.toSet()); ClusterStatisticsMergeMode clusterStatisticsMergeMode = stageToStatsMergingMode.get(stageToWorker.getKey() @@ -2881,7 +2622,7 @@ private void submitParallelMergeRequests(StageId stageId, Set tasks) // eagerly change state of workers whose state is being fetched so that we do not keep on queuing fetch requests. queryKernel.startFetchingStatsFromWorker( stageId, - tasks.stream().map(MSQTasks::workerFromTaskId).collect(Collectors.toSet()) + tasks.stream().map(workerManager::getWorkerNumber).collect(Collectors.toSet()) ); workerSketchFetcher.inMemoryFullSketchMerging(ControllerImpl.this::addToKernelManipulationQueue, stageId, tasks, @@ -2896,13 +2637,14 @@ private void submitSequentialMergeFetchRequests(StageId stageId, Set tas queryKernel.startFetchingStatsFromWorker( stageId, tasks.stream() - .map(MSQTasks::workerFromTaskId) + .map(workerManager::getWorkerNumber) .collect(Collectors.toSet()) ); workerSketchFetcher.sequentialTimeChunkMerging( ControllerImpl.this::addToKernelManipulationQueue, queryKernel.getCompleteKeyStatisticsInformation(stageId), - stageId, tasks, + stageId, + tasks, ControllerImpl.this::addToRetryQueue ); } @@ -2914,69 +2656,88 @@ private void submitSequentialMergeFetchRequests(StageId stageId, Set tas private void startStages() throws IOException, InterruptedException { final long maxInputBytesPerWorker = - MultiStageQueryContext.getMaxInputBytesPerWorker(task.getQuerySpec().getQuery().context()); + MultiStageQueryContext.getMaxInputBytesPerWorker(querySpec.getQuery().context()); logKernelStatus(queryDef.getQueryId(), queryKernel); - final List newStageIds = queryKernel.createAndGetNewStageIds( - inputSpecSlicerFactory, - task.getQuerySpec().getAssignmentStrategy(), - maxInputBytesPerWorker - ); - for (final StageId stageId : newStageIds) { - - // Allocate segments, if this is the final stage of an ingestion. - if (MSQControllerTask.isIngestion(task.getQuerySpec()) - && stageId.getStageNumber() == queryDef.getFinalStageDefinition().getStageNumber()) { - // We need to find the shuffle details (like partition ranges) to generate segments. Generally this is - // going to correspond to the stage immediately prior to the final segment-generator stage. - int shuffleStageNumber = Iterables.getOnlyElement(queryDef.getFinalStageDefinition().getInputStageNumbers()); - - // The following logic assumes that output of all the stages without a shuffle retain the partition boundaries - // of the input to that stage. This may not always be the case. For example: GROUP BY queries without an - // ORDER BY clause. This works for QueryKit generated queries up until now, but it should be reworked as it - // might not always be the case. - while (!queryDef.getStageDefinition(shuffleStageNumber).doesShuffle()) { - shuffleStageNumber = - Iterables.getOnlyElement(queryDef.getStageDefinition(shuffleStageNumber).getInputStageNumbers()); - } + List newStageIds; + + do { + newStageIds = queryKernel.createAndGetNewStageIds( + inputSpecSlicerFactory, + querySpec.getAssignmentStrategy(), + maxInputBytesPerWorker + ); - final StageId shuffleStageId = new StageId(queryDef.getQueryId(), shuffleStageNumber); - final Boolean isShuffleStageOutputEmpty = queryKernel.isStageOutputEmpty(shuffleStageId); - if (isFailOnEmptyInsertEnabled && Boolean.TRUE.equals(isShuffleStageOutputEmpty)) { - throw new MSQException(new InsertCannotBeEmptyFault(task.getDataSource())); + for (final StageId stageId : newStageIds) { + // Allocate segments, if this is the final stage of an ingestion. + if (MSQControllerTask.isIngestion(querySpec) + && stageId.getStageNumber() == queryDef.getFinalStageDefinition().getStageNumber()) { + populateSegmentsToGenerate(); } - final ClusterByPartitions partitionBoundaries = - queryKernel.getResultPartitionBoundariesForStage(shuffleStageId); - - final boolean mayHaveMultiValuedClusterByFields = - !queryKernel.getStageDefinition(shuffleStageId).mustGatherResultKeyStatistics() - || queryKernel.hasStageCollectorEncounteredAnyMultiValueField(shuffleStageId); - - segmentsToGenerate = generateSegmentIdsWithShardSpecs( - (DataSourceMSQDestination) task.getQuerySpec().getDestination(), - queryKernel.getStageDefinition(shuffleStageId).getSignature(), - queryKernel.getStageDefinition(shuffleStageId).getClusterBy(), - partitionBoundaries, - mayHaveMultiValuedClusterByFields, - isShuffleStageOutputEmpty + + final int workerCount = queryKernel.getWorkerInputsForStage(stageId).workerCount(); + final StageDefinition stageDef = queryKernel.getStageDefinition(stageId); + log.info( + "Query [%s] using workers[%d] for stage[%d], writing to[%s], shuffle[%s].", + stageId.getQueryId(), + workerCount, + stageId.getStageNumber(), + queryKernel.getStageOutputChannelMode(stageId), + stageDef.doesShuffle() ? stageDef.getShuffleSpec().kind() : "none" ); - log.info("Query[%s] generating %d segments.", queryDef.getQueryId(), segmentsToGenerate.size()); + workerManager.launchWorkersIfNeeded(workerCount); + stageRuntimesForLiveReports.put(stageId.getStageNumber(), new Interval(DateTimes.nowUtc(), DateTimes.MAX)); + startWorkForStage(queryDef, queryKernel, stageId.getStageNumber(), segmentsToGenerate); } + } while (!newStageIds.isEmpty()); + } - final int workerCount = queryKernel.getWorkerInputsForStage(stageId).workerCount(); - log.info( - "Query [%s] starting %d workers for stage %d.", - stageId.getQueryId(), - workerCount, - stageId.getStageNumber() - ); + /** + * Populate {@link #segmentsToGenerate} for ingestion. + */ + private void populateSegmentsToGenerate() throws IOException + { + // We need to find the shuffle details (like partition ranges) to generate segments. Generally this is + // going to correspond to the stage immediately prior to the final segment-generator stage. + int shuffleStageNumber = Iterables.getOnlyElement(queryDef.getFinalStageDefinition().getInputStageNumbers()); + + // The following logic assumes that output of all the stages without a shuffle retain the partition boundaries + // of the input to that stage. This may not always be the case. For example: GROUP BY queries without an + // ORDER BY clause. This works for QueryKit generated queries up until now, but it should be reworked as it + // might not always be the case. + while (!queryDef.getStageDefinition(shuffleStageNumber).doesShuffle()) { + shuffleStageNumber = + Iterables.getOnlyElement(queryDef.getStageDefinition(shuffleStageNumber).getInputStageNumbers()); + } - workerTaskLauncher.launchTasksIfNeeded(workerCount); - stageRuntimesForLiveReports.put(stageId.getStageNumber(), new Interval(DateTimes.nowUtc(), DateTimes.MAX)); - startWorkForStage(queryDef, queryKernel, stageId.getStageNumber(), segmentsToGenerate); + final StageId shuffleStageId = new StageId(queryDef.getQueryId(), shuffleStageNumber); + + final boolean isFailOnEmptyInsertEnabled = + MultiStageQueryContext.isFailOnEmptyInsertEnabled(querySpec.getQuery().context()); + final Boolean isShuffleStageOutputEmpty = queryKernel.isStageOutputEmpty(shuffleStageId); + if (isFailOnEmptyInsertEnabled && Boolean.TRUE.equals(isShuffleStageOutputEmpty)) { + throw new MSQException(new InsertCannotBeEmptyFault(getDataSourceForIngestion(querySpec))); } + + final ClusterByPartitions partitionBoundaries = + queryKernel.getResultPartitionBoundariesForStage(shuffleStageId); + + final boolean mayHaveMultiValuedClusterByFields = + !queryKernel.getStageDefinition(shuffleStageId).mustGatherResultKeyStatistics() + || queryKernel.hasStageCollectorEncounteredAnyMultiValueField(shuffleStageId); + + segmentsToGenerate = generateSegmentIdsWithShardSpecs( + (DataSourceMSQDestination) querySpec.getDestination(), + queryKernel.getStageDefinition(shuffleStageId).getSignature(), + queryKernel.getStageDefinition(shuffleStageId).getClusterBy(), + partitionBoundaries, + mayHaveMultiValuedClusterByFields, + isShuffleStageOutputEmpty + ); + + log.info("Query [%s] generating %d segments.", queryDef.getQueryId(), partitionBoundaries.size()); } /** @@ -3033,7 +2794,7 @@ private void updateLiveReportMaps() { logKernelStatus(queryDef.getQueryId(), queryKernel); - // Live reports: update stage phases, worker counts, partition counts. + // Live reports: update stage phases, worker counts, partition counts, output channel modes. for (StageId stageId : queryKernel.getActiveStages()) { final int stageNumber = stageId.getStageNumber(); stagePhasesForLiveReports.put(stageNumber, queryKernel.getStagePhase(stageId)); @@ -3045,15 +2806,20 @@ private void updateLiveReportMaps() ); } - stageWorkerCountsForLiveReports.putIfAbsent( + stageWorkerCountsForLiveReports.computeIfAbsent( stageNumber, - queryKernel.getWorkerInputsForStage(stageId).workerCount() + k -> queryKernel.getWorkerInputsForStage(stageId).workerCount() + ); + + stageOutputChannelModesForLiveReports.computeIfAbsent( + stageNumber, + k -> queryKernel.getStageOutputChannelMode(stageId) ); } // Live reports: update stage end times for any stages that just ended. for (StageId stageId : queryKernel.getActiveStages()) { - if (ControllerStagePhase.isSuccessfulTerminalPhase(queryKernel.getStagePhase(stageId))) { + if (queryKernel.getStagePhase(stageId).isSuccess()) { stageRuntimesForLiveReports.compute( queryKernel.getStageDefinition(stageId).getStageNumber(), (k, currentValue) -> { @@ -3070,21 +2836,144 @@ private void updateLiveReportMaps() /** * Issue cleanup commands to any stages that are effectivley finished, allowing them to delete their outputs. + * + * @return true if any stages were cleaned up */ - private void cleanUpEffectivelyFinishedStages() + private boolean cleanUpEffectivelyFinishedStages() { + final StageId finalStageId = queryDef.getFinalStageDefinition().getId(); + boolean didSomething = false; for (final StageId stageId : queryKernel.getEffectivelyFinishedStageIds()) { + if (finalStageId.equals(stageId) + && queryListener.readResults() + && (queryResultsReaderFuture == null || !queryResultsReaderFuture.isDone())) { + // Don't clean up final stage until results are done being read. + continue; + } + log.info("Query [%s] issuing cleanup order for stage %d.", queryDef.getQueryId(), stageId.getStageNumber()); contactWorkersForStage( queryKernel, - (netClient, taskId, workerNumber) -> netClient.postCleanupStage(taskId, stageId), queryKernel.getWorkerInputsForStage(stageId).workers(), - (ignore1) -> { - }, + (netClient, workerId, workerNumber) -> netClient.postCleanupStage(workerId, stageId), + (workerId, workerNumber) -> {}, false ); queryKernel.finishStage(stageId, true); + didSomething = true; + } + return didSomething; + } + + /** + * Start a {@link ControllerQueryResultsReader} that pushes results to our {@link QueryListener}. + * + * The reader runs in a single-threaded executor that is created by this method, and shut down when results + * are done being read. + */ + private void startQueryResultsReader() + { + if (queryResultsReaderFuture != null) { + throw new ISE("Already started"); + } + + final StageId finalStageId = queryKernel.getStageId(queryDef.getFinalStageDefinition().getStageNumber()); + final List taskIds = getTaskIds(); + + final InputChannelFactory inputChannelFactory; + + if (queryKernelConfig.isDurableStorage() || MSQControllerTask.writeResultsToDurableStorage(querySpec)) { + inputChannelFactory = DurableStorageInputChannelFactory.createStandardImplementation( + queryId(), + MSQTasks.makeStorageConnector(context.injector()), + closer, + MSQControllerTask.writeResultsToDurableStorage(querySpec) + ); + } else { + inputChannelFactory = new WorkerInputChannelFactory(netClient, () -> taskIds); + } + + final FrameProcessorExecutor resultReaderExec = new FrameProcessorExecutor( + MoreExecutors.listeningDecorator( + Execs.singleThreaded(StringUtils.encodeForFormat("msq-result-reader[" + queryId() + "]"))) + ); + + final String cancellationId = "results-reader"; + ReadableConcatFrameChannel resultsChannel = null; + + try { + final InputChannels inputChannels = new InputChannelsImpl( + queryDef, + queryKernel.getResultPartitionsForStage(finalStageId), + inputChannelFactory, + () -> ArenaMemoryAllocator.createOnHeap(5_000_000), + resultReaderExec, + cancellationId + ); + + resultsChannel = ReadableConcatFrameChannel.open( + StreamSupport.stream(queryKernel.getResultPartitionsForStage(finalStageId).spliterator(), false) + .map( + readablePartition -> { + try { + return inputChannels.openChannel( + new StagePartition( + queryKernel.getStageDefinition(finalStageId).getId(), + readablePartition.getPartitionNumber() + ) + ); + } + catch (IOException e) { + throw new RuntimeException(e); + } + } + ) + .iterator() + ); + + final ControllerQueryResultsReader resultsReader = new ControllerQueryResultsReader( + resultsChannel, + queryDef.getFinalStageDefinition().getFrameReader(), + querySpec.getColumnMappings(), + resultsContext, + context.jsonMapper(), + queryListener + ); + + queryResultsReaderFuture = resultReaderExec.runFully(resultsReader, cancellationId); + + // When results are done being read, kick the main thread. + // Important: don't use FutureUtils.futureWithBaggage, because we need queryResultsReaderFuture to resolve + // *before* the main thread is kicked. + queryResultsReaderFuture.addListener( + () -> addToKernelManipulationQueue(holder -> {}), + Execs.directExecutor() + ); } + catch (Throwable e) { + // There was some issue setting up the result reader. Shut down the results channel and stop the executor. + final ReadableConcatFrameChannel finalResultsChannel = resultsChannel; + throw CloseableUtils.closeAndWrapInCatch( + e, + () -> CloseableUtils.closeAll( + finalResultsChannel, + () -> resultReaderExec.getExecutorService().shutdownNow() + ) + ); + } + + // Result reader is set up. Register with the query-wide closer. + closer.register(() -> { + try { + resultReaderExec.cancel(cancellationId); + } + catch (Exception e) { + throw new RuntimeException(e); + } + finally { + resultReaderExec.getExecutorService().shutdownNow(); + } + }); } /** @@ -3158,7 +3047,7 @@ private MSQErrorReport mapQueryColumnNameToOutputColumnName( .value(inbf.getValue()) .position(inbf.getPosition()) .build(), - task.getQuerySpec().getColumnMappings() + querySpec.getColumnMappings() ); } else if (workerErrorReport.getFault() instanceof InvalidFieldFault) { InvalidFieldFault iff = (InvalidFieldFault) workerErrorReport.getFault(); @@ -3172,7 +3061,7 @@ private MSQErrorReport mapQueryColumnNameToOutputColumnName( .column(iff.getColumn()) .errorMsg(iff.getErrorMsg()) .build(), - task.getQuerySpec().getColumnMappings() + querySpec.getColumnMappings() ); } else { return workerErrorReport; @@ -3185,7 +3074,7 @@ private MSQErrorReport mapQueryColumnNameToOutputColumnName( */ private interface TaskContactFn { - ListenableFuture contactTask(WorkerClient client, String taskId, int workerNumber); + ListenableFuture contactTask(WorkerClient client, String workerId, int workerNumber); } /** @@ -3193,7 +3082,6 @@ private interface TaskContactFn */ private interface TaskContactSuccess { - void onSuccess(String taskId); - + void onSuccess(String workerId, int workerNumber); } } diff --git a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/exec/ControllerMemoryParameters.java b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/exec/ControllerMemoryParameters.java new file mode 100644 index 000000000000..8e6fc72b6aa7 --- /dev/null +++ b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/exec/ControllerMemoryParameters.java @@ -0,0 +1,109 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.druid.msq.exec; + +import com.google.common.base.Preconditions; +import org.apache.druid.msq.indexing.error.MSQException; +import org.apache.druid.msq.indexing.error.NotEnoughMemoryFault; +import org.apache.druid.msq.kernel.controller.ControllerQueryKernel; +import org.apache.druid.msq.statistics.ClusterByStatisticsCollectorImpl; + +/** + * Class for determining how much JVM heap to allocate to various purposes for {@link Controller}. + * + * First, look at how much of total JVM heap that is dedicated for MSQ; see + * {@link MemoryIntrospector#usableMemoryInJvm()}. + * + * Then, we split up that total amount of memory into equally-sized portions per {@link Controller}; see + * {@link MemoryIntrospector#numQueriesInJvm()}. The number of controllers is based entirely on server configuration, + * which makes the calculation robust to different queries running simultaneously in the same JVM. + * + * Then, we split that up into a chunk used for input channels, and a chunk used for partition statistics. + */ +public class ControllerMemoryParameters +{ + /** + * Maximum number of bytes that we'll ever use for maxRetainedBytes of {@link ClusterByStatisticsCollectorImpl}. + */ + private static final long PARTITION_STATS_MAX_MEMORY = 300_000_000; + + /** + * Minimum number of bytes that is allowable for maxRetainedBytes of {@link ClusterByStatisticsCollectorImpl}. + */ + private static final long PARTITION_STATS_MIN_MEMORY = 25_000_000; + + /** + * Memory allocated to {@link ClusterByStatisticsCollectorImpl} as part of {@link ControllerQueryKernel}. + */ + private final int partitionStatisticsMaxRetainedBytes; + + public ControllerMemoryParameters(int partitionStatisticsMaxRetainedBytes) + { + this.partitionStatisticsMaxRetainedBytes = partitionStatisticsMaxRetainedBytes; + } + + /** + * Create an instance. + * + * @param memoryIntrospector memory introspector + * @param maxWorkerCount maximum worker count of the final stage + */ + public static ControllerMemoryParameters createProductionInstance( + final MemoryIntrospector memoryIntrospector, + final int maxWorkerCount + ) + { + final long usableMemoryInJvm = memoryIntrospector.usableMemoryInJvm(); + final int numControllersInJvm = memoryIntrospector.numQueriesInJvm(); + Preconditions.checkArgument(usableMemoryInJvm > 0, "Usable memory[%s] must be > 0", usableMemoryInJvm); + Preconditions.checkArgument(numControllersInJvm > 0, "Number of controllers[%s] must be > 0", numControllersInJvm); + Preconditions.checkArgument(maxWorkerCount > 0, "Number of workers[%s] must be > 0", maxWorkerCount); + + final long memoryPerController = usableMemoryInJvm / numControllersInJvm; + final long memoryForInputChannels = WorkerMemoryParameters.memoryNeededForInputChannels(maxWorkerCount); + final int partitionStatisticsMaxRetainedBytes = (int) Math.min( + memoryPerController - memoryForInputChannels, + PARTITION_STATS_MAX_MEMORY + ); + + if (partitionStatisticsMaxRetainedBytes < PARTITION_STATS_MIN_MEMORY) { + final long requiredMemory = memoryForInputChannels + PARTITION_STATS_MIN_MEMORY; + throw new MSQException( + new NotEnoughMemoryFault( + memoryIntrospector.computeJvmMemoryRequiredForUsableMemory(requiredMemory), + memoryIntrospector.totalMemoryInJvm(), + usableMemoryInJvm, + numControllersInJvm, + memoryIntrospector.numProcessorsInJvm() + ) + ); + } + + return new ControllerMemoryParameters(partitionStatisticsMaxRetainedBytes); + } + + /** + * Memory allocated to {@link ClusterByStatisticsCollectorImpl} as part of {@link ControllerQueryKernel}. + */ + public int getPartitionStatisticsMaxRetainedBytes() + { + return partitionStatisticsMaxRetainedBytes; + } +} diff --git a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/exec/ControllerQueryResultsReader.java b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/exec/ControllerQueryResultsReader.java new file mode 100644 index 000000000000..ae24704c9d8e --- /dev/null +++ b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/exec/ControllerQueryResultsReader.java @@ -0,0 +1,158 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.druid.msq.exec; + +import com.fasterxml.jackson.databind.ObjectMapper; +import com.google.common.collect.ImmutableList; +import it.unimi.dsi.fastutil.ints.IntSet; +import org.apache.druid.frame.Frame; +import org.apache.druid.frame.channel.ReadableFrameChannel; +import org.apache.druid.frame.channel.WritableFrameChannel; +import org.apache.druid.frame.processor.FrameProcessor; +import org.apache.druid.frame.processor.FrameProcessors; +import org.apache.druid.frame.processor.ReturnOrAwait; +import org.apache.druid.frame.read.FrameReader; +import org.apache.druid.java.util.common.guava.Yielder; +import org.apache.druid.java.util.common.guava.Yielders; +import org.apache.druid.java.util.common.logger.Logger; +import org.apache.druid.msq.indexing.report.MSQResultsReport; +import org.apache.druid.msq.util.SqlStatementResourceHelper; +import org.apache.druid.segment.column.RowSignature; +import org.apache.druid.sql.calcite.planner.ColumnMapping; +import org.apache.druid.sql.calcite.planner.ColumnMappings; +import org.apache.druid.utils.CloseableUtils; + +import java.io.IOException; +import java.util.Collections; +import java.util.List; + +/** + * Used by {@link ControllerImpl} to read query results and hand them to a {@link QueryListener}. + */ +public class ControllerQueryResultsReader implements FrameProcessor +{ + private static final Logger log = new Logger(ControllerQueryResultsReader.class); + + private final ReadableFrameChannel in; + private final FrameReader frameReader; + private final ColumnMappings columnMappings; + private final ResultsContext resultsContext; + private final ObjectMapper jsonMapper; + private final QueryListener queryListener; + + private boolean wroteResultsStart; + + ControllerQueryResultsReader( + final ReadableFrameChannel in, + final FrameReader frameReader, + final ColumnMappings columnMappings, + final ResultsContext resultsContext, + final ObjectMapper jsonMapper, + final QueryListener queryListener + ) + { + this.in = in; + this.frameReader = frameReader; + this.columnMappings = columnMappings; + this.resultsContext = resultsContext; + this.jsonMapper = jsonMapper; + this.queryListener = queryListener; + } + + @Override + public List inputChannels() + { + return Collections.singletonList(in); + } + + @Override + public List outputChannels() + { + return Collections.emptyList(); + } + + @Override + public ReturnOrAwait runIncrementally(final IntSet readableInputs) + { + if (readableInputs.isEmpty()) { + return ReturnOrAwait.awaitAll(inputChannels().size()); + } + + if (!wroteResultsStart) { + final RowSignature querySignature = frameReader.signature(); + final ImmutableList.Builder mappedSignature = ImmutableList.builder(); + + for (final ColumnMapping mapping : columnMappings.getMappings()) { + mappedSignature.add( + new MSQResultsReport.ColumnAndType( + mapping.getOutputColumn(), + querySignature.getColumnType(mapping.getQueryColumn()).orElse(null) + ) + ); + } + + queryListener.onResultsStart( + mappedSignature.build(), + resultsContext.getSqlTypeNames() + ); + + wroteResultsStart = true; + } + + // Read from query results channel, if it's open. + if (in.isFinished()) { + queryListener.onResultsComplete(); + return ReturnOrAwait.returnObject(null); + } else { + final Frame frame = in.read(); + Yielder rowYielder = Yielders.each( + SqlStatementResourceHelper.getResultSequence( + frame, + frameReader, + columnMappings, + resultsContext, + jsonMapper + ) + ); + + try { + while (!rowYielder.isDone()) { + if (queryListener.onResultRow(rowYielder.get())) { + rowYielder = rowYielder.next(null); + } else { + // Caller wanted to stop reading. + return ReturnOrAwait.returnObject(null); + } + } + } + finally { + CloseableUtils.closeAndSuppressExceptions(rowYielder, e -> log.warn(e, "Failed to close frame yielder")); + } + + return ReturnOrAwait.awaitAll(inputChannels().size()); + } + } + + @Override + public void cleanup() throws IOException + { + FrameProcessors.closeAll(inputChannels(), outputChannels()); + } +} diff --git a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/exec/ExceptionWrappingWorkerClient.java b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/exec/ExceptionWrappingWorkerClient.java index 93dbc0080045..8a7607d3159a 100644 --- a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/exec/ExceptionWrappingWorkerClient.java +++ b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/exec/ExceptionWrappingWorkerClient.java @@ -60,25 +60,23 @@ public ListenableFuture postWorkOrder(String workerTaskId, WorkOrder workO @Override public ListenableFuture fetchClusterByStatisticsSnapshot( String workerTaskId, - String queryId, - int stageNumber + StageId stageId ) { - return wrap(workerTaskId, client, c -> c.fetchClusterByStatisticsSnapshot(workerTaskId, queryId, stageNumber)); + return wrap(workerTaskId, client, c -> c.fetchClusterByStatisticsSnapshot(workerTaskId, stageId)); } @Override public ListenableFuture fetchClusterByStatisticsSnapshotForTimeChunk( String workerTaskId, - String queryId, - int stageNumber, + StageId stageId, long timeChunk ) { return wrap( workerTaskId, client, - c -> c.fetchClusterByStatisticsSnapshotForTimeChunk(workerTaskId, queryId, stageNumber, timeChunk) + c -> c.fetchClusterByStatisticsSnapshotForTimeChunk(workerTaskId, stageId, timeChunk) ); } diff --git a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/exec/Limits.java b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/exec/Limits.java index 8b6f26770a5d..bb782cb67d9a 100644 --- a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/exec/Limits.java +++ b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/exec/Limits.java @@ -86,9 +86,10 @@ public class Limits public static final long MAX_WORKERS_FOR_PARALLEL_MERGE = 100; /** - * Max number of rows in the query reports of the SELECT queries run by MSQ. This ensures that the reports donot blow - * up for queries operating on larger datasets. The full result of the select query should be available once the - * MSQ is able to run async queries + * Max number of rows in the query reports of SELECT queries run by MSQ when using + * {@link org.apache.druid.msq.indexing.destination.DurableStorageMSQDestination}. Reports in this mode contain a + * preview of actual query results, but not the full resultset.This ensures that the reports do not blow up in + * size for queries operating on larger datasets. */ public static final long MAX_SELECT_RESULT_ROWS = 3_000; diff --git a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/exec/MemoryIntrospector.java b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/exec/MemoryIntrospector.java new file mode 100644 index 000000000000..337e36d14efa --- /dev/null +++ b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/exec/MemoryIntrospector.java @@ -0,0 +1,65 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.druid.msq.exec; + +import org.apache.druid.msq.kernel.WorkOrder; + +/** + * Introspector used to generate {@link ControllerMemoryParameters}. + */ +public interface MemoryIntrospector +{ + /** + * Amount of total memory in the entire JVM. + */ + long totalMemoryInJvm(); + + /** + * Amount of memory usable for the multi-stage query engine in the entire JVM. + * + * This may be an expensive operation. For example, the production implementation {@link MemoryIntrospectorImpl} + * estimates size of all lookups as part of computing this value. + */ + long usableMemoryInJvm(); + + /** + * Amount of total JVM memory required for a particular amount of usable memory to be available. + * + * This may be an expensive operation. For example, the production implementation {@link MemoryIntrospectorImpl} + * estimates size of all lookups as part of computing this value. + */ + long computeJvmMemoryRequiredForUsableMemory(long usableMemory); + + /** + * Maximum number of queries that run simultaneously in this JVM. + * + * On workers, this is the maximum number of {@link Worker} that run simultaneously in this JVM. See + * {@link WorkerMemoryParameters} for how memory is divided among and within {@link WorkOrder} run by a worker. + * + * On controllers, this is the maximum number of {@link Controller} that run simultaneously. See + * {@link ControllerMemoryParameters} for how memory is used by controllers. + */ + int numQueriesInJvm(); + + /** + * Maximum number of processing threads that can be used at once in this JVM. + */ + int numProcessorsInJvm(); +} diff --git a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/exec/MemoryIntrospectorImpl.java b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/exec/MemoryIntrospectorImpl.java new file mode 100644 index 000000000000..f7cd501ed8fd --- /dev/null +++ b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/exec/MemoryIntrospectorImpl.java @@ -0,0 +1,140 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.druid.msq.exec; + +import com.google.common.collect.ImmutableList; +import org.apache.druid.java.util.common.logger.Logger; +import org.apache.druid.query.DruidProcessingConfig; +import org.apache.druid.query.lookup.LookupExtractor; +import org.apache.druid.query.lookup.LookupExtractorFactoryContainer; +import org.apache.druid.query.lookup.LookupExtractorFactoryContainerProvider; + +import java.util.List; + +/** + * Production implementation of {@link MemoryIntrospector}. + */ +public class MemoryIntrospectorImpl implements MemoryIntrospector +{ + private static final Logger log = new Logger(MemoryIntrospectorImpl.class); + + private final LookupExtractorFactoryContainerProvider lookupProvider; + private final long totalMemoryInJvm; + private final int numQueriesInJvm; + private final int numProcessorsInJvm; + private final double usableMemoryFraction; + + /** + * Create an introspector. + * + * @param lookupProvider provider of lookups; we use this to subtract lookup size from total JVM memory when + * computing usable memory + * @param totalMemoryInJvm maximum JVM heap memory + * @param usableMemoryFraction fraction of JVM memory, after subtracting lookup overhead, that we consider usable + * for multi-stage queries + * @param numQueriesInJvm maximum number of {@link Controller} or {@link Worker} that may run concurrently + * @param numProcessorsInJvm size of processing thread pool, typically {@link DruidProcessingConfig#getNumThreads()} + */ + public MemoryIntrospectorImpl( + final LookupExtractorFactoryContainerProvider lookupProvider, + final long totalMemoryInJvm, + final double usableMemoryFraction, + final int numQueriesInJvm, + final int numProcessorsInJvm + ) + { + this.lookupProvider = lookupProvider; + this.totalMemoryInJvm = totalMemoryInJvm; + this.numQueriesInJvm = numQueriesInJvm; + this.numProcessorsInJvm = numProcessorsInJvm; + this.usableMemoryFraction = usableMemoryFraction; + } + + @Override + public long totalMemoryInJvm() + { + return totalMemoryInJvm; + } + + @Override + public long usableMemoryInJvm() + { + final long totalMemory = totalMemoryInJvm(); + final long totalLookupFootprint = computeTotalLookupFootprint(true); + return Math.max( + 0, + (long) ((totalMemory - totalLookupFootprint) * usableMemoryFraction) + ); + } + + @Override + public long computeJvmMemoryRequiredForUsableMemory(long usableMemory) + { + final long totalLookupFootprint = computeTotalLookupFootprint(false); + return (long) Math.ceil(usableMemory / usableMemoryFraction + totalLookupFootprint); + } + + @Override + public int numQueriesInJvm() + { + return numQueriesInJvm; + } + + @Override + public int numProcessorsInJvm() + { + return numProcessorsInJvm; + } + + /** + * Compute and return total estimated lookup footprint. + * + * Correctness of this approach depends on lookups being loaded *before* calling this method. Luckily, this is the + * typical mode of operation, since by default druid.lookup.enableLookupSyncOnStartup = true. + * + * @param logFootprint whether footprint should be logged + */ + private long computeTotalLookupFootprint(final boolean logFootprint) + { + final List lookupNames = ImmutableList.copyOf(lookupProvider.getAllLookupNames()); + + long lookupFootprint = 0; + + for (final String lookupName : lookupNames) { + final LookupExtractorFactoryContainer container = lookupProvider.get(lookupName).orElse(null); + + if (container != null) { + try { + final LookupExtractor extractor = container.getLookupExtractorFactory().get(); + lookupFootprint += extractor.estimateHeapFootprint(); + } + catch (Exception e) { + log.noStackTrace().warn(e, "Failed to load lookup[%s] for size estimation. Skipping.", lookupName); + } + } + } + + if (logFootprint) { + log.info("Lookup footprint: lookup count[%d], total bytes[%,d].", lookupNames.size(), lookupFootprint); + } + + return lookupFootprint; + } +} diff --git a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/exec/OutputChannelMode.java b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/exec/OutputChannelMode.java new file mode 100644 index 000000000000..7e7fc3d3d6f3 --- /dev/null +++ b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/exec/OutputChannelMode.java @@ -0,0 +1,92 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.druid.msq.exec; + +import com.fasterxml.jackson.annotation.JsonCreator; +import com.fasterxml.jackson.annotation.JsonValue; +import org.apache.druid.java.util.common.IAE; +import org.apache.druid.msq.kernel.QueryDefinition; +import org.apache.druid.msq.kernel.WorkOrder; +import org.apache.druid.msq.kernel.controller.ControllerQueryKernelUtils; + +/** + * Mode for stage output channels. Provided to workers in {@link WorkOrder#getOutputChannelMode()}. + */ +public enum OutputChannelMode +{ + /** + * In-memory output channels. Stage shuffle data does not hit disk. This mode requires a consumer stage to run + * at the same time as its corresponding producer stage. See {@link ControllerQueryKernelUtils#computeStageGroups} for the + * logic that determines when we can use in-memory channels. + */ + MEMORY("memory"), + + /** + * Local file output channels. Stage shuffle data is stored in files on disk on the producer, and served via HTTP + * to the consumer. + */ + LOCAL_STORAGE("localStorage"), + + /** + * Durable storage output channels. Stage shuffle data is written by producers to durable storage (e.g. cloud + * storage), and is read from durable storage by consumers. + */ + DURABLE_STORAGE_INTERMEDIATE("durableStorage"), + + /** + * Like {@link #DURABLE_STORAGE_INTERMEDIATE}, but a special case for the final stage + * {@link QueryDefinition#getFinalStageDefinition()}. The structure of files in deep storage is somewhat different. + */ + DURABLE_STORAGE_QUERY_RESULTS("durableStorageQueryResults"); + + private final String name; + + OutputChannelMode(String name) + { + this.name = name; + } + + @JsonCreator + public static OutputChannelMode fromString(final String s) + { + for (final OutputChannelMode mode : values()) { + if (mode.toString().equals(s)) { + return mode; + } + } + + throw new IAE("No such outputChannelMode[%s]", s); + } + + /** + * Whether this mode involves writing to durable storage. + */ + public boolean isDurable() + { + return this == DURABLE_STORAGE_INTERMEDIATE || this == DURABLE_STORAGE_QUERY_RESULTS; + } + + @Override + @JsonValue + public String toString() + { + return name; + } +} diff --git a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/exec/QueryListener.java b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/exec/QueryListener.java new file mode 100644 index 000000000000..997fe4c8682d --- /dev/null +++ b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/exec/QueryListener.java @@ -0,0 +1,71 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.druid.msq.exec; + +import org.apache.calcite.sql.type.SqlTypeName; +import org.apache.druid.msq.indexing.report.MSQResultsReport; +import org.apache.druid.msq.indexing.report.MSQTaskReportPayload; + +import javax.annotation.Nullable; +import java.util.List; + +/** + * Object passed to {@link Controller#run(QueryListener)} to enable retrieval of results, status, counters, etc. + */ +public interface QueryListener +{ + /** + * Whether this listener is meant to receive results. + */ + boolean readResults(); + + /** + * Called when results start coming in. + * + * @param signature signature of results + * @param sqlTypeNames SQL type names of results; same length as the signature + */ + void onResultsStart( + List signature, + @Nullable List sqlTypeNames + ); + + /** + * Called for each result row. Follows a call to {@link #onResultsStart(List, List)}. + * + * @param row result row + * + * @return whether the controller should keep reading results + */ + boolean onResultRow(Object[] row); + + /** + * Called after the last result has been delivered by {@link #onResultRow(Object[])}. Only called if results are + * actually complete. If results are truncated due to {@link #readResults()} or {@link #onResultRow(Object[])} + * returning false, this method is not called. + */ + void onResultsComplete(); + + /** + * Called when the query is complete and a report is available. After this method is called, no other methods + * will be called. The report will not include a {@link MSQResultsReport}. + */ + void onQueryComplete(MSQTaskReportPayload report); +} diff --git a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/exec/ResultsContext.java b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/exec/ResultsContext.java new file mode 100644 index 000000000000..9e565bb75a5d --- /dev/null +++ b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/exec/ResultsContext.java @@ -0,0 +1,86 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.druid.msq.exec; + +import org.apache.calcite.sql.type.SqlTypeName; +import org.apache.druid.sql.calcite.run.SqlResults; + +import javax.annotation.Nullable; +import java.util.List; +import java.util.Objects; + +/** + * Holder for objects needed to interpret SQL results. + */ +public class ResultsContext +{ + private final List sqlTypeNames; + private final SqlResults.Context sqlResultsContext; + + public ResultsContext( + final List sqlTypeNames, + final SqlResults.Context sqlResultsContext + ) + { + this.sqlTypeNames = sqlTypeNames; + this.sqlResultsContext = sqlResultsContext; + } + + @Nullable + public List getSqlTypeNames() + { + return sqlTypeNames; + } + + @Nullable + public SqlResults.Context getSqlResultsContext() + { + return sqlResultsContext; + } + + @Override + public boolean equals(Object o) + { + if (this == o) { + return true; + } + if (o == null || getClass() != o.getClass()) { + return false; + } + ResultsContext that = (ResultsContext) o; + return Objects.equals(sqlTypeNames, that.sqlTypeNames) + && Objects.equals(sqlResultsContext, that.sqlResultsContext); + } + + @Override + public int hashCode() + { + return Objects.hash(sqlTypeNames, sqlResultsContext); + } + + @Override + public String toString() + { + return "ResultsContext{" + + "sqlTypeNames=" + sqlTypeNames + + ", sqlResultsContext=" + sqlResultsContext + + '}'; + } +} diff --git a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/exec/RetryCapableWorkerManager.java b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/exec/RetryCapableWorkerManager.java new file mode 100644 index 000000000000..d5b3d41d7a2b --- /dev/null +++ b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/exec/RetryCapableWorkerManager.java @@ -0,0 +1,45 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.druid.msq.exec; + +/** + * Expanded {@link WorkerManager} interface with methods to support retrying workers. + */ +public interface RetryCapableWorkerManager extends WorkerManager +{ + /** + * Queues worker for relaunch. A noop if the worker is already in the queue. + */ + void submitForRelaunch(int workerNumber); + + /** + * Report a worker that failed without active orders. To be retried if it is requried for future stages only. + */ + void reportFailedInactiveWorker(int workerNumber); + + /** + * Checks if the controller has canceled the input taskId. This method is used in {@link ControllerImpl} + * to figure out if the worker taskId is canceled by the controller. If yes, the errors from that worker taskId + * are ignored for the error reports. + * + * @return true if task is canceled by the controller, else false + */ + boolean isTaskCanceledByController(String taskId); +} diff --git a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/exec/SegmentLoadStatusFetcher.java b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/exec/SegmentLoadStatusFetcher.java index 1546766f856f..d4eaef600125 100644 --- a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/exec/SegmentLoadStatusFetcher.java +++ b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/exec/SegmentLoadStatusFetcher.java @@ -104,7 +104,7 @@ public class SegmentLoadStatusFetcher implements AutoCloseable public SegmentLoadStatusFetcher( BrokerClient brokerClient, ObjectMapper objectMapper, - String taskId, + String queryId, String datasource, Set dataSegments, boolean doWait @@ -128,7 +128,9 @@ public SegmentLoadStatusFetcher( totalSegmentsGenerated )); this.doWait = doWait; - this.executorService = MoreExecutors.listeningDecorator(Execs.singleThreaded(taskId + "-segment-load-waiter-%d")); + this.executorService = MoreExecutors.listeningDecorator( + Execs.singleThreaded(StringUtils.encodeForFormat(queryId) + "-segment-load-waiter-%d") + ); } /** diff --git a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/exec/WorkerClient.java b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/exec/WorkerClient.java index 5c02a79f89a3..572051124a74 100644 --- a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/exec/WorkerClient.java +++ b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/exec/WorkerClient.java @@ -27,26 +27,27 @@ import org.apache.druid.msq.kernel.WorkOrder; import org.apache.druid.msq.statistics.ClusterByStatisticsSnapshot; +import java.io.Closeable; import java.io.IOException; /** - * Client for multi-stage query workers. Used by the controller task. + * Client for {@link Worker}. Each instance is scoped to a single query, and can communicate with all workers for + * that particular query. */ -public interface WorkerClient extends AutoCloseable +public interface WorkerClient extends Closeable { /** * Worker's client method to add a {@link WorkOrder} to the worker to work on */ - ListenableFuture postWorkOrder(String workerTaskId, WorkOrder workOrder); + ListenableFuture postWorkOrder(String workerId, WorkOrder workOrder); /** * Fetches the {@link ClusterByStatisticsSnapshot} from a worker. This is intended to be used by the * {@link WorkerSketchFetcher} under PARALLEL or AUTO modes. */ ListenableFuture fetchClusterByStatisticsSnapshot( - String workerTaskId, - String queryId, - int stageNumber + String workerId, + StageId stageId ); /** @@ -54,9 +55,8 @@ ListenableFuture fetchClusterByStatisticsSnapshot( * This is intended to be used by the {@link WorkerSketchFetcher} under SEQUENTIAL or AUTO modes. */ ListenableFuture fetchClusterByStatisticsSnapshotForTimeChunk( - String workerTaskId, - String queryId, - int stageNumber, + String workerId, + StageId stageId, long timeChunk ); @@ -65,28 +65,26 @@ ListenableFuture fetchClusterByStatisticsSnapshotFo * controller after collating the result statistics from all the workers processing the query */ ListenableFuture postResultPartitionBoundaries( - String workerTaskId, + String workerId, StageId stageId, ClusterByPartitions partitionBoundaries ); /** - * Worker's client method to inform that the work has been done, and it can initiate cleanup and shutdown - * @param workerTaskId + * Fetches counters from a worker. */ - ListenableFuture postFinish(String workerTaskId); + ListenableFuture getCounters(String workerId); /** - * Fetches all the counters gathered by that worker - * @param workerTaskId + * Worker's client method that informs it that the results and resources for the given stage are no longer required + * and that they can be cleaned up */ - ListenableFuture getCounters(String workerTaskId); + ListenableFuture postCleanupStage(String workerId, StageId stageId); /** - * Worker's client method that informs it that the results and resources for the given stage are no longer required - * and that they can be cleaned up + * Worker's client method to inform that the work has been done, and it can initiate cleanup and shutdown. */ - ListenableFuture postCleanupStage(String workerTaskId, StageId stageId); + ListenableFuture postFinish(String workerId); /** * Fetch some data from a worker and add it to the provided channel. The exact amount of data is determined @@ -96,13 +94,16 @@ ListenableFuture postResultPartitionBoundaries( * kind of unrecoverable exception). */ ListenableFuture fetchChannelData( - String workerTaskId, + String workerId, StageId stageId, int partitionNumber, long offset, ReadableByteChunksFrameChannel channel ); + /** + * Close this client and release resources. + */ @Override void close() throws IOException; } diff --git a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/indexing/RetryTask.java b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/exec/WorkerFailureListener.java similarity index 72% rename from extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/indexing/RetryTask.java rename to extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/exec/WorkerFailureListener.java index 39fb1e688ecf..9bc4ed56cde7 100644 --- a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/indexing/RetryTask.java +++ b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/exec/WorkerFailureListener.java @@ -17,17 +17,18 @@ * under the License. */ -package org.apache.druid.msq.indexing; +package org.apache.druid.msq.exec; +import org.apache.druid.msq.indexing.MSQWorkerTask; import org.apache.druid.msq.indexing.error.MSQFault; -public interface RetryTask +/** + * Notifies users of {@link WorkerManager} when a worker fails. + */ +public interface WorkerFailureListener { /** - * Retry task when {@link MSQFault} is encountered. - * - * @param workerTask - * @param msqFault + * Fires when a worker launched or monitoring by {@link WorkerManager} fails. */ - void retry(MSQWorkerTask workerTask, MSQFault msqFault); + void onFailure(MSQWorkerTask workerTask, MSQFault msqFault); } diff --git a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/exec/WorkerManager.java b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/exec/WorkerManager.java new file mode 100644 index 000000000000..ebce4821d591 --- /dev/null +++ b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/exec/WorkerManager.java @@ -0,0 +1,92 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.druid.msq.exec; + +import com.google.common.util.concurrent.ListenableFuture; +import org.apache.druid.msq.indexing.WorkerCount; + +import java.util.List; +import java.util.Map; +import java.util.Set; + +/** + * Used by {@link ControllerImpl} to discover and manage workers. + * + * Worker managers capable of retrying should extend {@link RetryCapableWorkerManager} (an extension of this interface). + */ +public interface WorkerManager +{ + int UNKNOWN_WORKER_NUMBER = -1; + + /** + * Starts this manager. + * + * Returns a future that resolves successfully when all workers end successfully or are canceled. The returned future + * resolves to an exception if one of the workers fails without being explicitly canceled, or if something else + * goes wrong. + */ + ListenableFuture start(); + + /** + * Launch additional workers, if needed, to bring the number of running workers up to {@code workerCount}. + * Blocks until the requested workers are launched. If enough workers are already running, this method does nothing. + */ + void launchWorkersIfNeeded(int workerCount) throws InterruptedException; + + /** + * Blocks until workers with the provided worker numbers (indexes into {@link #getWorkerIds()} are ready to be + * contacted for work. + */ + void waitForWorkers(Set workerNumbers) throws InterruptedException; + + /** + * List of currently-active workers. + */ + List getWorkerIds(); + + /** + * Number of currently-active and currently-pending workers. + */ + WorkerCount getWorkerCount(); + + /** + * Worker number of a worker with the provided ID, or {@link #UNKNOWN_WORKER_NUMBER} if none exists. + */ + int getWorkerNumber(String workerId); + + /** + * Whether an active worker exists with the provided ID. + */ + boolean isWorkerActive(String workerId); + + /** + * Map of worker number to list of all workers currently running with that number. More than one worker per number + * only occurs when fault tolerance is enabled. + */ + Map> getWorkerStats(); + + /** + * Blocks until all workers exit. Returns quietly, no matter whether there was an exception associated with the + * future from {@link #start()} or not. + * + * @param interrupt whether to interrupt currently-running work + */ + void stop(boolean interrupt); +} diff --git a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/exec/WorkerMemoryParameters.java b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/exec/WorkerMemoryParameters.java index a09d0508485d..b36b1b4155a8 100644 --- a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/exec/WorkerMemoryParameters.java +++ b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/exec/WorkerMemoryParameters.java @@ -473,6 +473,16 @@ static int computeMaxWorkers( ); } + /** + * Computes the amount of memory needed to read a single partition from a given number of workers. + */ + static long memoryNeededForInputChannels(final int numInputWorkers) + { + // Workers that read sorted inputs must open all channels at once to do an N-way merge. Calculate memory needs. + // Requirement: one input frame per worker, one buffered output frame. + return (long) STANDARD_FRAME_SIZE * (numInputWorkers + 1); + } + /** * Maximum number of workers that may exist in the current JVM. */ @@ -563,13 +573,6 @@ private static long estimateUsableMemory(final int numWorkersInJvm, final long e return estimatedTotalBundleMemory + (estimateStatOverHeadPerWorker * numWorkersInJvm); } - private static long memoryNeededForInputChannels(final int numInputWorkers) - { - // Workers that read sorted inputs must open all channels at once to do an N-way merge. Calculate memory needs. - // Requirement: one input frame per worker, one buffered output frame. - return (long) STANDARD_FRAME_SIZE * (numInputWorkers + 1); - } - private static long memoryNeededForHashPartitioning(final int numOutputPartitions) { // One standard frame for each processor output. diff --git a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/exec/WorkerSketchFetcher.java b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/exec/WorkerSketchFetcher.java index 271ce8ff0709..73f151fcdaa9 100644 --- a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/exec/WorkerSketchFetcher.java +++ b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/exec/WorkerSketchFetcher.java @@ -30,7 +30,6 @@ import org.apache.druid.java.util.common.concurrent.Execs; import org.apache.druid.java.util.common.function.TriConsumer; import org.apache.druid.java.util.common.logger.Logger; -import org.apache.druid.msq.indexing.MSQWorkerTaskLauncher; import org.apache.druid.msq.indexing.error.MSQFault; import org.apache.druid.msq.indexing.error.WorkerRpcFailedFault; import org.apache.druid.msq.kernel.StageId; @@ -58,23 +57,23 @@ public class WorkerSketchFetcher implements AutoCloseable private static final int DEFAULT_THREAD_COUNT = 4; private final WorkerClient workerClient; - private final MSQWorkerTaskLauncher workerTaskLauncher; + private final WorkerManager workerManager; private final boolean retryEnabled; - private AtomicReference isError = new AtomicReference<>(); + private final AtomicReference isError = new AtomicReference<>(); final ExecutorService executorService; public WorkerSketchFetcher( WorkerClient workerClient, - MSQWorkerTaskLauncher workerTaskLauncher, + WorkerManager workerManager, boolean retryEnabled ) { this.workerClient = workerClient; this.executorService = Execs.multiThreaded(DEFAULT_THREAD_COUNT, "SketchFetcherThreadPool-%d"); - this.workerTaskLauncher = workerTaskLauncher; + this.workerManager = workerManager; this.retryEnabled = retryEnabled; } @@ -93,21 +92,14 @@ public void inMemoryFullSketchMerging( for (String taskId : taskIds) { try { - int workerNumber = MSQTasks.workerFromTaskId(taskId); + int workerNumber = workerManager.getWorkerNumber(taskId); executorService.submit(() -> { fetchStatsFromWorker( kernelActions, - () -> workerClient.fetchClusterByStatisticsSnapshot( - taskId, - stageId.getQueryId(), - stageId.getStageNumber() - ), + () -> workerClient.fetchClusterByStatisticsSnapshot(taskId, stageId), taskId, - (kernel, snapshot) -> kernel.mergeClusterByStatisticsCollectorForAllTimeChunks( - stageId, - workerNumber, - snapshot - ), + (kernel, snapshot) -> + kernel.mergeClusterByStatisticsCollectorForAllTimeChunks(stageId, workerNumber, snapshot), retryOperation ); }); @@ -135,9 +127,14 @@ private void fetchStatsFromWorker( executorService.shutdownNow(); return; } - int worker = MSQTasks.workerFromTaskId(taskId); + int worker = workerManager.getWorkerNumber(taskId); + if (worker == WorkerManager.UNKNOWN_WORKER_NUMBER) { + log.info("Task[%s] is no longer the latest task for worker[%d]. Skipping fetch.", taskId, worker); + return; + } + try { - workerTaskLauncher.waitUntilWorkersReady(ImmutableSet.of(worker)); + workerManager.waitForWorkers(ImmutableSet.of(worker)); } catch (InterruptedException interruptedException) { isError.compareAndSet(null, interruptedException); @@ -146,12 +143,8 @@ private void fetchStatsFromWorker( } // if task is not the latest task. It must have retried. - if (!workerTaskLauncher.isTaskLatest(taskId)) { - log.info( - "Task[%s] is no longer the latest task for worker[%d], hence ignoring fetching stats from this worker", - taskId, - worker - ); + if (!workerManager.isWorkerActive(taskId)) { + log.info("Task[%s] is no longer the latest task for worker[%d]. Skipping fetch.", taskId, worker); return; } @@ -250,7 +243,7 @@ public void sequentialTimeChunkMerging( completeKeyStatisticsInformation.getTimeSegmentVsWorkerMap().forEach((timeChunk, wks) -> { for (String taskId : tasks) { - int workerNumber = MSQTasks.workerFromTaskId(taskId); + int workerNumber = workerManager.getWorkerNumber(taskId); if (wks.contains(workerNumber)) { noBoundaries.remove(taskId); executorService.submit(() -> { @@ -258,8 +251,7 @@ public void sequentialTimeChunkMerging( kernelActions, () -> workerClient.fetchClusterByStatisticsSnapshotForTimeChunk( taskId, - stageId.getQueryId(), - stageId.getStageNumber(), + new StageId(stageId.getQueryId(), stageId.getStageNumber()), timeChunk ), taskId, @@ -281,7 +273,7 @@ public void sequentialTimeChunkMerging( for (String taskId : noBoundaries) { kernelActions.accept( kernel -> { - final int workerNumber = MSQTasks.workerFromTaskId(taskId); + final int workerNumber = workerManager.getWorkerNumber(taskId); kernel.mergeClusterByStatisticsCollectorForAllTimeChunks( stageId, workerNumber, diff --git a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/exec/WorkerStats.java b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/exec/WorkerStats.java new file mode 100644 index 000000000000..831ea645e40a --- /dev/null +++ b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/exec/WorkerStats.java @@ -0,0 +1,105 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.druid.msq.exec; + +import com.fasterxml.jackson.annotation.JsonCreator; +import com.fasterxml.jackson.annotation.JsonProperty; +import org.apache.druid.indexer.TaskState; + +import java.util.Objects; + +public class WorkerStats +{ + private final String workerId; + private final TaskState state; + private final long durationMs; + private final long pendingMs; + + @JsonCreator + public WorkerStats( + @JsonProperty("workerId") String workerId, + @JsonProperty("state") TaskState state, + @JsonProperty("durationMs") long durationMs, + @JsonProperty("pendingMs") long pendingMs + ) + { + this.workerId = workerId; + this.state = state; + this.durationMs = durationMs; + this.pendingMs = pendingMs; + } + + @JsonProperty + public String getWorkerId() + { + return workerId; + } + + @JsonProperty + public TaskState getState() + { + return state; + } + + @JsonProperty("durationMs") + public long getDuration() + { + return durationMs; + } + + @JsonProperty("pendingMs") + public long getPendingTimeInMs() + { + return pendingMs; + } + + @Override + public boolean equals(Object o) + { + if (this == o) { + return true; + } + if (o == null || getClass() != o.getClass()) { + return false; + } + WorkerStats that = (WorkerStats) o; + return durationMs == that.durationMs + && pendingMs == that.pendingMs + && Objects.equals(workerId, that.workerId) + && state == that.state; + } + + @Override + public int hashCode() + { + return Objects.hash(workerId, state, durationMs, pendingMs); + } + + @Override + public String toString() + { + return "WorkerStats{" + + "workerId='" + workerId + '\'' + + ", state=" + state + + ", durationMs=" + durationMs + + ", pendingMs=" + pendingMs + + '}'; + } +} diff --git a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/guice/IndexerMemoryManagementModule.java b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/guice/IndexerMemoryManagementModule.java new file mode 100644 index 000000000000..92f16a631d9f --- /dev/null +++ b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/guice/IndexerMemoryManagementModule.java @@ -0,0 +1,78 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.druid.msq.guice; + +import com.google.inject.Binder; +import com.google.inject.Provides; +import org.apache.druid.discovery.NodeRole; +import org.apache.druid.frame.processor.Bouncer; +import org.apache.druid.guice.LazySingleton; +import org.apache.druid.guice.annotations.LoadScope; +import org.apache.druid.indexing.worker.config.WorkerConfig; +import org.apache.druid.initialization.DruidModule; +import org.apache.druid.msq.exec.MemoryIntrospector; +import org.apache.druid.msq.exec.MemoryIntrospectorImpl; +import org.apache.druid.query.DruidProcessingConfig; +import org.apache.druid.query.lookup.LookupExtractorFactoryContainerProvider; +import org.apache.druid.utils.JvmUtils; + +/** + * Provides {@link MemoryIntrospector} for multi-task-per-JVM model. + * + * @see PeonMemoryManagementModule for single-task-per-JVM model used on {@link org.apache.druid.cli.CliPeon} + */ +@LoadScope(roles = NodeRole.INDEXER_JSON_NAME) +public class IndexerMemoryManagementModule implements DruidModule +{ + /** + * Allocate up to 75% of memory for MSQ-related stuff (if all running tasks are MSQ tasks). + */ + private static final double USABLE_MEMORY_FRACTION = 0.75; + + @Override + public void configure(Binder binder) + { + // Nothing to do. + } + + @Provides + @LazySingleton + public Bouncer makeProcessorBouncer(final DruidProcessingConfig processingConfig) + { + return new Bouncer(processingConfig.getNumThreads()); + } + + @Provides + @LazySingleton + public MemoryIntrospector createMemoryIntrospector( + final LookupExtractorFactoryContainerProvider lookupProvider, + final DruidProcessingConfig processingConfig, + final WorkerConfig workerConfig + ) + { + return new MemoryIntrospectorImpl( + lookupProvider, + JvmUtils.getRuntimeInfo().getMaxHeapSizeBytes(), + USABLE_MEMORY_FRACTION, + workerConfig.getCapacity(), + processingConfig.getNumThreads() + ); + } +} diff --git a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/guice/MSQIndexingModule.java b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/guice/MSQIndexingModule.java index 125a66331e60..f4d24cfc5c4c 100644 --- a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/guice/MSQIndexingModule.java +++ b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/guice/MSQIndexingModule.java @@ -23,11 +23,6 @@ import com.fasterxml.jackson.databind.module.SimpleModule; import com.google.common.collect.ImmutableList; import com.google.inject.Binder; -import com.google.inject.Provides; -import org.apache.druid.discovery.NodeRole; -import org.apache.druid.frame.processor.Bouncer; -import org.apache.druid.guice.LazySingleton; -import org.apache.druid.guice.annotations.Self; import org.apache.druid.initialization.DruidModule; import org.apache.druid.msq.counters.ChannelCounters; import org.apache.druid.msq.counters.CounterSnapshotsSerializer; @@ -94,11 +89,9 @@ import org.apache.druid.msq.querykit.results.QueryResultFrameProcessorFactory; import org.apache.druid.msq.querykit.scan.ScanQueryFrameProcessorFactory; import org.apache.druid.msq.util.PassthroughAggregatorFactory; -import org.apache.druid.query.DruidProcessingConfig; import java.util.Collections; import java.util.List; -import java.util.Set; /** * Module that adds {@link MSQControllerTask}, {@link MSQWorkerTask}, and dependencies. @@ -206,17 +199,4 @@ public List getJacksonModules() public void configure(Binder binder) { } - - @Provides - @LazySingleton - public Bouncer makeBouncer(final DruidProcessingConfig processingConfig, @Self Set nodeRoles) - { - if (nodeRoles.contains(NodeRole.PEON) && !nodeRoles.contains(NodeRole.INDEXER)) { - // CliPeon -> use only one thread regardless of configured # of processing threads. This matches the expected - // resource usage pattern for CliPeon-based tasks (one task / one working thread per JVM). - return new Bouncer(1); - } else { - return new Bouncer(processingConfig.getNumThreads()); - } - } } diff --git a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/guice/MSQSqlModule.java b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/guice/MSQSqlModule.java index 8e381e50bd01..ea6eb364cece 100644 --- a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/guice/MSQSqlModule.java +++ b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/guice/MSQSqlModule.java @@ -25,7 +25,6 @@ import org.apache.druid.discovery.NodeRole; import org.apache.druid.guice.LazySingleton; import org.apache.druid.guice.annotations.LoadScope; -import org.apache.druid.guice.annotations.MSQ; import org.apache.druid.initialization.DruidModule; import org.apache.druid.metadata.input.InputSourceModule; import org.apache.druid.msq.sql.MSQTaskSqlEngine; @@ -62,11 +61,11 @@ public void configure(Binder binder) } @Provides - @MSQ + @MultiStageQuery @LazySingleton public SqlStatementFactory makeMSQSqlStatementFactory( final MSQTaskSqlEngine engine, - SqlToolbox toolbox + final SqlToolbox toolbox ) { return new SqlStatementFactory(toolbox.withEngine(engine)); diff --git a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/guice/MultiStageQuery.java b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/guice/MultiStageQuery.java index 986b1cfb2545..ba017a2c7863 100644 --- a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/guice/MultiStageQuery.java +++ b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/guice/MultiStageQuery.java @@ -26,6 +26,9 @@ import java.lang.annotation.RetentionPolicy; import java.lang.annotation.Target; +/** + * Binding annotation for implements of interfaces that are MSQ (MultiStageQuery) focused. + */ @Target({ElementType.FIELD, ElementType.PARAMETER, ElementType.METHOD}) @Retention(RetentionPolicy.RUNTIME) @BindingAnnotation diff --git a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/guice/PeonMemoryManagementModule.java b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/guice/PeonMemoryManagementModule.java new file mode 100644 index 000000000000..9e814c082781 --- /dev/null +++ b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/guice/PeonMemoryManagementModule.java @@ -0,0 +1,85 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.druid.msq.guice; + +import com.google.inject.Binder; +import com.google.inject.Provides; +import org.apache.druid.discovery.NodeRole; +import org.apache.druid.frame.processor.Bouncer; +import org.apache.druid.guice.LazySingleton; +import org.apache.druid.guice.annotations.LoadScope; +import org.apache.druid.initialization.DruidModule; +import org.apache.druid.msq.exec.MemoryIntrospector; +import org.apache.druid.msq.exec.MemoryIntrospectorImpl; +import org.apache.druid.query.lookup.LookupExtractorFactoryContainerProvider; +import org.apache.druid.utils.JvmUtils; + +/** + * Provides {@link MemoryIntrospector} for single-task-per-JVM model. + * + * @see IndexerMemoryManagementModule for multi-task-per-JVM model used on {@link org.apache.druid.cli.CliIndexer} + */ +@LoadScope(roles = NodeRole.PEON_JSON_NAME) +public class PeonMemoryManagementModule implements DruidModule +{ + /** + * Peons have a single worker per JVM. + */ + private static final int NUM_WORKERS_IN_JVM = 1; + + /** + * Peons may have more than one processing thread, but we currently only use one of them. + */ + private static final int NUM_PROCESSING_THREADS = 1; + + /** + * Allocate 75% of memory for MSQ-related stuff. + */ + private static final double USABLE_MEMORY_FRACTION = 0.75; + + @Override + public void configure(Binder binder) + { + // Nothing to do. + } + + @Provides + @LazySingleton + public Bouncer makeProcessorBouncer() + { + return new Bouncer(NUM_PROCESSING_THREADS); + } + + @Provides + @LazySingleton + public MemoryIntrospector createMemoryIntrospector( + final LookupExtractorFactoryContainerProvider lookupProvider, + final Bouncer bouncer + ) + { + return new MemoryIntrospectorImpl( + lookupProvider, + JvmUtils.getRuntimeInfo().getMaxHeapSizeBytes(), + USABLE_MEMORY_FRACTION, + NUM_WORKERS_IN_JVM, + bouncer.getMaxCount() + ); + } +} diff --git a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/guice/SqlTaskModule.java b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/guice/SqlTaskModule.java index 52531294f341..d09f8613fa7e 100644 --- a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/guice/SqlTaskModule.java +++ b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/guice/SqlTaskModule.java @@ -19,7 +19,6 @@ package org.apache.druid.msq.guice; -import com.fasterxml.jackson.databind.Module; import com.google.inject.Binder; import org.apache.druid.discovery.NodeRole; import org.apache.druid.guice.Jerseys; @@ -29,9 +28,6 @@ import org.apache.druid.msq.sql.resources.SqlStatementResource; import org.apache.druid.msq.sql.resources.SqlTaskResource; -import java.util.Collections; -import java.util.List; - /** * Module for adding the {@link SqlTaskResource} endpoint to the Broker. */ @@ -47,10 +43,4 @@ public void configure(Binder binder) LifecycleModule.register(binder, SqlStatementResource.class); Jerseys.addResource(binder, SqlStatementResource.class); } - - @Override - public List getJacksonModules() - { - return Collections.emptyList(); - } } diff --git a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/indexing/IndexerControllerContext.java b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/indexing/IndexerControllerContext.java index aeee05e75067..3ff71c3e1b77 100644 --- a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/indexing/IndexerControllerContext.java +++ b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/indexing/IndexerControllerContext.java @@ -20,56 +20,110 @@ package org.apache.druid.msq.indexing; import com.fasterxml.jackson.databind.ObjectMapper; +import com.google.common.collect.ImmutableMap; import com.google.inject.Injector; import com.google.inject.Key; -import org.apache.druid.client.coordinator.CoordinatorClient; import org.apache.druid.guice.annotations.Self; -import org.apache.druid.indexer.report.TaskReport; import org.apache.druid.indexing.common.TaskToolbox; import org.apache.druid.indexing.common.actions.TaskActionClient; +import org.apache.druid.indexing.common.task.IndexTaskUtils; +import org.apache.druid.java.util.common.StringUtils; import org.apache.druid.java.util.common.io.Closer; -import org.apache.druid.java.util.emitter.service.ServiceEmitter; +import org.apache.druid.java.util.common.logger.Logger; +import org.apache.druid.java.util.emitter.service.ServiceMetricEvent; import org.apache.druid.msq.exec.Controller; import org.apache.druid.msq.exec.ControllerContext; +import org.apache.druid.msq.exec.ControllerMemoryParameters; +import org.apache.druid.msq.exec.MemoryIntrospector; +import org.apache.druid.msq.exec.SegmentSource; import org.apache.druid.msq.exec.WorkerClient; -import org.apache.druid.msq.exec.WorkerManagerClient; +import org.apache.druid.msq.exec.WorkerFailureListener; +import org.apache.druid.msq.exec.WorkerManager; import org.apache.druid.msq.indexing.client.ControllerChatHandler; import org.apache.druid.msq.indexing.client.IndexerWorkerClient; -import org.apache.druid.msq.indexing.client.IndexerWorkerManagerClient; +import org.apache.druid.msq.indexing.error.MSQException; +import org.apache.druid.msq.indexing.error.MSQWarnings; +import org.apache.druid.msq.indexing.error.UnknownFault; +import org.apache.druid.msq.input.InputSpecSlicer; +import org.apache.druid.msq.input.table.TableInputSpecSlicer; +import org.apache.druid.msq.kernel.QueryDefinition; +import org.apache.druid.msq.kernel.controller.ControllerQueryKernelConfig; +import org.apache.druid.msq.util.MultiStageQueryContext; +import org.apache.druid.query.DruidMetrics; +import org.apache.druid.query.QueryContext; import org.apache.druid.rpc.ServiceClientFactory; import org.apache.druid.rpc.indexing.OverlordClient; import org.apache.druid.segment.realtime.firehose.ChatHandler; import org.apache.druid.server.DruidNode; +import java.util.Map; +import java.util.concurrent.ThreadLocalRandom; +import java.util.concurrent.TimeUnit; + /** * Implementation for {@link ControllerContext} required to run multi-stage queries as indexing tasks. */ public class IndexerControllerContext implements ControllerContext { + private static final Logger log = new Logger(IndexerControllerContext.class); + + private final MSQControllerTask task; private final TaskToolbox toolbox; private final Injector injector; private final ServiceClientFactory clientFactory; private final OverlordClient overlordClient; - private final WorkerManagerClient workerManager; + private final ServiceMetricEvent.Builder metricBuilder; public IndexerControllerContext( + final MSQControllerTask task, final TaskToolbox toolbox, final Injector injector, final ServiceClientFactory clientFactory, final OverlordClient overlordClient ) { + this.task = task; this.toolbox = toolbox; this.injector = injector; this.clientFactory = clientFactory; this.overlordClient = overlordClient; - this.workerManager = new IndexerWorkerManagerClient(overlordClient); + this.metricBuilder = new ServiceMetricEvent.Builder(); + IndexTaskUtils.setTaskDimensions(metricBuilder, task); + } + + @Override + public ControllerQueryKernelConfig queryKernelConfig( + final MSQSpec querySpec, + final QueryDefinition queryDef + ) + { + final MemoryIntrospector memoryIntrospector = injector.getInstance(MemoryIntrospector.class); + final ControllerMemoryParameters memoryParameters = + ControllerMemoryParameters.createProductionInstance( + memoryIntrospector, + queryDef.getFinalStageDefinition().getMaxWorkerCount() + ); + + final ControllerQueryKernelConfig config = makeQueryKernelConfig(querySpec, memoryParameters); + + log.debug( + "Query[%s] using %s[%s], %s[%s], %s[%s].", + queryDef.getQueryId(), + MultiStageQueryContext.CTX_DURABLE_SHUFFLE_STORAGE, + config.isDurableStorage(), + MultiStageQueryContext.CTX_FAULT_TOLERANCE, + config.isFaultTolerant(), + MultiStageQueryContext.CTX_MAX_CONCURRENT_STAGES, + config.getMaxConcurrentStages() + ); + + return config; } @Override - public ServiceEmitter emitter() + public void emitMetric(String metric, Number value) { - return toolbox.getEmitter(); + toolbox.getEmitter().emit(metricBuilder.setMetric(metric, value)); } @Override @@ -91,9 +145,15 @@ public DruidNode selfNode() } @Override - public CoordinatorClient coordinatorClient() + public InputSpecSlicer newTableInputSpecSlicer() { - return toolbox.getCoordinatorClient(); + final SegmentSource includeSegmentSource = + MultiStageQueryContext.getSegmentSources(task.getQuerySpec().getQuery().context()); + return new TableInputSpecSlicer( + toolbox.getCoordinatorClient(), + toolbox.getTaskActionClient(), + includeSegmentSource + ); } @Override @@ -103,29 +163,126 @@ public TaskActionClient taskActionClient() } @Override - public WorkerClient taskClientFor(Controller controller) + public WorkerClient newWorkerClient() { - // Ignore controller parameter. return new IndexerWorkerClient(clientFactory, overlordClient, jsonMapper()); } @Override public void registerController(Controller controller, final Closer closer) { - ChatHandler chatHandler = new ControllerChatHandler(toolbox, controller); - toolbox.getChatHandlerProvider().register(controller.id(), chatHandler, false); - closer.register(() -> toolbox.getChatHandlerProvider().unregister(controller.id())); + ChatHandler chatHandler = new ControllerChatHandler( + controller, + task.getDataSource(), + toolbox.getAuthorizerMapper() + ); + toolbox.getChatHandlerProvider().register(controller.queryId(), chatHandler, false); + closer.register(() -> toolbox.getChatHandlerProvider().unregister(controller.queryId())); } @Override - public WorkerManagerClient workerManager() + public WorkerManager newWorkerManager( + final String queryId, + final MSQSpec querySpec, + final ControllerQueryKernelConfig queryKernelConfig, + final WorkerFailureListener workerFailureListener + ) { - return workerManager; + return new MSQWorkerTaskLauncher( + queryId, + task.getDataSource(), + overlordClient, + workerFailureListener, + makeTaskContext(querySpec, queryKernelConfig, task.getContext()), + // 10 minutes +- 2 minutes jitter + TimeUnit.SECONDS.toMillis(600 + ThreadLocalRandom.current().nextInt(-4, 5) * 30L) + ); } - @Override - public void writeReports(String controllerTaskId, TaskReport.ReportMap reports) + /** + * Helper method for {@link #queryKernelConfig(MSQSpec, QueryDefinition)}. Also used in tests. + */ + public static ControllerQueryKernelConfig makeQueryKernelConfig( + final MSQSpec querySpec, + final ControllerMemoryParameters memoryParameters + ) { - toolbox.getTaskReportFileWriter().write(controllerTaskId, reports); + final QueryContext queryContext = querySpec.getQuery().context(); + final int maxConcurrentStages = MultiStageQueryContext.getMaxConcurrentStages(queryContext); + final boolean isFaultToleranceEnabled = MultiStageQueryContext.isFaultToleranceEnabled(queryContext); + final boolean isDurableStorageEnabled; + + if (isFaultToleranceEnabled) { + if (!queryContext.containsKey(MultiStageQueryContext.CTX_DURABLE_SHUFFLE_STORAGE)) { + // if context key not set, enable durableStorage automatically. + isDurableStorageEnabled = true; + } else { + // if context key is set, and durableStorage is turned on. + if (MultiStageQueryContext.isDurableStorageEnabled(queryContext)) { + isDurableStorageEnabled = true; + } else { + throw new MSQException( + UnknownFault.forMessage( + StringUtils.format( + "Context param[%s] cannot be explicitly set to false when context param[%s] is" + + " set to true. Either remove the context param[%s] or explicitly set it to true.", + MultiStageQueryContext.CTX_DURABLE_SHUFFLE_STORAGE, + MultiStageQueryContext.CTX_FAULT_TOLERANCE, + MultiStageQueryContext.CTX_DURABLE_SHUFFLE_STORAGE + ) + ) + ); + } + } + } else { + isDurableStorageEnabled = MultiStageQueryContext.isDurableStorageEnabled(queryContext); + } + + return ControllerQueryKernelConfig + .builder() + .pipeline(maxConcurrentStages > 1) + .durableStorage(isDurableStorageEnabled) + .faultTolerance(isFaultToleranceEnabled) + .destination(querySpec.getDestination()) + .maxConcurrentStages(maxConcurrentStages) + .maxRetainedPartitionSketchBytes(memoryParameters.getPartitionStatisticsMaxRetainedBytes()) + .build(); + } + + /** + * Helper method for {@link #newWorkerManager}, split out to be used in tests. + * + * @param querySpec MSQ query spec; used for + */ + public static Map makeTaskContext( + final MSQSpec querySpec, + final ControllerQueryKernelConfig queryKernelConfig, + final Map controllerTaskContext + ) + { + final ImmutableMap.Builder taskContextOverridesBuilder = ImmutableMap.builder(); + final long maxParseExceptions = MultiStageQueryContext.getMaxParseExceptions(querySpec.getQuery().context()); + + taskContextOverridesBuilder + .put(MultiStageQueryContext.CTX_DURABLE_SHUFFLE_STORAGE, queryKernelConfig.isDurableStorage()) + .put(MSQWarnings.CTX_MAX_PARSE_EXCEPTIONS_ALLOWED, maxParseExceptions) + .put(MultiStageQueryContext.CTX_IS_REINDEX, MSQControllerTask.isReplaceInputDataSourceTask(querySpec)) + .put(MultiStageQueryContext.CTX_MAX_CONCURRENT_STAGES, queryKernelConfig.getMaxConcurrentStages()); + + if (querySpec.getDestination().toSelectDestination() != null) { + taskContextOverridesBuilder.put( + MultiStageQueryContext.CTX_SELECT_DESTINATION, + querySpec.getDestination().toSelectDestination().getName() + ); + } + + // propagate the controller's tags to the worker task for enhanced metrics reporting + @SuppressWarnings("unchecked") + Map tags = (Map) controllerTaskContext.get(DruidMetrics.TAGS); + if (tags != null) { + taskContextOverridesBuilder.put(DruidMetrics.TAGS, tags); + } + + return taskContextOverridesBuilder.build(); } } diff --git a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/indexing/IndexerResourcePermissionMapper.java b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/indexing/IndexerResourcePermissionMapper.java new file mode 100644 index 000000000000..30bc75282fa4 --- /dev/null +++ b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/indexing/IndexerResourcePermissionMapper.java @@ -0,0 +1,55 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.druid.msq.indexing; + +import org.apache.druid.cli.CliIndexer; +import org.apache.druid.cli.CliPeon; +import org.apache.druid.msq.rpc.ResourcePermissionMapper; +import org.apache.druid.server.security.Action; +import org.apache.druid.server.security.Resource; +import org.apache.druid.server.security.ResourceAction; +import org.apache.druid.server.security.ResourceType; + +import java.util.Collections; +import java.util.List; + +/** + * Production implementation of {@link ResourcePermissionMapper} for tasks: {@link CliIndexer} and {@link CliPeon}. + */ +public class IndexerResourcePermissionMapper implements ResourcePermissionMapper +{ + private final String dataSource; + + public IndexerResourcePermissionMapper(String dataSource) + { + this.dataSource = dataSource; + } + + @Override + public List getAdminPermissions() + { + return Collections.singletonList( + new ResourceAction( + new Resource(dataSource, ResourceType.DATASOURCE), + Action.WRITE + ) + ); + } +} diff --git a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/indexing/MSQControllerTask.java b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/indexing/MSQControllerTask.java index 9cc9e4dae745..e645e0e62cd8 100644 --- a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/indexing/MSQControllerTask.java +++ b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/indexing/MSQControllerTask.java @@ -47,6 +47,7 @@ import org.apache.druid.msq.exec.ControllerContext; import org.apache.druid.msq.exec.ControllerImpl; import org.apache.druid.msq.exec.MSQTasks; +import org.apache.druid.msq.exec.ResultsContext; import org.apache.druid.msq.indexing.destination.DataSourceMSQDestination; import org.apache.druid.msq.indexing.destination.DurableStorageMSQDestination; import org.apache.druid.msq.indexing.destination.ExportMSQDestination; @@ -246,20 +247,37 @@ public TaskStatus runTask(final TaskToolbox toolbox) throws Exception final OverlordClient overlordClient = injector.getInstance(OverlordClient.class) .withRetryPolicy(StandardRetryPolicy.unlimited()); final ControllerContext context = new IndexerControllerContext( + this, toolbox, injector, clientFactory, overlordClient ); - controller = new ControllerImpl(this, context); - return controller.run(); + + controller = new ControllerImpl( + this.getId(), + querySpec, + new ResultsContext(getSqlTypeNames(), getSqlResultsContext()), + context + ); + + final TaskReportQueryListener queryListener = new TaskReportQueryListener( + querySpec.getDestination(), + () -> toolbox.getTaskReportFileWriter().openReportOutputStream(getId()), + toolbox.getJsonMapper(), + getId(), + getContext() + ); + + controller.run(queryListener); + return queryListener.getStatusReport().toTaskStatus(getId()); } @Override public void stopGracefully(final TaskConfig taskConfig) { if (controller != null) { - controller.stopGracefully(); + controller.stop(); } } @@ -300,14 +318,15 @@ public static boolean isExport(final MSQSpec querySpec) * Returns true if the task reads from the same table as the destionation. In this case, we would prefer to fail * instead of reading any unused segments to ensure that old data is not read. */ - public static boolean isReplaceInputDataSourceTask(MSQControllerTask task) + public static boolean isReplaceInputDataSourceTask(MSQSpec querySpec) { - return task.getQuerySpec() - .getQuery() - .getDataSource() - .getTableNames() - .stream() - .anyMatch(datasouce -> task.getDataSource().equals(datasouce)); + if (isIngestion(querySpec)) { + final String targetDataSource = ((DataSourceMSQDestination) querySpec.getDestination()).getDataSource(); + final Set sourceTableNames = querySpec.getQuery().getDataSource().getTableNames(); + return sourceTableNames.contains(targetDataSource); + } else { + return false; + } } public static boolean writeResultsToDurableStorage(final MSQSpec querySpec) diff --git a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/indexing/MSQWorkerTaskLauncher.java b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/indexing/MSQWorkerTaskLauncher.java index 55ff6a3876d8..ed32b81f44ef 100644 --- a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/indexing/MSQWorkerTaskLauncher.java +++ b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/indexing/MSQWorkerTaskLauncher.java @@ -19,13 +19,15 @@ package org.apache.druid.msq.indexing; -import com.fasterxml.jackson.annotation.JsonProperty; import com.google.common.collect.ImmutableList; import com.google.common.util.concurrent.ListenableFuture; import com.google.common.util.concurrent.SettableFuture; import com.google.errorprone.annotations.concurrent.GuardedBy; +import it.unimi.dsi.fastutil.ints.Int2ObjectAVLTreeMap; +import it.unimi.dsi.fastutil.ints.Int2ObjectMap; import it.unimi.dsi.fastutil.ints.IntOpenHashSet; import it.unimi.dsi.fastutil.ints.IntSet; +import org.apache.druid.client.indexing.TaskStatusResponse; import org.apache.druid.common.guava.FutureUtils; import org.apache.druid.indexer.TaskLocation; import org.apache.druid.indexer.TaskState; @@ -34,17 +36,18 @@ import org.apache.druid.java.util.common.StringUtils; import org.apache.druid.java.util.common.concurrent.Execs; import org.apache.druid.java.util.common.logger.Logger; -import org.apache.druid.msq.exec.ControllerContext; -import org.apache.druid.msq.exec.ControllerImpl; import org.apache.druid.msq.exec.Limits; import org.apache.druid.msq.exec.MSQTasks; -import org.apache.druid.msq.exec.WorkerManagerClient; +import org.apache.druid.msq.exec.RetryCapableWorkerManager; +import org.apache.druid.msq.exec.WorkerFailureListener; +import org.apache.druid.msq.exec.WorkerStats; import org.apache.druid.msq.indexing.error.MSQException; import org.apache.druid.msq.indexing.error.TaskStartTimeoutFault; import org.apache.druid.msq.indexing.error.TooManyAttemptsForJob; import org.apache.druid.msq.indexing.error.TooManyAttemptsForWorker; import org.apache.druid.msq.indexing.error.UnknownFault; import org.apache.druid.msq.indexing.error.WorkerFailedFault; +import org.apache.druid.rpc.indexing.OverlordClient; import java.time.Duration; import java.util.ArrayList; @@ -56,19 +59,17 @@ import java.util.Map; import java.util.OptionalLong; import java.util.Set; -import java.util.TreeMap; import java.util.concurrent.ConcurrentHashMap; -import java.util.concurrent.ConcurrentMap; -import java.util.concurrent.ConcurrentSkipListMap; import java.util.concurrent.ExecutorService; import java.util.concurrent.atomic.AtomicBoolean; import java.util.concurrent.atomic.AtomicLong; import java.util.concurrent.atomic.AtomicReference; +import java.util.stream.Collectors; /** * Like {@link org.apache.druid.indexing.common.task.batch.parallel.TaskMonitor}, but different. */ -public class MSQWorkerTaskLauncher +public class MSQWorkerTaskLauncher implements RetryCapableWorkerManager { private static final Logger log = new Logger(MSQWorkerTaskLauncher.class); private static final long HIGH_FREQUENCY_CHECK_MILLIS = 100; @@ -87,7 +88,7 @@ private enum State private final String controllerTaskId; private final String dataSource; - private final ControllerContext context; + private final OverlordClient overlordClient; private final ExecutorService exec; private final long maxTaskStartDelayMillis; @@ -108,15 +109,19 @@ private enum State @GuardedBy("taskIds") private final List taskIds = new ArrayList<>(); + // Task ID -> worker number. Only set for active workers. + @GuardedBy("taskIds") + private final Map taskIdToWorkerNumber = new HashMap<>(); + // Worker number -> whether the task has fully started up or not. @GuardedBy("taskIds") private final IntSet fullyStartedTasks = new IntOpenHashSet(); - // Mutable state accessed by mainLoop, ControllerImpl, and jetty (/liveReports) threads. + // Mutable state written only by the mainLoop() thread. // Tasks are added here once they are submitted for running, but before they are fully started up. - // taskId -> taskTracker - private final ConcurrentMap taskTrackers = new ConcurrentSkipListMap<>(Comparator.comparingInt( - MSQTasks::workerFromTaskId)); + // Uses a concurrent map because getWorkerStats() reads this map too, and getWorkerStats() can be called by various + // other threads. + private final ConcurrentHashMap taskTrackers = new ConcurrentHashMap<>(); // Set of tasks which are issued a cancel request by the controller. private final Set canceledWorkerTasks = ConcurrentHashMap.newKeySet(); @@ -135,35 +140,31 @@ private enum State private final Set failedInactiveWorkers = ConcurrentHashMap.newKeySet(); private final ConcurrentHashMap> workerToTaskIds = new ConcurrentHashMap<>(); - private final RetryTask retryTask; + private final WorkerFailureListener workerFailureListener; private final AtomicLong recentFullyStartedWorkerTimeInMillis = new AtomicLong(System.currentTimeMillis()); public MSQWorkerTaskLauncher( final String controllerTaskId, final String dataSource, - final ControllerContext context, - final RetryTask retryTask, + final OverlordClient overlordClient, + final WorkerFailureListener workerFailureListener, final Map taskContextOverrides, final long maxTaskStartDelayMillis ) { this.controllerTaskId = controllerTaskId; this.dataSource = dataSource; - this.context = context; + this.overlordClient = overlordClient; this.taskContextOverrides = taskContextOverrides; this.exec = Execs.singleThreaded( "multi-stage-query-task-launcher[" + StringUtils.encodeForFormat(controllerTaskId) + "]-%s" ); - this.retryTask = retryTask; + this.workerFailureListener = workerFailureListener; this.maxTaskStartDelayMillis = maxTaskStartDelayMillis; } - /** - * Launches tasks, blocking until they are all in RUNNING state. Returns a future that resolves successfully when - * all tasks end successfully or are canceled. The returned future resolves to an exception if one of the tasks fails - * without being explicitly canceled, or if something else goes wrong. - */ + @Override public ListenableFuture start() { if (state.compareAndSet(State.NEW, State.STARTED)) { @@ -181,10 +182,7 @@ public ListenableFuture start() return stopFuture; } - /** - * Stops all tasks, blocking until they exit. Returns quietly, no matter whether there was an exception - * associated with the future from {@link #start()} or not. - */ + @Override public void stop(final boolean interrupt) { if (state.compareAndSet(State.NEW, State.STOPPED)) { @@ -221,24 +219,24 @@ public void stop(final boolean interrupt) } // Block until stopped. - waitForWorkerShutdown(); + try { + FutureUtils.getUnchecked(stopFuture, false); + } + catch (Throwable ignored) { + // Suppress. + } } - /** - * Get the list of currently-active tasks. - */ - public List getActiveTasks() + @Override + public List getWorkerIds() { synchronized (taskIds) { return ImmutableList.copyOf(taskIds); } } - /** - * Launch additional tasks, if needed, to bring the size of {@link #taskIds} up to {@code taskCount}. If enough - * tasks are already running, this method does nothing. - */ - public void launchTasksIfNeeded(final int taskCount) throws InterruptedException + @Override + public void launchWorkersIfNeeded(final int taskCount) throws InterruptedException { synchronized (taskIds) { retryInactiveTasksIfNeeded(taskCount); @@ -280,21 +278,13 @@ Set getWorkersToRelaunch() return workersToRelaunch; } - /** - * Queues worker for relaunch. A noop if the worker is already in the queue. - * - * @param workerNumber worker number - */ + @Override public void submitForRelaunch(int workerNumber) { workersToRelaunch.add(workerNumber); } - /** - * Report a worker that failed without active orders. To be retried if it is requried for future stages only. - * - * @param workerNumber worker number - */ + @Override public void reportFailedInactiveWorker(int workerNumber) { synchronized (taskIds) { @@ -302,16 +292,11 @@ public void reportFailedInactiveWorker(int workerNumber) } } - /** - * Blocks the call untill the worker tasks are ready to be contacted for work. - * - * @param workerSet - * @throws InterruptedException - */ - public void waitUntilWorkersReady(Set workerSet) throws InterruptedException + @Override + public void waitForWorkers(Set workerNumbers) throws InterruptedException { synchronized (taskIds) { - while (!fullyStartedTasks.containsAll(workerSet)) { + while (!fullyStartedTasks.containsAll(workerNumbers)) { if (stopFuture.isDone() || stopFuture.isCancelled()) { FutureUtils.getUnchecked(stopFuture, false); throw new ISE("Stopped"); @@ -321,40 +306,30 @@ public void waitUntilWorkersReady(Set workerSet) throws InterruptedExce } } - public void waitForWorkerShutdown() - { - try { - FutureUtils.getUnchecked(stopFuture, false); - } - catch (Throwable ignored) { - // Suppress. - } - } - - /** - * Checks if the controller has canceled the input taskId. This method is used in {@link ControllerImpl} - * to figure out if the worker taskId is canceled by the controller. If yes, the errors from that worker taskId - * are ignored for the error reports. - * - * @return true if task is canceled by the controller, else false - */ + @Override public boolean isTaskCanceledByController(String taskId) { return canceledWorkerTasks.contains(taskId); } + @Override + public int getWorkerNumber(String taskId) + { + return MSQTasks.workerFromTaskId(taskId); + } - public boolean isTaskLatest(String taskId) + @Override + public boolean isWorkerActive(String taskId) { - int worker = MSQTasks.workerFromTaskId(taskId); synchronized (taskIds) { - return taskId.equals(taskIds.get(worker)); + return taskIdToWorkerNumber.get(taskId) != null; } } + @Override public Map> getWorkerStats() { - final Map> workerStats = new TreeMap<>(); + final Int2ObjectMap> workerStats = new Int2ObjectAVLTreeMap<>(); for (Map.Entry taskEntry : taskTrackers.entrySet()) { final TaskTracker taskTracker = taskEntry.getValue(); @@ -393,6 +368,7 @@ private void mainLoop() cleanFailedTasksWhichAreRelaunched(); } catch (Throwable e) { + log.warn(e, "Stopped due to exception in task management loop."); state.set(State.STOPPED); cancelTasksOnStop.set(true); caught = e; @@ -491,9 +467,10 @@ private void runNewTasks() return taskIds; }); - context.workerManager().run(task.getId(), task); + FutureUtils.getUnchecked(overlordClient.runTask(task.getId(), task), true); synchronized (taskIds) { + taskIdToWorkerNumber.put(task.getId(), taskIds.size()); taskIds.add(task.getId()); taskIds.notifyAll(); } @@ -504,7 +481,8 @@ private void runNewTasks() * Returns a pair which contains the number of currently running worker tasks and the number of worker tasks that are * not yet fully started as left and right respectively. */ - public WorkerCount getWorkerTaskCount() + @Override + public WorkerCount getWorkerCount() { synchronized (taskIds) { if (stopFuture.isDone()) { @@ -530,8 +508,8 @@ private void updateTaskTrackersAndTaskIds() } if (!taskStatusesNeeded.isEmpty()) { - final WorkerManagerClient workerManager = context.workerManager(); - final Map statuses = workerManager.statuses(taskStatusesNeeded); + final Map statuses = + FutureUtils.getUnchecked(overlordClient.taskStatuses(taskStatusesNeeded), true); for (Map.Entry statusEntry : statuses.entrySet()) { final String taskId = statusEntry.getKey(); @@ -542,7 +520,13 @@ private void updateTaskTrackersAndTaskIds() if (!status.getStatusCode().isComplete() && tracker.unknownLocation()) { // Look up location if not known. Note: this location is not used to actually contact the task. For that, // we have SpecificTaskServiceLocator. This location is only used to determine if a task has started up. - tracker.setLocation(workerManager.location(taskId)); + final TaskStatusResponse taskStatusResponse = + FutureUtils.getUnchecked(overlordClient.taskStatus(taskId), true); + if (taskStatusResponse.getStatus() != null) { + tracker.setLocation(taskStatusResponse.getStatus().getLocation()); + } else { + tracker.setLocation(TaskLocation.unknown()); + } } if (status.getStatusCode() == TaskState.RUNNING && !tracker.unknownLocation()) { @@ -568,10 +552,7 @@ private void checkForErroneousTasks() { final int numTasks = taskTrackers.size(); - Iterator> taskTrackerIterator = taskTrackers.entrySet().iterator(); - - while (taskTrackerIterator.hasNext()) { - final Map.Entry taskEntry = taskTrackerIterator.next(); + for (Map.Entry taskEntry : taskTrackersByWorkerNumber()) { final String taskId = taskEntry.getKey(); final TaskTracker tracker = taskEntry.getValue(); if (tracker.isRetrying()) { @@ -583,7 +564,7 @@ private void checkForErroneousTasks() final String errorMessage = StringUtils.format("Task [%s] status missing", taskId); log.info(errorMessage + ". Trying to relaunch the worker"); tracker.enableRetrying(); - retryTask.retry( + workerFailureListener.onFailure( tracker.msqWorkerTask, UnknownFault.forMessage(errorMessage) ); @@ -591,7 +572,7 @@ private void checkForErroneousTasks() } else if (tracker.didRunTimeOut(maxTaskStartDelayMillis) && !canceledWorkerTasks.contains(taskId)) { removeWorkerFromFullyStartedWorkers(tracker); throw new MSQException(new TaskStartTimeoutFault( - this.getWorkerTaskCount().getPendingWorkerCount(), + this.getWorkerCount().getPendingWorkerCount(), numTasks + 1, maxTaskStartDelayMillis )); @@ -600,7 +581,10 @@ private void checkForErroneousTasks() TaskStatus taskStatus = tracker.statusRef.get(); log.info("Task[%s] failed because %s. Trying to relaunch the worker", taskId, taskStatus.getErrorMsg()); tracker.enableRetrying(); - retryTask.retry(tracker.msqWorkerTask, new WorkerFailedFault(taskId, taskStatus.getErrorMsg())); + workerFailureListener.onFailure( + tracker.msqWorkerTask, + new WorkerFailedFault(taskId, taskStatus.getErrorMsg()) + ); } } } @@ -658,16 +642,18 @@ private void relaunchTasks() taskIds.notifyAll(); } - context.workerManager().run(relaunchedTask.getId(), relaunchedTask); + FutureUtils.getUnchecked(overlordClient.runTask(relaunchedTask.getId(), relaunchedTask), true); taskHistory.add(relaunchedTask.getId()); synchronized (taskIds) { // replace taskId with the retry taskID for the same worker number + taskIdToWorkerNumber.remove(taskIds.get(toRelaunch.getWorkerNumber())); taskIds.set(toRelaunch.getWorkerNumber(), relaunchedTask.getId()); + taskIdToWorkerNumber.put(relaunchedTask.getId(), toRelaunch.getWorkerNumber()); taskIds.notifyAll(); } - return taskHistory; + return taskHistory; }); iterator.remove(); } @@ -697,14 +683,14 @@ private void shutDownTasks() { cleanFailedTasksWhichAreRelaunched(); - for (final Map.Entry taskEntry : taskTrackers.entrySet()) { + for (final Map.Entry taskEntry : taskTrackersByWorkerNumber()) { final String taskId = taskEntry.getKey(); final TaskTracker tracker = taskEntry.getValue(); if ((!canceledWorkerTasks.contains(taskId)) && (!tracker.isComplete())) { canceledWorkerTasks.add(taskId); - context.workerManager().cancel(taskId); + FutureUtils.getUnchecked(overlordClient.cancelTask(taskId), true); } } } @@ -720,7 +706,7 @@ private void cleanFailedTasksWhichAreRelaunched() try { if (canceledWorkerTasks.add(taskId)) { try { - context.workerManager().cancel(taskId); + FutureUtils.getUnchecked(overlordClient.cancelTask(taskId), true); } catch (Exception ignore) { //ignoring cancellation exception @@ -730,7 +716,6 @@ private void cleanFailedTasksWhichAreRelaunched() finally { tasksToCancel.remove(); } - } } @@ -749,6 +734,17 @@ private boolean allTasksStarted(final int taskCount) return true; } + /** + * Returns entries of {@link #taskTrackers} sorted by worker number. + */ + private List> taskTrackersByWorkerNumber() + { + return taskTrackers.entrySet() + .stream() + .sorted(Comparator.comparing(entry -> entry.getValue().workerNumber)) + .collect(Collectors.toList()); + } + /** * Used by the main loop to decide how often to check task status. */ @@ -885,51 +881,4 @@ public long taskPendingTimeInMs() } } } - - public static class WorkerStats - { - String workerId; - TaskState state; - long duration; - long pendingTimeInMs; - - /** - * For JSON deserialization only - */ - public WorkerStats() - { - } - - public WorkerStats(String workerId, TaskState state, long duration, long pendingTimeInMs) - { - this.workerId = workerId; - this.state = state; - this.duration = duration; - this.pendingTimeInMs = pendingTimeInMs; - } - - @JsonProperty - public String getWorkerId() - { - return workerId; - } - - @JsonProperty - public TaskState getState() - { - return state; - } - - @JsonProperty("durationMs") - public long getDuration() - { - return duration; - } - - @JsonProperty("pendingMs") - public long getPendingTimeInMs() - { - return pendingTimeInMs; - } - } } diff --git a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/indexing/TaskReportQueryListener.java b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/indexing/TaskReportQueryListener.java new file mode 100644 index 000000000000..4cc4678a58a7 --- /dev/null +++ b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/indexing/TaskReportQueryListener.java @@ -0,0 +1,214 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.druid.msq.indexing; + +import com.fasterxml.jackson.core.JsonGenerator; +import com.fasterxml.jackson.databind.ObjectMapper; +import com.fasterxml.jackson.databind.SerializerProvider; +import org.apache.calcite.sql.type.SqlTypeName; +import org.apache.druid.indexer.report.TaskContextReport; +import org.apache.druid.java.util.common.ISE; +import org.apache.druid.java.util.common.jackson.JacksonUtils; +import org.apache.druid.msq.exec.OutputChannelMode; +import org.apache.druid.msq.exec.QueryListener; +import org.apache.druid.msq.indexing.destination.MSQDestination; +import org.apache.druid.msq.indexing.report.MSQResultsReport; +import org.apache.druid.msq.indexing.report.MSQStatusReport; +import org.apache.druid.msq.indexing.report.MSQTaskReport; +import org.apache.druid.msq.indexing.report.MSQTaskReportPayload; + +import javax.annotation.Nullable; +import java.io.IOException; +import java.io.OutputStream; +import java.util.List; +import java.util.Map; + +/** + * Query listener that writes {@link MSQTaskReport} to an {@link OutputStream}. + * + * This is used so the report can be written one row at a time, as results are being read, as part of the main + * query loop. This allows reports to scale to row counts that cannot be materialized in memory, and allows + * report-writing to be interleaved with query execution when using {@link OutputChannelMode#MEMORY}. + */ +public class TaskReportQueryListener implements QueryListener +{ + private static final String FIELD_TYPE = "type"; + private static final String FIELD_TASK_ID = "taskId"; + private static final String FIELD_PAYLOAD = "payload"; + private static final String FIELD_STATUS = "status"; + private static final String FIELD_STAGES = "stages"; + private static final String FIELD_COUNTERS = "counters"; + private static final String FIELD_RESULTS = "results"; + private static final String FIELD_RESULTS_SIGNATURE = "signature"; + private static final String FIELD_RESULTS_SQL_TYPE_NAMES = "sqlTypeNames"; + private static final String FIELD_RESULTS_RESULTS = "results"; + private static final String FIELD_RESULTS_TRUNCATED = "resultsTruncated"; + + private final long rowsInTaskReport; + private final OutputStreamSupplier reportSink; + private final ObjectMapper jsonMapper; + private final SerializerProvider serializers; + private final String taskId; + private final Map taskContext; + + private JsonGenerator jg; + private long numResults; + private MSQStatusReport statusReport; + + public TaskReportQueryListener( + final MSQDestination destination, + final OutputStreamSupplier reportSink, + final ObjectMapper jsonMapper, + final String taskId, + final Map taskContext + ) + { + this.rowsInTaskReport = destination.getRowsInTaskReport(); + this.reportSink = reportSink; + this.jsonMapper = jsonMapper; + this.serializers = jsonMapper.getSerializerProviderInstance(); + this.taskId = taskId; + this.taskContext = taskContext; + } + + @Override + public boolean readResults() + { + return rowsInTaskReport == MSQDestination.UNLIMITED || rowsInTaskReport > 0; + } + + @Override + public void onResultsStart(List signature, @Nullable List sqlTypeNames) + { + try { + openGenerator(); + + jg.writeObjectFieldStart(FIELD_RESULTS); + writeObjectField(FIELD_RESULTS_SIGNATURE, signature); + if (sqlTypeNames != null) { + writeObjectField(FIELD_RESULTS_SQL_TYPE_NAMES, sqlTypeNames); + } + jg.writeArrayFieldStart(FIELD_RESULTS_RESULTS); + } + catch (IOException e) { + throw new RuntimeException(e); + } + } + + @Override + public boolean onResultRow(Object[] row) + { + try { + JacksonUtils.writeObjectUsingSerializerProvider(jg, serializers, row); + numResults++; + + if (rowsInTaskReport == MSQDestination.UNLIMITED || numResults < rowsInTaskReport) { + return true; + } else { + jg.writeEndArray(); + jg.writeBooleanField(FIELD_RESULTS_TRUNCATED, true); + jg.writeEndObject(); + return false; + } + } + catch (IOException e) { + throw new RuntimeException(e); + } + } + + @Override + public void onResultsComplete() + { + try { + jg.writeEndArray(); + jg.writeBooleanField(FIELD_RESULTS_TRUNCATED, false); + jg.writeEndObject(); + } + catch (IOException e) { + throw new RuntimeException(e); + } + } + + @Override + public void onQueryComplete(MSQTaskReportPayload report) + { + try { + openGenerator(); + statusReport = report.getStatus(); + writeObjectField(FIELD_STATUS, report.getStatus()); + + if (report.getStages() != null) { + writeObjectField(FIELD_STAGES, report.getStages()); + } + + if (report.getCounters() != null) { + writeObjectField(FIELD_COUNTERS, report.getCounters()); + } + + jg.writeEndObject(); // End MSQTaskReportPayload + jg.writeEndObject(); // End MSQTaskReport + jg.writeObjectField(TaskContextReport.REPORT_KEY, new TaskContextReport(taskId, taskContext)); + jg.writeEndObject(); // End report + jg.close(); + } + catch (IOException e) { + throw new RuntimeException(e); + } + } + + public MSQStatusReport getStatusReport() + { + if (statusReport == null) { + throw new ISE("Status report not available"); + } + + return statusReport; + } + + /** + * Initialize {@link #jg}, if it wasn't already set up. Writes the object start marker, too. + */ + private void openGenerator() throws IOException + { + if (jg == null) { + jg = jsonMapper.createGenerator(reportSink.get()); + jg.writeStartObject(); // Start report + jg.writeObjectFieldStart(MSQTaskReport.REPORT_KEY); // Start MSQTaskReport + jg.writeStringField(FIELD_TYPE, MSQTaskReport.REPORT_KEY); + jg.writeStringField(FIELD_TASK_ID, taskId); + jg.writeObjectFieldStart(FIELD_PAYLOAD); // Start MSQTaskReportPayload + } + } + + /** + * Write a field name followed by an object. Unlike {@link JsonGenerator#writeObjectField(String, Object)}, + * this approach avoids the re-creation of a {@link SerializerProvider} for each call. + */ + private void writeObjectField(final String fieldName, final Object value) throws IOException + { + jg.writeFieldName(fieldName); + JacksonUtils.writeObjectUsingSerializerProvider(jg, serializers, value); + } + + public interface OutputStreamSupplier + { + OutputStream get() throws IOException; + } +} diff --git a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/indexing/client/ControllerChatHandler.java b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/indexing/client/ControllerChatHandler.java index 4be026ac34c2..bf3dd4a6bf14 100644 --- a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/indexing/client/ControllerChatHandler.java +++ b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/indexing/client/ControllerChatHandler.java @@ -19,179 +19,16 @@ package org.apache.druid.msq.indexing.client; -import org.apache.druid.indexer.report.TaskReport; -import org.apache.druid.indexing.common.TaskToolbox; -import org.apache.druid.msq.counters.CounterSnapshots; -import org.apache.druid.msq.counters.CounterSnapshotsTree; import org.apache.druid.msq.exec.Controller; -import org.apache.druid.msq.exec.ControllerClient; -import org.apache.druid.msq.indexing.MSQControllerTask; -import org.apache.druid.msq.indexing.MSQTaskList; -import org.apache.druid.msq.indexing.error.MSQErrorReport; -import org.apache.druid.msq.kernel.StageId; -import org.apache.druid.msq.statistics.PartialKeyStatisticsInformation; +import org.apache.druid.msq.indexing.IndexerResourcePermissionMapper; +import org.apache.druid.msq.rpc.ControllerResource; import org.apache.druid.segment.realtime.firehose.ChatHandler; -import org.apache.druid.segment.realtime.firehose.ChatHandlers; -import org.apache.druid.server.security.Action; +import org.apache.druid.server.security.AuthorizerMapper; -import javax.servlet.http.HttpServletRequest; -import javax.ws.rs.Consumes; -import javax.ws.rs.GET; -import javax.ws.rs.POST; -import javax.ws.rs.Path; -import javax.ws.rs.PathParam; -import javax.ws.rs.Produces; -import javax.ws.rs.core.Context; -import javax.ws.rs.core.MediaType; -import javax.ws.rs.core.Response; -import java.util.List; - -public class ControllerChatHandler implements ChatHandler +public class ControllerChatHandler extends ControllerResource implements ChatHandler { - private final Controller controller; - private final MSQControllerTask task; - private final TaskToolbox toolbox; - - public ControllerChatHandler(TaskToolbox toolbox, Controller controller) - { - this.controller = controller; - this.task = controller.task(); - this.toolbox = toolbox; - } - - /** - * Used by subtasks to post {@link PartialKeyStatisticsInformation} for shuffling stages. - * - * See {@link ControllerClient#postPartialKeyStatistics(StageId, int, PartialKeyStatisticsInformation)} - * for the client-side code that calls this API. - */ - @POST - @Path("/partialKeyStatisticsInformation/{queryId}/{stageNumber}/{workerNumber}") - @Produces(MediaType.APPLICATION_JSON) - @Consumes(MediaType.APPLICATION_JSON) - public Response httpPostPartialKeyStatistics( - final Object partialKeyStatisticsObject, - @PathParam("queryId") final String queryId, - @PathParam("stageNumber") final int stageNumber, - @PathParam("workerNumber") final int workerNumber, - @Context final HttpServletRequest req - ) - { - ChatHandlers.authorizationCheck(req, Action.WRITE, task.getDataSource(), toolbox.getAuthorizerMapper()); - controller.updatePartialKeyStatisticsInformation(stageNumber, workerNumber, partialKeyStatisticsObject); - return Response.status(Response.Status.ACCEPTED).build(); - } - - /** - * Used by subtasks to post system errors. Note that the errors are organized by taskId, not by query/stage/worker, - * because system errors are associated with a task rather than a specific query/stage/worker execution context. - * - * See {@link ControllerClient#postWorkerError} for the client-side code that calls this API. - */ - @POST - @Path("/workerError/{taskId}") - @Produces(MediaType.APPLICATION_JSON) - @Consumes(MediaType.APPLICATION_JSON) - public Response httpPostWorkerError( - final MSQErrorReport errorReport, - @PathParam("taskId") final String taskId, - @Context final HttpServletRequest req - ) - { - ChatHandlers.authorizationCheck(req, Action.WRITE, task.getDataSource(), toolbox.getAuthorizerMapper()); - controller.workerError(errorReport); - return Response.status(Response.Status.ACCEPTED).build(); - } - - /** - * Used by subtasks to post system warnings. - * - * See {@link ControllerClient#postWorkerWarning} for the client-side code that calls this API. - */ - @POST - @Path("/workerWarning") - @Produces(MediaType.APPLICATION_JSON) - @Consumes(MediaType.APPLICATION_JSON) - public Response httpPostWorkerWarning( - final List errorReport, - @Context final HttpServletRequest req - ) - { - ChatHandlers.authorizationCheck(req, Action.WRITE, task.getDataSource(), toolbox.getAuthorizerMapper()); - controller.workerWarning(errorReport); - return Response.status(Response.Status.ACCEPTED).build(); - } - - - /** - * Used by subtasks to post {@link CounterSnapshots} periodically. - * - * See {@link ControllerClient#postCounters} for the client-side code that calls this API. - */ - @POST - @Path("/counters/{taskId}") - @Produces(MediaType.APPLICATION_JSON) - @Consumes(MediaType.APPLICATION_JSON) - public Response httpPostCounters( - @PathParam("taskId") final String taskId, - final CounterSnapshotsTree snapshotsTree, - @Context final HttpServletRequest req - ) - { - ChatHandlers.authorizationCheck(req, Action.WRITE, task.getDataSource(), toolbox.getAuthorizerMapper()); - controller.updateCounters(taskId, snapshotsTree); - return Response.status(Response.Status.OK).build(); - } - - /** - * Used by subtasks to post notifications that their results are ready. - * - * See {@link ControllerClient#postResultsComplete} for the client-side code that calls this API. - */ - @POST - @Path("/resultsComplete/{queryId}/{stageNumber}/{workerNumber}") - @Produces(MediaType.APPLICATION_JSON) - @Consumes(MediaType.APPLICATION_JSON) - public Response httpPostResultsComplete( - final Object resultObject, - @PathParam("queryId") final String queryId, - @PathParam("stageNumber") final int stageNumber, - @PathParam("workerNumber") final int workerNumber, - @Context final HttpServletRequest req - ) - { - ChatHandlers.authorizationCheck(req, Action.WRITE, task.getDataSource(), toolbox.getAuthorizerMapper()); - controller.resultsComplete(queryId, stageNumber, workerNumber, resultObject); - return Response.status(Response.Status.ACCEPTED).build(); - } - - /** - * See {@link ControllerClient#getTaskList} for the client-side code that calls this API. - */ - @GET - @Path("/taskList") - @Produces(MediaType.APPLICATION_JSON) - public Response httpGetTaskList(@Context final HttpServletRequest req) - { - ChatHandlers.authorizationCheck(req, Action.WRITE, task.getDataSource(), toolbox.getAuthorizerMapper()); - - return Response.ok(new MSQTaskList(controller.getTaskIds())).build(); - } - - /** - * See {@link org.apache.druid.indexing.overlord.RemoteTaskRunner#streamTaskReports} for the client-side code that - * calls this API. - */ - @GET - @Path("/liveReports") - @Produces(MediaType.APPLICATION_JSON) - public Response httpGetLiveReports(@Context final HttpServletRequest req) + public ControllerChatHandler(Controller controller, String dataSource, AuthorizerMapper authorizerMapper) { - ChatHandlers.authorizationCheck(req, Action.WRITE, task.getDataSource(), toolbox.getAuthorizerMapper()); - final TaskReport.ReportMap reports = controller.liveReports(); - if (reports == null) { - return Response.status(Response.Status.NOT_FOUND).build(); - } - return Response.ok(reports).build(); + super(controller, new IndexerResourcePermissionMapper(dataSource), authorizerMapper); } } diff --git a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/indexing/client/IndexerControllerClient.java b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/indexing/client/IndexerControllerClient.java index 493cbeb62424..81303eb43848 100644 --- a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/indexing/client/IndexerControllerClient.java +++ b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/indexing/client/IndexerControllerClient.java @@ -80,6 +80,22 @@ public void postPartialKeyStatistics( ); } + @Override + public void postDoneReadingInput(StageId stageId, int workerNumber) throws IOException + { + final String path = StringUtils.format( + "/doneReadingInput/%s/%d/%d", + StringUtils.urlEncode(stageId.getQueryId()), + stageId.getStageNumber(), + workerNumber + ); + + doRequest( + new RequestBuilder(HttpMethod.POST, path), + IgnoreHttpResponseHandler.INSTANCE + ); + } + @Override public void postCounters(String workerId, CounterSnapshotsTree snapshotsTree) throws IOException { diff --git a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/indexing/client/IndexerWorkerClient.java b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/indexing/client/IndexerWorkerClient.java index af089a296006..e9b4a370b241 100644 --- a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/indexing/client/IndexerWorkerClient.java +++ b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/indexing/client/IndexerWorkerClient.java @@ -19,32 +19,11 @@ package org.apache.druid.msq.indexing.client; -import com.fasterxml.jackson.core.type.TypeReference; import com.fasterxml.jackson.databind.ObjectMapper; -import com.google.common.util.concurrent.FutureCallback; -import com.google.common.util.concurrent.Futures; -import com.google.common.util.concurrent.ListenableFuture; -import com.google.common.util.concurrent.MoreExecutors; -import com.google.common.util.concurrent.SettableFuture; import com.google.errorprone.annotations.concurrent.GuardedBy; -import org.apache.druid.common.guava.FutureUtils; -import org.apache.druid.frame.channel.ReadableByteChunksFrameChannel; -import org.apache.druid.frame.file.FrameFileHttpResponseHandler; -import org.apache.druid.frame.file.FrameFilePartialFetch; -import org.apache.druid.frame.key.ClusterByPartitions; import org.apache.druid.java.util.common.Pair; -import org.apache.druid.java.util.common.StringUtils; -import org.apache.druid.java.util.common.concurrent.Execs; -import org.apache.druid.java.util.common.logger.Logger; -import org.apache.druid.java.util.http.client.response.BytesFullResponseHandler; -import org.apache.druid.java.util.http.client.response.BytesFullResponseHolder; -import org.apache.druid.msq.counters.CounterSnapshotsTree; -import org.apache.druid.msq.exec.WorkerClient; -import org.apache.druid.msq.kernel.StageId; -import org.apache.druid.msq.kernel.WorkOrder; -import org.apache.druid.msq.statistics.ClusterByStatisticsSnapshot; -import org.apache.druid.rpc.IgnoreHttpResponseHandler; -import org.apache.druid.rpc.RequestBuilder; +import org.apache.druid.msq.indexing.MSQWorkerTask; +import org.apache.druid.msq.rpc.BaseWorkerClientImpl; import org.apache.druid.rpc.ServiceClient; import org.apache.druid.rpc.ServiceClientFactory; import org.apache.druid.rpc.StandardRetryPolicy; @@ -52,10 +31,8 @@ import org.apache.druid.rpc.indexing.SpecificTaskRetryPolicy; import org.apache.druid.rpc.indexing.SpecificTaskServiceLocator; import org.apache.druid.utils.CloseableUtils; -import org.jboss.netty.handler.codec.http.HttpMethod; -import javax.annotation.Nonnull; -import javax.ws.rs.core.HttpHeaders; +import javax.ws.rs.core.MediaType; import java.io.Closeable; import java.io.IOException; import java.util.HashMap; @@ -63,11 +40,13 @@ import java.util.Map; import java.util.stream.Collectors; -public class IndexerWorkerClient implements WorkerClient +/** + * Worker client for {@link MSQWorkerTask}. + */ +public class IndexerWorkerClient extends BaseWorkerClientImpl { private final ServiceClientFactory clientFactory; private final OverlordClient overlordClient; - private final ObjectMapper jsonMapper; @GuardedBy("clientMap") private final Map> clientMap = new HashMap<>(); @@ -78,202 +57,9 @@ public IndexerWorkerClient( final ObjectMapper jsonMapper ) { + super(jsonMapper, MediaType.APPLICATION_JSON); this.clientFactory = clientFactory; this.overlordClient = overlordClient; - this.jsonMapper = jsonMapper; - } - - - @Nonnull - public static String getStagePartitionPath(StageId stageId, int partitionNumber) - { - return StringUtils.format( - "/channels/%s/%d/%d", - StringUtils.urlEncode(stageId.getQueryId()), - stageId.getStageNumber(), - partitionNumber - ); - } - - @Override - public ListenableFuture postWorkOrder(String workerTaskId, WorkOrder workOrder) - { - return getClient(workerTaskId).asyncRequest( - new RequestBuilder(HttpMethod.POST, "/workOrder") - .jsonContent(jsonMapper, workOrder), - IgnoreHttpResponseHandler.INSTANCE - ); - } - - @Override - public ListenableFuture fetchClusterByStatisticsSnapshot( - String workerTaskId, - String queryId, - int stageNumber - ) - { - String path = StringUtils.format( - "/keyStatistics/%s/%d", - StringUtils.urlEncode(queryId), - stageNumber - ); - - return FutureUtils.transform( - getClient(workerTaskId).asyncRequest( - new RequestBuilder(HttpMethod.POST, path), - new BytesFullResponseHandler() - ), - holder -> deserialize(holder, new TypeReference() - { - }) - ); - } - - @Override - public ListenableFuture fetchClusterByStatisticsSnapshotForTimeChunk( - String workerTaskId, - String queryId, - int stageNumber, - long timeChunk - ) - { - String path = StringUtils.format( - "/keyStatisticsForTimeChunk/%s/%d/%d", - StringUtils.urlEncode(queryId), - stageNumber, - timeChunk - ); - - return FutureUtils.transform( - getClient(workerTaskId).asyncRequest( - new RequestBuilder(HttpMethod.POST, path), - new BytesFullResponseHandler() - ), - holder -> deserialize(holder, new TypeReference() - { - }) - ); - } - - @Override - public ListenableFuture postResultPartitionBoundaries( - String workerTaskId, - StageId stageId, - ClusterByPartitions partitionBoundaries - ) - { - final String path = StringUtils.format( - "/resultPartitionBoundaries/%s/%d", - StringUtils.urlEncode(stageId.getQueryId()), - stageId.getStageNumber() - ); - - return getClient(workerTaskId).asyncRequest( - new RequestBuilder(HttpMethod.POST, path) - .jsonContent(jsonMapper, partitionBoundaries), - IgnoreHttpResponseHandler.INSTANCE - ); - } - - /** - * Client-side method for {@link WorkerChatHandler#httpPostCleanupStage}. - */ - @Override - public ListenableFuture postCleanupStage( - final String workerTaskId, - final StageId stageId - ) - { - final String path = StringUtils.format( - "/cleanupStage/%s/%d", - StringUtils.urlEncode(stageId.getQueryId()), - stageId.getStageNumber() - ); - - return getClient(workerTaskId).asyncRequest( - new RequestBuilder(HttpMethod.POST, path), - IgnoreHttpResponseHandler.INSTANCE - ); - } - - @Override - public ListenableFuture postFinish(String workerTaskId) - { - return getClient(workerTaskId).asyncRequest( - new RequestBuilder(HttpMethod.POST, "/finish"), - IgnoreHttpResponseHandler.INSTANCE - ); - } - - @Override - public ListenableFuture getCounters(String workerTaskId) - { - return FutureUtils.transform( - getClient(workerTaskId).asyncRequest( - new RequestBuilder(HttpMethod.GET, "/counters"), - new BytesFullResponseHandler() - ), - holder -> deserialize(holder, new TypeReference() - { - }) - ); - } - - private static final Logger log = new Logger(IndexerWorkerClient.class); - - @Override - public ListenableFuture fetchChannelData( - String workerTaskId, - StageId stageId, - int partitionNumber, - long offset, - ReadableByteChunksFrameChannel channel - ) - { - final ServiceClient client = getClient(workerTaskId); - final String path = getStagePartitionPath(stageId, partitionNumber); - - final SettableFuture retVal = SettableFuture.create(); - final ListenableFuture clientFuture = - client.asyncRequest( - new RequestBuilder(HttpMethod.GET, StringUtils.format("%s?offset=%d", path, offset)) - .header(HttpHeaders.ACCEPT_ENCODING, "identity"), // Data is compressed at app level - new FrameFileHttpResponseHandler(channel) - ); - - Futures.addCallback( - clientFuture, - new FutureCallback() - { - @Override - public void onSuccess(FrameFilePartialFetch partialFetch) - { - if (partialFetch.isExceptionCaught()) { - // Exception while reading channel. Recoverable. - log.noStackTrace().info( - partialFetch.getExceptionCaught(), - "Encountered exception while reading channel [%s]", - channel.getId() - ); - } - - // Empty fetch means this is the last fetch for the channel. - partialFetch.backpressureFuture().addListener( - () -> retVal.set(partialFetch.isLastFetch()), - Execs.directExecutor() - ); - } - - @Override - public void onFailure(Throwable t) - { - retVal.setException(t); - } - }, - MoreExecutors.directExecutor() - ); - - return retVal; } @Override @@ -291,36 +77,22 @@ public void close() throws IOException } } - private ServiceClient getClient(final String workerTaskId) + @Override + protected ServiceClient getClient(final String workerId) { synchronized (clientMap) { return clientMap.computeIfAbsent( - workerTaskId, + workerId, id -> { final SpecificTaskServiceLocator locator = new SpecificTaskServiceLocator(id, overlordClient); final ServiceClient client = clientFactory.makeClient( id, locator, - new SpecificTaskRetryPolicy(workerTaskId, StandardRetryPolicy.unlimitedWithoutRetryLogging()) + new SpecificTaskRetryPolicy(workerId, StandardRetryPolicy.unlimitedWithoutRetryLogging()) ); return Pair.of(client, locator); } ).lhs; } } - - /** - * Deserialize a {@link BytesFullResponseHolder} as JSON. - *

- * It would be reasonable to move this to {@link BytesFullResponseHolder} itself, or some shared utility class. - */ - private T deserialize(final BytesFullResponseHolder bytesHolder, final TypeReference typeReference) - { - try { - return jsonMapper.readValue(bytesHolder.getContent(), typeReference); - } - catch (IOException e) { - throw new RuntimeException(e); - } - } } diff --git a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/indexing/destination/DataSourceMSQDestination.java b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/indexing/destination/DataSourceMSQDestination.java index 0854582a733c..ea3072bfe45a 100644 --- a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/indexing/destination/DataSourceMSQDestination.java +++ b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/indexing/destination/DataSourceMSQDestination.java @@ -133,6 +133,18 @@ public boolean isReplaceTimeChunks() return replaceTimeChunks != null; } + @Override + public long getRowsInTaskReport() + { + return 0; + } + + @Override + public MSQSelectDestination toSelectDestination() + { + return null; + } + @Override public boolean equals(Object o) { diff --git a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/indexing/destination/DurableStorageMSQDestination.java b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/indexing/destination/DurableStorageMSQDestination.java index e522243b60d2..88fe5f58e5a1 100644 --- a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/indexing/destination/DurableStorageMSQDestination.java +++ b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/indexing/destination/DurableStorageMSQDestination.java @@ -20,6 +20,7 @@ package org.apache.druid.msq.indexing.destination; import com.fasterxml.jackson.annotation.JsonCreator; +import org.apache.druid.msq.exec.Limits; import org.apache.druid.msq.indexing.MSQControllerTask; import org.apache.druid.msq.querykit.ShuffleSpecFactories; import org.apache.druid.msq.querykit.ShuffleSpecFactory; @@ -63,4 +64,16 @@ public Optional getDestinationResource() { return Optional.of(new Resource(MSQControllerTask.DUMMY_DATASOURCE_FOR_SELECT, ResourceType.DATASOURCE)); } + + @Override + public long getRowsInTaskReport() + { + return Limits.MAX_SELECT_RESULT_ROWS; + } + + @Override + public MSQSelectDestination toSelectDestination() + { + return MSQSelectDestination.DURABLESTORAGE; + } } diff --git a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/indexing/destination/ExportMSQDestination.java b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/indexing/destination/ExportMSQDestination.java index 14ac0ce4c2e8..d6a78def63a8 100644 --- a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/indexing/destination/ExportMSQDestination.java +++ b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/indexing/destination/ExportMSQDestination.java @@ -64,6 +64,18 @@ public ResultFormat getResultFormat() return resultFormat; } + @Override + public long getRowsInTaskReport() + { + return 0; + } + + @Override + public MSQSelectDestination toSelectDestination() + { + return MSQSelectDestination.EXPORT; + } + @Override public boolean equals(Object o) { diff --git a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/indexing/destination/MSQDestination.java b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/indexing/destination/MSQDestination.java index 39460b15194c..ad7878f049a1 100644 --- a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/indexing/destination/MSQDestination.java +++ b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/indexing/destination/MSQDestination.java @@ -21,9 +21,11 @@ import com.fasterxml.jackson.annotation.JsonSubTypes; import com.fasterxml.jackson.annotation.JsonTypeInfo; +import org.apache.druid.msq.indexing.TaskReportQueryListener; import org.apache.druid.msq.querykit.ShuffleSpecFactory; import org.apache.druid.server.security.Resource; +import javax.annotation.Nullable; import java.util.Optional; @JsonTypeInfo(use = JsonTypeInfo.Id.NAME, property = "type") @@ -35,7 +37,31 @@ }) public interface MSQDestination { + /** + * Returned by {@link #getRowsInTaskReport()} when an unlimited number of rows should be included in the task report. + */ + long UNLIMITED = -1; + + /** + * Shuffle spec for the final stage, which writes results to the destination. + */ ShuffleSpecFactory getShuffleSpecFactory(int targetSize); + /** + * Return the resource for this destination. Used for security checks. + */ Optional getDestinationResource(); + + /** + * Number of rows to include in the task report when using {@link TaskReportQueryListener}. Zero means do not + * include results in the report at all. {@link #UNLIMITED} means include an unlimited number of rows. + */ + long getRowsInTaskReport(); + + /** + * Return the {@link MSQSelectDestination} that corresponds to this destination. Returns null if this is not a + * SELECT destination (for example, returns null for {@link DataSourceMSQDestination}). + */ + @Nullable + MSQSelectDestination toSelectDestination(); } diff --git a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/indexing/destination/MSQSelectDestination.java b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/indexing/destination/MSQSelectDestination.java index e32705462470..0d21bdbe0c2f 100644 --- a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/indexing/destination/MSQSelectDestination.java +++ b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/indexing/destination/MSQSelectDestination.java @@ -22,35 +22,31 @@ import com.fasterxml.jackson.annotation.JsonValue; /** - * Determines the destination for results of select queries. + * Determines the destination for results of select queries. Convertible to and from {@link MSQDestination} in a limited + * way, without as many options. Provided directly by end users in query context. */ public enum MSQSelectDestination { /** * Writes all the results directly to the report. */ - TASKREPORT("taskReport", false), + TASKREPORT("taskReport"), + /** - * Writes all the results as files in a specified format to an external location outside druid. + * Writes all the results as files in a specified format to an external location outside Druid. */ - EXPORT("export", false), + EXPORT("export"), + /** * Writes the results as frame files to durable storage. Task report can be truncated to a preview. */ - DURABLESTORAGE("durableStorage", true); + DURABLESTORAGE("durableStorage"); private final String name; - private final boolean shouldTruncateResultsInTaskReport; - - public boolean shouldTruncateResultsInTaskReport() - { - return shouldTruncateResultsInTaskReport; - } - MSQSelectDestination(String name, boolean shouldTruncateResultsInTaskReport) + MSQSelectDestination(String name) { this.name = name; - this.shouldTruncateResultsInTaskReport = shouldTruncateResultsInTaskReport; } @JsonValue @@ -58,13 +54,4 @@ public String getName() { return name; } - - @Override - public String toString() - { - return "MSQSelectDestination{" + - "name='" + name + '\'' + - ", shouldTruncateResultsInTaskReport=" + shouldTruncateResultsInTaskReport + - '}'; - } } diff --git a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/indexing/destination/TaskReportMSQDestination.java b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/indexing/destination/TaskReportMSQDestination.java index 3f199255ac76..dadc40048b66 100644 --- a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/indexing/destination/TaskReportMSQDestination.java +++ b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/indexing/destination/TaskReportMSQDestination.java @@ -61,4 +61,16 @@ public Optional getDestinationResource() { return Optional.of(new Resource(MSQControllerTask.DUMMY_DATASOURCE_FOR_SELECT, ResourceType.DATASOURCE)); } + + @Override + public long getRowsInTaskReport() + { + return UNLIMITED; + } + + @Override + public MSQSelectDestination toSelectDestination() + { + return MSQSelectDestination.TASKREPORT; + } } diff --git a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/indexing/error/NotEnoughMemoryFault.java b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/indexing/error/NotEnoughMemoryFault.java index ed30179306ad..5c80f065eef3 100644 --- a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/indexing/error/NotEnoughMemoryFault.java +++ b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/indexing/error/NotEnoughMemoryFault.java @@ -47,8 +47,9 @@ public NotEnoughMemoryFault( { super( CODE, - "Not enough memory. Required at least %,d bytes. (total = %,d bytes; usable = %,d bytes; server workers = %,d; server threads = %,d). Increase JVM memory with the -xmx option" - + (serverWorkers > 1 ? " or reduce number of server workers" : ""), + "Not enough memory. Required at least %,d bytes. (total = %,d bytes; usable = %,d bytes; " + + "worker capacity = %,d; processing threads = %,d). Increase JVM memory with the -Xmx option" + + (serverWorkers > 1 ? " or reduce worker capacity on this server" : ""), suggestedServerMemory, serverMemory, usableMemory, diff --git a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/indexing/report/MSQResultsReport.java b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/indexing/report/MSQResultsReport.java index b96ce469145e..0479b2959554 100644 --- a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/indexing/report/MSQResultsReport.java +++ b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/indexing/report/MSQResultsReport.java @@ -25,23 +25,14 @@ import com.google.common.base.Preconditions; import org.apache.calcite.sql.type.SqlTypeName; import org.apache.druid.common.config.Configs; -import org.apache.druid.java.util.common.guava.Sequences; -import org.apache.druid.java.util.common.guava.Yielder; -import org.apache.druid.java.util.common.guava.Yielders; -import org.apache.druid.java.util.common.logger.Logger; -import org.apache.druid.msq.exec.Limits; -import org.apache.druid.msq.indexing.destination.MSQSelectDestination; -import org.apache.druid.msq.util.MultiStageQueryContext; import org.apache.druid.segment.column.ColumnType; import javax.annotation.Nullable; -import java.util.ArrayList; import java.util.List; import java.util.Objects; public class MSQResultsReport { - private static final Logger log = new Logger(MSQResultsReport.class); /** * Like {@link org.apache.druid.segment.column.RowSignature}, but allows duplicate column names for compatibility * with SQL (which also allows duplicate column names in query results). @@ -49,72 +40,21 @@ public class MSQResultsReport private final List signature; @Nullable private final List sqlTypeNames; - private final Yielder resultYielder; + private final List results; private final boolean resultsTruncated; - public MSQResultsReport( - final List signature, - @Nullable final List sqlTypeNames, - final Yielder resultYielder, - @Nullable Boolean resultsTruncated - ) - { - this.signature = Preconditions.checkNotNull(signature, "signature"); - this.sqlTypeNames = sqlTypeNames; - this.resultYielder = Preconditions.checkNotNull(resultYielder, "resultYielder"); - this.resultsTruncated = Configs.valueOrDefault(resultsTruncated, false); - } - - /** - * Method that enables Jackson deserialization. - */ @JsonCreator - static MSQResultsReport fromJson( + public MSQResultsReport( @JsonProperty("signature") final List signature, @JsonProperty("sqlTypeNames") @Nullable final List sqlTypeNames, @JsonProperty("results") final List results, @JsonProperty("resultsTruncated") final Boolean resultsTruncated ) { - return new MSQResultsReport(signature, sqlTypeNames, Yielders.each(Sequences.simple(results)), resultsTruncated); - } - - public static MSQResultsReport createReportAndLimitRowsIfNeeded( - final List signature, - @Nullable final List sqlTypeNames, - Yielder resultYielder, - MSQSelectDestination selectDestination - ) - { - List results = new ArrayList<>(); - long rowCount = 0; - int factor = 1; - while (!resultYielder.isDone()) { - results.add(resultYielder.get()); - resultYielder = resultYielder.next(null); - ++rowCount; - if (selectDestination.shouldTruncateResultsInTaskReport() && rowCount >= Limits.MAX_SELECT_RESULT_ROWS) { - break; - } - if (rowCount % (factor * Limits.MAX_SELECT_RESULT_ROWS) == 0) { - log.warn( - "Task report is getting too large with %d rows. Large task reports can cause the controller to go out of memory. " - + "Consider using the 'limit %d' clause in your query to reduce the number of rows in the result. " - + "If you require all the results, consider setting [%s=%s] in the query context which will allow you to fetch large result sets.", - rowCount, - Limits.MAX_SELECT_RESULT_ROWS, - MultiStageQueryContext.CTX_SELECT_DESTINATION, - MSQSelectDestination.DURABLESTORAGE.getName() - ); - factor = factor < 32 ? factor * 2 : 32; - } - } - return new MSQResultsReport( - signature, - sqlTypeNames, - Yielders.each(Sequences.simple(results)), - !resultYielder.isDone() - ); + this.signature = Preconditions.checkNotNull(signature, "signature"); + this.sqlTypeNames = sqlTypeNames; + this.results = Preconditions.checkNotNull(results, "results"); + this.resultsTruncated = Configs.valueOrDefault(resultsTruncated, false); } @JsonProperty("signature") @@ -132,9 +72,9 @@ public List getSqlTypeNames() } @JsonProperty("results") - public Yielder getResultYielder() + public List getResults() { - return resultYielder; + return results; } @JsonProperty("resultsTruncated") diff --git a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/indexing/report/MSQStagesReport.java b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/indexing/report/MSQStagesReport.java index 422d8235fe20..76a077a3ebe0 100644 --- a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/indexing/report/MSQStagesReport.java +++ b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/indexing/report/MSQStagesReport.java @@ -24,7 +24,9 @@ import com.fasterxml.jackson.annotation.JsonProperty; import com.fasterxml.jackson.annotation.JsonValue; import com.google.common.base.Preconditions; +import org.apache.druid.msq.exec.OutputChannelMode; import org.apache.druid.msq.kernel.QueryDefinition; +import org.apache.druid.msq.kernel.ShuffleKind; import org.apache.druid.msq.kernel.StageDefinition; import org.apache.druid.msq.kernel.controller.ControllerStagePhase; import org.joda.time.DateTime; @@ -52,7 +54,8 @@ public static MSQStagesReport create( final Map stagePhaseMap, final Map stageRuntimeMap, final Map stageWorkerCountMap, - final Map stagePartitionCountMap + final Map stagePartitionCountMap, + final Map stageOutputChannelModeMap ) { final List stages = new ArrayList<>(); @@ -76,6 +79,8 @@ public static MSQStagesReport create( stagePhaseMap.get(stageNumber), workerCount, partitionCount, + stageDef.doesShuffle() ? stageDef.getShuffleSpec().kind() : null, + stageOutputChannelModeMap.get(stageNumber), stageStartTime, stageDuration ); @@ -126,6 +131,8 @@ public static class Stage private final ControllerStagePhase phase; private final int workerCount; private final int partitionCount; + private final ShuffleKind shuffleKind; + private final OutputChannelMode outputChannelMode; private final DateTime startTime; private final long duration; @@ -136,7 +143,9 @@ private Stage( @JsonProperty("phase") @Nullable final ControllerStagePhase phase, @JsonProperty("workerCount") final int workerCount, @JsonProperty("partitionCount") final int partitionCount, - @JsonProperty("startTime") @Nullable final DateTime startTime, + @JsonProperty("shuffle") final ShuffleKind shuffleKind, + @JsonProperty("output") final OutputChannelMode outputChannelMode, + @JsonProperty("startTime")@Nullable final DateTime startTime, @JsonProperty("duration") final long duration ) { @@ -145,6 +154,8 @@ private Stage( this.phase = phase; this.workerCount = workerCount; this.partitionCount = partitionCount; + this.shuffleKind = shuffleKind; + this.outputChannelMode = outputChannelMode; this.startTime = startTime; this.duration = duration; } @@ -184,6 +195,20 @@ public int getPartitionCount() return partitionCount; } + @JsonProperty("shuffle") + @JsonInclude(JsonInclude.Include.NON_NULL) + public ShuffleKind getShuffleKind() + { + return shuffleKind; + } + + @JsonProperty("output") + @JsonInclude(JsonInclude.Include.NON_NULL) + public OutputChannelMode getOutputChannelMode() + { + return outputChannelMode; + } + @JsonProperty("sort") @JsonInclude(JsonInclude.Include.NON_DEFAULT) public boolean isSorting() diff --git a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/indexing/report/MSQStatusReport.java b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/indexing/report/MSQStatusReport.java index eca8998f865c..8bab9e0832bd 100644 --- a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/indexing/report/MSQStatusReport.java +++ b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/indexing/report/MSQStatusReport.java @@ -24,9 +24,11 @@ import com.fasterxml.jackson.annotation.JsonProperty; import com.google.common.base.Preconditions; import org.apache.druid.indexer.TaskState; +import org.apache.druid.indexer.TaskStatus; import org.apache.druid.msq.exec.SegmentLoadStatusFetcher; -import org.apache.druid.msq.indexing.MSQWorkerTaskLauncher; +import org.apache.druid.msq.exec.WorkerStats; import org.apache.druid.msq.indexing.error.MSQErrorReport; +import org.apache.druid.msq.indexing.error.MSQFaultUtils; import org.joda.time.DateTime; import javax.annotation.Nullable; @@ -50,7 +52,7 @@ public class MSQStatusReport private final long durationMs; - private final Map> workerStats; + private final Map> workerStats; private final int pendingTasks; @@ -69,10 +71,11 @@ public MSQStatusReport( @JsonProperty("warnings") Collection warningReports, @JsonProperty("startTime") @Nullable DateTime startTime, @JsonProperty("durationMs") long durationMs, - @JsonProperty("workers") Map> workerStats, + @JsonProperty("workers") Map> workerStats, @JsonProperty("pendingTasks") int pendingTasks, @JsonProperty("runningTasks") int runningTasks, - @JsonProperty("segmentLoadWaiterStatus") @Nullable SegmentLoadStatusFetcher.SegmentLoadWaiterStatus segmentLoadWaiterStatus, + @JsonProperty("segmentLoadWaiterStatus") @Nullable + SegmentLoadStatusFetcher.SegmentLoadWaiterStatus segmentLoadWaiterStatus, @JsonProperty("segmentReport") @Nullable MSQSegmentReport segmentReport ) { @@ -136,7 +139,7 @@ public long getDurationMs() } @JsonProperty("workers") - public Map> getWorkerStats() + public Map> getWorkerStats() { return workerStats; } @@ -157,6 +160,22 @@ public MSQSegmentReport getSegmentReport() return segmentReport; } + /** + * Returns a {@link TaskStatus} appropriate for this status report. + */ + public TaskStatus toTaskStatus(final String taskId) + { + if (status == TaskState.SUCCESS) { + return TaskStatus.success(taskId); + } else { + // Error report is nonnull when status code != SUCCESS. Use that message. + return TaskStatus.failure( + taskId, + MSQFaultUtils.generateMessageWithErrorCode(errorReport.getFault()) + ); + } + } + @Override public boolean equals(Object o) { diff --git a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/indexing/report/MSQTaskReportPayload.java b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/indexing/report/MSQTaskReportPayload.java index 111cb5aa83a3..bf00c9434df2 100644 --- a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/indexing/report/MSQTaskReportPayload.java +++ b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/indexing/report/MSQTaskReportPayload.java @@ -28,6 +28,11 @@ public class MSQTaskReportPayload { + public static final String FIELD_STATUS = "status"; + public static final String FIELD_STAGES = "stages"; + public static final String FIELD_COUNTERS = "counters"; + public static final String FIELD_RESULTS = "results"; + private final MSQStatusReport status; @Nullable @@ -41,10 +46,10 @@ public class MSQTaskReportPayload @JsonCreator public MSQTaskReportPayload( - @JsonProperty("status") MSQStatusReport status, - @JsonProperty("stages") @Nullable MSQStagesReport stages, - @JsonProperty("counters") @Nullable CounterSnapshotsTree counters, - @JsonProperty("results") @Nullable MSQResultsReport results + @JsonProperty(FIELD_STATUS) MSQStatusReport status, + @JsonProperty(FIELD_STAGES) @Nullable MSQStagesReport stages, + @JsonProperty(FIELD_COUNTERS) @Nullable CounterSnapshotsTree counters, + @JsonProperty(FIELD_RESULTS) @Nullable MSQResultsReport results ) { this.status = status; diff --git a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/input/InputSpecSlicer.java b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/input/InputSpecSlicer.java index 074d1a1c0489..ff1808e463c8 100644 --- a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/input/InputSpecSlicer.java +++ b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/input/InputSpecSlicer.java @@ -22,7 +22,8 @@ import java.util.List; /** - * Slices {@link InputSpec} into {@link InputSlice} on the controller. + * Slices {@link InputSpec} into {@link InputSlice} on the controller. Each slice is assigned to a single worker, and + * the slice number equals the worker number. */ public interface InputSpecSlicer { diff --git a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/input/InputSpecSlicerFactory.java b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/input/InputSpecSlicerFactory.java index 24b5cc1c5259..8accf1ec1e90 100644 --- a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/input/InputSpecSlicerFactory.java +++ b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/input/InputSpecSlicerFactory.java @@ -20,6 +20,7 @@ package org.apache.druid.msq.input; import it.unimi.dsi.fastutil.ints.Int2ObjectMap; +import org.apache.druid.msq.exec.OutputChannelMode; import org.apache.druid.msq.input.stage.ReadablePartitions; import org.apache.druid.msq.input.stage.StageInputSpecSlicer; @@ -32,5 +33,8 @@ */ public interface InputSpecSlicerFactory { - InputSpecSlicer makeSlicer(Int2ObjectMap stagePartitionsMap); + InputSpecSlicer makeSlicer( + Int2ObjectMap stagePartitionsMap, + Int2ObjectMap stageOutputChannelModeMap + ); } diff --git a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/input/stage/StageInputSlice.java b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/input/stage/StageInputSlice.java index eaf47a5df0dd..2c32d0e9ec0e 100644 --- a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/input/stage/StageInputSlice.java +++ b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/input/stage/StageInputSlice.java @@ -23,8 +23,10 @@ import com.fasterxml.jackson.annotation.JsonProperty; import com.fasterxml.jackson.annotation.JsonTypeName; import com.google.common.base.Preconditions; +import org.apache.druid.msq.exec.OutputChannelMode; import org.apache.druid.msq.input.InputSlice; +import javax.annotation.Nullable; import java.util.Objects; /** @@ -38,14 +40,19 @@ public class StageInputSlice implements InputSlice private final int stage; private final ReadablePartitions partitions; + @Nullable // May be null when created by older controllers + private final OutputChannelMode outputChannelMode; + @JsonCreator public StageInputSlice( @JsonProperty("stage") int stageNumber, - @JsonProperty("partitions") ReadablePartitions partitions + @JsonProperty("partitions") ReadablePartitions partitions, + @JsonProperty("output") OutputChannelMode outputChannelMode ) { this.stage = stageNumber; this.partitions = Preconditions.checkNotNull(partitions, "partitions"); + this.outputChannelMode = outputChannelMode; } @JsonProperty("stage") @@ -60,6 +67,13 @@ public ReadablePartitions getPartitions() return partitions; } + @JsonProperty("output") + @Nullable // May be null when created by older controllers + public OutputChannelMode getOutputChannelMode() + { + return outputChannelMode; + } + @Override public int fileCount() { @@ -76,21 +90,24 @@ public boolean equals(Object o) return false; } StageInputSlice that = (StageInputSlice) o; - return stage == that.stage && Objects.equals(partitions, that.partitions); + return stage == that.stage + && Objects.equals(partitions, that.partitions) + && outputChannelMode == that.outputChannelMode; } @Override public int hashCode() { - return Objects.hash(stage, partitions); + return Objects.hash(stage, partitions, outputChannelMode); } @Override public String toString() { - return "StageInputSpec{" + + return "StageInputSlice{" + "stage=" + stage + ", partitions=" + partitions + + ", outputChannelMode=" + outputChannelMode + '}'; } } diff --git a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/input/stage/StageInputSpecSlicer.java b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/input/stage/StageInputSpecSlicer.java index ad41b4234e85..f3b5d23ae4f0 100644 --- a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/input/stage/StageInputSpecSlicer.java +++ b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/input/stage/StageInputSpecSlicer.java @@ -21,6 +21,7 @@ import it.unimi.dsi.fastutil.ints.Int2ObjectMap; import org.apache.druid.java.util.common.ISE; +import org.apache.druid.msq.exec.OutputChannelMode; import org.apache.druid.msq.input.InputSlice; import org.apache.druid.msq.input.InputSpec; import org.apache.druid.msq.input.InputSpecSlicer; @@ -36,9 +37,16 @@ public class StageInputSpecSlicer implements InputSpecSlicer // Stage number -> partitions for that stage private final Int2ObjectMap stagePartitionsMap; - public StageInputSpecSlicer(final Int2ObjectMap stagePartitionsMap) + // Stage number -> output mode for that stage + private final Int2ObjectMap stageOutputChannelModeMap; + + public StageInputSpecSlicer( + final Int2ObjectMap stagePartitionsMap, + final Int2ObjectMap stageOutputChannelModeMap + ) { this.stagePartitionsMap = stagePartitionsMap; + this.stageOutputChannelModeMap = stageOutputChannelModeMap; } @Override @@ -53,9 +61,14 @@ public List sliceStatic(InputSpec inputSpec, int maxNumSlices) final StageInputSpec stageInputSpec = (StageInputSpec) inputSpec; final ReadablePartitions stagePartitions = stagePartitionsMap.get(stageInputSpec.getStageNumber()); + final OutputChannelMode outputChannelMode = stageOutputChannelModeMap.get(stageInputSpec.getStageNumber()); if (stagePartitions == null) { - throw new ISE("Stage [%d] not available", stageInputSpec.getStageNumber()); + throw new ISE("Stage[%d] output partitions not available", stageInputSpec.getStageNumber()); + } + + if (outputChannelMode == null) { + throw new ISE("Stage[%d] output mode not available", stageInputSpec.getStageNumber()); } // Decide how many workers to use, and assign inputs. @@ -63,7 +76,13 @@ public List sliceStatic(InputSpec inputSpec, int maxNumSlices) final List retVal = new ArrayList<>(); for (final ReadablePartitions partitions : workerPartitions) { - retVal.add(new StageInputSlice(stageInputSpec.getStageNumber(), partitions)); + retVal.add( + new StageInputSlice( + stageInputSpec.getStageNumber(), + partitions, + outputChannelMode + ) + ); } return retVal; diff --git a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/input/table/TableInputSpecSlicer.java b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/input/table/TableInputSpecSlicer.java index 7e93324ce68d..916dd3c1db38 100644 --- a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/input/table/TableInputSpecSlicer.java +++ b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/input/table/TableInputSpecSlicer.java @@ -22,19 +22,31 @@ import com.google.common.base.Preconditions; import com.google.common.collect.ImmutableList; import com.google.common.collect.Iterators; +import org.apache.druid.client.ImmutableSegmentLoadInfo; +import org.apache.druid.client.coordinator.CoordinatorClient; +import org.apache.druid.indexing.common.actions.RetrieveUsedSegmentsAction; +import org.apache.druid.indexing.common.actions.TaskActionClient; +import org.apache.druid.java.util.common.logger.Logger; +import org.apache.druid.msq.exec.SegmentSource; +import org.apache.druid.msq.indexing.error.MSQException; +import org.apache.druid.msq.indexing.error.UnknownFault; import org.apache.druid.msq.input.InputSlice; import org.apache.druid.msq.input.InputSpec; import org.apache.druid.msq.input.InputSpecSlicer; import org.apache.druid.msq.input.NilInputSlice; import org.apache.druid.msq.input.SlicerUtils; -import org.apache.druid.msq.querykit.DataSegmentTimelineView; import org.apache.druid.query.filter.DimFilterUtils; import org.apache.druid.server.coordination.DruidServerMetadata; import org.apache.druid.timeline.DataSegment; +import org.apache.druid.timeline.SegmentTimeline; import org.apache.druid.timeline.TimelineLookup; +import org.apache.druid.timeline.VersionedIntervalTimeline; import org.joda.time.Interval; +import javax.annotation.Nullable; +import java.io.IOException; import java.util.ArrayList; +import java.util.Collection; import java.util.Collections; import java.util.HashMap; import java.util.HashSet; @@ -46,15 +58,25 @@ import java.util.stream.StreamSupport; /** - * Slices {@link TableInputSpec} into {@link SegmentsInputSlice}. + * Slices {@link TableInputSpec} into {@link SegmentsInputSlice} in tasks. */ public class TableInputSpecSlicer implements InputSpecSlicer { - private final DataSegmentTimelineView timelineView; + private static final Logger log = new Logger(TableInputSpecSlicer.class); - public TableInputSpecSlicer(DataSegmentTimelineView timelineView) + private final CoordinatorClient coordinatorClient; + private final TaskActionClient taskActionClient; + private final SegmentSource includeSegmentSource; + + public TableInputSpecSlicer( + CoordinatorClient coordinatorClient, + TaskActionClient taskActionClient, + SegmentSource includeSegmentSource + ) { - this.timelineView = timelineView; + this.coordinatorClient = coordinatorClient; + this.taskActionClient = taskActionClient; + this.includeSegmentSource = includeSegmentSource; } @Override @@ -128,7 +150,7 @@ public List sliceDynamic( private Set getPrunedSegmentSet(final TableInputSpec tableInputSpec) { final TimelineLookup timeline = - timelineView.getTimeline(tableInputSpec.getDataSource(), tableInputSpec.getIntervals()).orElse(null); + getTimeline(tableInputSpec.getDataSource(), tableInputSpec.getIntervals()); if (timeline == null) { return Collections.emptySet(); @@ -159,6 +181,87 @@ private Set getPrunedSegmentSet(final TableInputSpec ta } } + @Nullable + private VersionedIntervalTimeline getTimeline( + final String dataSource, + final List intervals + ) + { + final boolean includeRealtime = SegmentSource.shouldQueryRealtimeServers(includeSegmentSource); + final Iterable realtimeAndHistoricalSegments; + + // Fetch the realtime segments and segments loaded on the historical. Do this first so that we don't miss any + // segment if they get handed off between the two calls. Segments loaded on historicals are deduplicated below, + // since we are only interested in realtime segments for now. + if (includeRealtime) { + realtimeAndHistoricalSegments = coordinatorClient.fetchServerViewSegments(dataSource, intervals); + } else { + realtimeAndHistoricalSegments = ImmutableList.of(); + } + + // Fetch all published, used segments (all non-realtime segments) from the metadata store. + // If the task is operating with a REPLACE lock, + // any segment created after the lock was acquired for its interval will not be considered. + final Collection publishedUsedSegments; + try { + // Additional check as the task action does not accept empty intervals + if (intervals.isEmpty()) { + publishedUsedSegments = Collections.emptySet(); + } else { + publishedUsedSegments = + taskActionClient.submit(new RetrieveUsedSegmentsAction(dataSource, intervals)); + } + } + catch (IOException e) { + throw new MSQException(e, UnknownFault.forException(e)); + } + + int realtimeCount = 0; + + // Deduplicate segments, giving preference to published used segments. + // We do this so that if any segments have been handed off in between the two metadata calls above, + // we directly fetch it from deep storage. + Set unifiedSegmentView = new HashSet<>(publishedUsedSegments); + + // Iterate over the realtime segments and segments loaded on the historical + for (ImmutableSegmentLoadInfo segmentLoadInfo : realtimeAndHistoricalSegments) { + Set servers = segmentLoadInfo.getServers(); + // Filter out only realtime servers. We don't want to query historicals for now, but we can in the future. + // This check can be modified then. + Set realtimeServerMetadata + = servers.stream() + .filter(druidServerMetadata -> includeSegmentSource.getUsedServerTypes() + .contains(druidServerMetadata.getType()) + ) + .collect(Collectors.toSet()); + if (!realtimeServerMetadata.isEmpty()) { + realtimeCount += 1; + DataSegmentWithLocation dataSegmentWithLocation = new DataSegmentWithLocation( + segmentLoadInfo.getSegment(), + realtimeServerMetadata + ); + unifiedSegmentView.add(dataSegmentWithLocation); + } else { + // We don't have any segments of the required segment source, ignore the segment + } + } + + if (includeRealtime) { + log.info( + "Fetched total [%d] segments from coordinator: [%d] from metadata stoure, [%d] from server view", + unifiedSegmentView.size(), + publishedUsedSegments.size(), + realtimeCount + ); + } + + if (unifiedSegmentView.isEmpty()) { + return null; + } else { + return SegmentTimeline.forSegments(unifiedSegmentView); + } + } + private static List makeSlices( final TableInputSpec tableInputSpec, final List> assignments @@ -206,7 +309,8 @@ private static List createWeightedSegmentSet(List new HashSet<>()); serverVsSegmentsMap.get(druidServerMetadata).add(dataSegmentWithInterval); @@ -286,7 +390,8 @@ public DataServerRequestDescriptor toDataServerRequestDescriptor() { return new DataServerRequestDescriptor( serverMetadata, - segments.stream().map(DataSegmentWithInterval::toRichSegmentDescriptor).collect(Collectors.toList())); + segments.stream().map(DataSegmentWithInterval::toRichSegmentDescriptor).collect(Collectors.toList()) + ); } } } diff --git a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/kernel/GlobalSortMaxCountShuffleSpec.java b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/kernel/GlobalSortMaxCountShuffleSpec.java index e773fcb87a97..8be2108a57a4 100644 --- a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/kernel/GlobalSortMaxCountShuffleSpec.java +++ b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/kernel/GlobalSortMaxCountShuffleSpec.java @@ -119,7 +119,11 @@ public ClusterBy clusterBy() @Override public int partitionCount() { - throw new ISE("Number of partitions not known for [%s].", kind()); + if (maxPartitions == 1) { + return 1; + } else { + throw new ISE("Number of partitions not known for [%s] with maxPartitions[%d].", kind(), maxPartitions); + } } @JsonProperty("partitions") diff --git a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/kernel/GlobalSortShuffleSpec.java b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/kernel/GlobalSortShuffleSpec.java index 4d608c4fbe43..7f0878da0393 100644 --- a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/kernel/GlobalSortShuffleSpec.java +++ b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/kernel/GlobalSortShuffleSpec.java @@ -35,6 +35,12 @@ public interface GlobalSortShuffleSpec extends ShuffleSpec */ boolean mustGatherResultKeyStatistics(); + /** + * Whether the {@link ClusterByStatisticsCollector} for this stage collects keys in aggregating mode or + * non-aggregating mode. + */ + boolean doesAggregate(); + /** * Generates a set of partitions based on the provided statistics. * diff --git a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/kernel/HashShuffleSpec.java b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/kernel/HashShuffleSpec.java index fc453d76635b..fbc39fc672c3 100644 --- a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/kernel/HashShuffleSpec.java +++ b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/kernel/HashShuffleSpec.java @@ -59,16 +59,11 @@ public ClusterBy clusterBy() return clusterBy; } - @Override - public boolean doesAggregate() - { - return false; - } - @Override @JsonProperty("partitions") public int partitionCount() { return numPartitions; } + } diff --git a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/kernel/MixShuffleSpec.java b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/kernel/MixShuffleSpec.java index 6fbe16b6740f..b29d41c336ae 100644 --- a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/kernel/MixShuffleSpec.java +++ b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/kernel/MixShuffleSpec.java @@ -53,12 +53,6 @@ public ClusterBy clusterBy() return ClusterBy.none(); } - @Override - public boolean doesAggregate() - { - return false; - } - @Override public int partitionCount() { diff --git a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/kernel/QueryDefinition.java b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/kernel/QueryDefinition.java index 553e119131d5..64f27f2fddc0 100644 --- a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/kernel/QueryDefinition.java +++ b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/kernel/QueryDefinition.java @@ -96,14 +96,14 @@ static QueryDefinition create(@JsonProperty("stages") final List stageBuilders = new ArrayList<>(); /** - * Package-private: callers should use {@link QueryDefinition#builder()}. + * Package-private: callers should use {@link QueryDefinition#builder(String)}. */ - QueryDefinitionBuilder() + QueryDefinitionBuilder(final String queryId) { - } - - public QueryDefinitionBuilder queryId(final String queryId) - { - this.queryId = Preconditions.checkNotNull(queryId, "queryId"); - return this; + this.queryId = queryId; } public QueryDefinitionBuilder add(final StageDefinitionBuilder stageBuilder) diff --git a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/kernel/ShuffleKind.java b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/kernel/ShuffleKind.java index ac3bb99273e7..b1ae27e87166 100644 --- a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/kernel/ShuffleKind.java +++ b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/kernel/ShuffleKind.java @@ -19,19 +19,23 @@ package org.apache.druid.msq.kernel; +import com.fasterxml.jackson.annotation.JsonCreator; +import com.fasterxml.jackson.annotation.JsonValue; +import org.apache.druid.java.util.common.IAE; + public enum ShuffleKind { /** * Put all data in a single partition, with no sorting and no statistics gathering. */ - MIX(false, false), + MIX("mix", false, false), /** * Partition using hash codes, with no sorting. * * This kind of shuffle supports pipelining: producer and consumer stages can run at the same time. */ - HASH(true, false), + HASH("hash", true, false), /** * Partition using hash codes, with each partition internally sorted. @@ -42,7 +46,7 @@ public enum ShuffleKind * Due to the need to sort outputs, this shuffle mechanism cannot be pipelined. Producer stages must finish before * consumer stages can run. */ - HASH_LOCAL_SORT(true, true), + HASH_LOCAL_SORT("hashLocalSort", true, true), /** * Partition using a distributed global sort. @@ -58,17 +62,31 @@ public enum ShuffleKind * Due to the need to sort outputs, this shuffle mechanism cannot be pipelined. Producer stages must finish before * consumer stages can run. */ - GLOBAL_SORT(false, true); + GLOBAL_SORT("globalSort", false, true); + private final String name; private final boolean hash; private final boolean sort; - ShuffleKind(boolean hash, boolean sort) + ShuffleKind(String name, boolean hash, boolean sort) { + this.name = name; this.hash = hash; this.sort = sort; } + @JsonCreator + public static ShuffleKind fromString(final String s) + { + for (final ShuffleKind kind : values()) { + if (kind.toString().equals(s)) { + return kind; + } + } + + throw new IAE("No such shuffleKind[%s]", s); + } + /** * Whether this shuffle does hash-partitioning. */ @@ -84,4 +102,11 @@ public boolean isSort() { return sort; } + + @Override + @JsonValue + public String toString() + { + return name; + } } diff --git a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/kernel/ShuffleSpec.java b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/kernel/ShuffleSpec.java index 37f53fca199d..4b7971a7f783 100644 --- a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/kernel/ShuffleSpec.java +++ b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/kernel/ShuffleSpec.java @@ -59,17 +59,13 @@ public interface ShuffleSpec ClusterBy clusterBy(); /** - * Whether this stage aggregates by the {@link #clusterBy()} key. - */ - boolean doesAggregate(); - - /** - * Number of partitions, if known. + * Number of partitions, if known in advance. * * Partition count is always known if {@link #kind()} is {@link ShuffleKind#MIX}, {@link ShuffleKind#HASH}, or - * {@link ShuffleKind#HASH_LOCAL_SORT}. It is not known if {@link #kind()} is {@link ShuffleKind#GLOBAL_SORT}. + * {@link ShuffleKind#HASH_LOCAL_SORT}. For {@link ShuffleKind#GLOBAL_SORT}, it is known if we have a single + * output partition. * - * @throws IllegalStateException if kind is {@link ShuffleKind#GLOBAL_SORT} + * @throws IllegalStateException if kind is {@link ShuffleKind#GLOBAL_SORT} with more than one target partition */ int partitionCount(); } diff --git a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/kernel/StageDefinition.java b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/kernel/StageDefinition.java index 4e212949d5ee..80b912faa8da 100644 --- a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/kernel/StageDefinition.java +++ b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/kernel/StageDefinition.java @@ -345,7 +345,7 @@ public ClusterByStatisticsCollector createResultKeyStatisticsCollector(final int signature, maxRetainedBytes, Limits.MAX_PARTITION_BUCKETS, - shuffleSpec.doesAggregate(), + ((GlobalSortShuffleSpec) shuffleSpec).doesAggregate(), shuffleCheckHasMultipleValues ); } diff --git a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/kernel/StageId.java b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/kernel/StageId.java index 35c8dc43665c..5b98eed0da95 100644 --- a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/kernel/StageId.java +++ b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/kernel/StageId.java @@ -21,8 +21,8 @@ import com.fasterxml.jackson.annotation.JsonCreator; import com.fasterxml.jackson.annotation.JsonValue; -import com.google.common.base.Strings; import org.apache.druid.common.guava.GuavaUtils; +import org.apache.druid.common.utils.IdUtils; import org.apache.druid.java.util.common.IAE; import org.apache.druid.java.util.common.StringUtils; @@ -43,15 +43,11 @@ public class StageId implements Comparable public StageId(final String queryId, final int stageNumber) { - if (Strings.isNullOrEmpty(queryId)) { - throw new IAE("Null or empty queryId"); - } - if (stageNumber < 0) { throw new IAE("Invalid stageNumber [%s]", stageNumber); } - this.queryId = queryId; + this.queryId = IdUtils.validateId("queryId", queryId); this.stageNumber = stageNumber; } diff --git a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/kernel/WorkOrder.java b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/kernel/WorkOrder.java index b9a3024048b0..201a1783c05f 100644 --- a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/kernel/WorkOrder.java +++ b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/kernel/WorkOrder.java @@ -23,6 +23,8 @@ import com.fasterxml.jackson.annotation.JsonInclude; import com.fasterxml.jackson.annotation.JsonProperty; import com.google.common.base.Preconditions; +import org.apache.druid.msq.exec.ControllerClient; +import org.apache.druid.msq.exec.OutputChannelMode; import org.apache.druid.msq.input.InputSlice; import javax.annotation.Nullable; @@ -46,6 +48,12 @@ public class WorkOrder private final List workerInputs; private final ExtraInfoHolder extraInfoHolder; + @Nullable + private final List workerIds; + + @Nullable + private final OutputChannelMode outputChannelMode; + @JsonCreator @SuppressWarnings("rawtypes") public WorkOrder( @@ -53,7 +61,9 @@ public WorkOrder( @JsonProperty("stage") final int stageNumber, @JsonProperty("worker") final int workerNumber, @JsonProperty("input") final List workerInputs, - @JsonProperty("extra") @Nullable final ExtraInfoHolder extraInfoHolder + @JsonProperty("extra") @Nullable final ExtraInfoHolder extraInfoHolder, + @JsonProperty("workers") @Nullable final List workerIds, + @JsonProperty("output") @Nullable final OutputChannelMode outputChannelMode ) { this.queryDefinition = Preconditions.checkNotNull(queryDefinition, "queryDefinition"); @@ -61,6 +71,8 @@ public WorkOrder( this.workerNumber = workerNumber; this.workerInputs = Preconditions.checkNotNull(workerInputs, "workerInputs"); this.extraInfoHolder = extraInfoHolder; + this.workerIds = workerIds; + this.outputChannelMode = outputChannelMode; } @JsonProperty("query") @@ -95,6 +107,31 @@ ExtraInfoHolder getExtraInfoHolder() return extraInfoHolder; } + /** + * Worker IDs for this query, if known in advance (at the time the work order is created). May be null, in which + * case workers use {@link ControllerClient#getTaskList()} to find worker IDs. + */ + @Nullable + @JsonProperty("workers") + @JsonInclude(JsonInclude.Include.NON_NULL) + public List getWorkerIds() + { + return workerIds; + } + + public boolean hasOutputChannelMode() + { + return outputChannelMode != null; + } + + @Nullable + @JsonProperty("output") + @JsonInclude(JsonInclude.Include.NON_NULL) + public OutputChannelMode getOutputChannelMode() + { + return outputChannelMode; + } + @Nullable public Object getExtraInfo() { @@ -106,6 +143,23 @@ public StageDefinition getStageDefinition() return queryDefinition.getStageDefinition(stageNumber); } + public WorkOrder withOutputChannelMode(final OutputChannelMode newOutputChannelMode) + { + if (newOutputChannelMode == outputChannelMode) { + return this; + } else { + return new WorkOrder( + queryDefinition, + stageNumber, + workerNumber, + workerInputs, + extraInfoHolder, + workerIds, + newOutputChannelMode + ); + } + } + @Override public boolean equals(Object o) { @@ -120,13 +174,23 @@ public boolean equals(Object o) && workerNumber == workOrder.workerNumber && Objects.equals(queryDefinition, workOrder.queryDefinition) && Objects.equals(workerInputs, workOrder.workerInputs) - && Objects.equals(extraInfoHolder, workOrder.extraInfoHolder); + && Objects.equals(extraInfoHolder, workOrder.extraInfoHolder) + && Objects.equals(workerIds, workOrder.workerIds) + && Objects.equals(outputChannelMode, workOrder.outputChannelMode); } @Override public int hashCode() { - return Objects.hash(queryDefinition, stageNumber, workerInputs, workerNumber, extraInfoHolder); + return Objects.hash( + queryDefinition, + stageNumber, + workerNumber, + workerInputs, + extraInfoHolder, + workerIds, + outputChannelMode + ); } @Override @@ -138,6 +202,8 @@ public String toString() ", workerNumber=" + workerNumber + ", workerInputs=" + workerInputs + ", extraInfoHolder=" + extraInfoHolder + + ", workerIds=" + workerIds + + ", outputChannelMode=" + outputChannelMode + '}'; } } diff --git a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/kernel/WorkerAssignmentStrategy.java b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/kernel/WorkerAssignmentStrategy.java index 18f1f821d9a6..cdf4e2e20b0b 100644 --- a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/kernel/WorkerAssignmentStrategy.java +++ b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/kernel/WorkerAssignmentStrategy.java @@ -33,12 +33,12 @@ import java.util.OptionalInt; /** - * Strategy for assigning input slices to tasks. Influences how {@link InputSpecSlicer} is used. + * Strategy for assigning input slices to workers. Influences how {@link InputSpecSlicer} is used. */ public enum WorkerAssignmentStrategy { /** - * Use the highest possible number of tasks, while staying within {@link StageDefinition#getMaxWorkerCount()}. + * Use the highest possible number of workers, while staying within {@link StageDefinition#getMaxWorkerCount()}. * * Implemented using {@link InputSpecSlicer#sliceStatic}. */ @@ -57,7 +57,7 @@ public List assign( }, /** - * Use the lowest possible number of tasks, while keeping each task's workload under + * Use the lowest possible number of workers, while keeping each worker's workload under * {@link Limits#MAX_INPUT_FILES_PER_WORKER} files and {@code maxInputBytesPerWorker} bytes. * * Implemented using {@link InputSpecSlicer#sliceDynamic} whenever possible. diff --git a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/kernel/controller/ControllerQueryKernel.java b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/kernel/controller/ControllerQueryKernel.java index c7805f04a9f3..05e0f722ccd4 100644 --- a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/kernel/controller/ControllerQueryKernel.java +++ b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/kernel/controller/ControllerQueryKernel.java @@ -20,7 +20,6 @@ package org.apache.druid.msq.kernel.controller; import com.google.common.annotations.VisibleForTesting; -import com.google.common.base.Preconditions; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; import com.google.common.collect.ImmutableSet; @@ -33,6 +32,7 @@ import org.apache.druid.java.util.common.IAE; import org.apache.druid.java.util.common.ISE; import org.apache.druid.java.util.common.logger.Logger; +import org.apache.druid.msq.exec.OutputChannelMode; import org.apache.druid.msq.exec.QueryValidator; import org.apache.druid.msq.indexing.error.CanceledFault; import org.apache.druid.msq.indexing.error.MSQException; @@ -41,7 +41,6 @@ import org.apache.druid.msq.indexing.error.UnknownFault; import org.apache.druid.msq.indexing.error.WorkerFailedFault; import org.apache.druid.msq.indexing.error.WorkerRpcFailedFault; -import org.apache.druid.msq.input.InputSpecSlicer; import org.apache.druid.msq.input.InputSpecSlicerFactory; import org.apache.druid.msq.input.stage.ReadablePartitions; import org.apache.druid.msq.kernel.ExtraInfoHolder; @@ -55,14 +54,17 @@ import org.apache.druid.msq.statistics.PartialKeyStatisticsInformation; import javax.annotation.Nullable; +import java.util.ArrayDeque; import java.util.ArrayList; import java.util.HashMap; import java.util.HashSet; import java.util.Iterator; import java.util.List; import java.util.Map; -import java.util.Optional; +import java.util.Queue; import java.util.Set; +import java.util.SortedSet; +import java.util.function.Consumer; import java.util.stream.Collectors; /** @@ -77,39 +79,49 @@ public class ControllerQueryKernel { private static final Logger log = new Logger(ControllerQueryKernel.class); + private final QueryDefinition queryDef; + private final ControllerQueryKernelConfig config; /** * Stage ID -> tracker for that stage. An extension of the state of this kernel. */ - private final Map stageTracker = new HashMap<>(); + private final Map stageTrackers = new HashMap<>(); /** - * Stage ID -> stages that flow *into* that stage. Computed by {@link #computeStageInflowMap}. + * Stage ID -> stages that flow *into* that stage. Computed by {@link ControllerQueryKernelUtils#computeStageInflowMap}. */ private final ImmutableMap> inflowMap; /** - * Stage ID -> stages that *depend on* that stage. Computed by {@link #computeStageOutflowMap}. + * Stage ID -> stages that *depend on* that stage. Computed by {@link ControllerQueryKernelUtils#computeStageOutflowMap}. */ private final ImmutableMap> outflowMap; /** * Maintains a running map of (stageId -> pending inflow stages) which need to be completed to provision the stage * corresponding to the stageId. After initializing, if the value of the entry becomes an empty set, it is removed - * from the map, and the removed entry is added to {@link #readyToRunStages}. + * from the map, and the removed entry is added to {@link #stageGroupQueue}. */ - private final Map> pendingInflowMap; + private final Map> pendingInflowMap; /** * Maintains a running count of (stageId -> outflow stages pending on its results). After initializing, if * the value of the entry becomes an empty set, it is removed from the map and the removed entry is added to * {@link #effectivelyFinishedStages}. */ - private final Map> pendingOutflowMap; + private final Map> pendingOutflowMap; /** - * Tracks those stages which can be initialized safely. + * Stage groups, in the order that we will run them. Each group is a set of stages that internally uses + * {@link OutputChannelMode#MEMORY} for communication. (The final stage may use a different + * {@link OutputChannelMode}. In particular, if a stage group has a single stage, it may use any + * {@link OutputChannelMode}.) + */ + private final Queue stageGroupQueue; + + /** + * Tracks those stages which are ready to begin executing. Populated by {@link #registerStagePhaseChange}. */ private final Set readyToRunStages = new HashSet<>(); @@ -123,7 +135,12 @@ public class ControllerQueryKernel * Map> * Stores the work order per worker per stage so that we can retrieve that in case of worker retry */ - private final Map> stageWorkOrders; + private final Map> stageWorkOrders = new HashMap<>(); + + /** + * Tracks the output channel mode for each stage. + */ + private final Map stageOutputChannelModes = new HashMap<>(); /** * {@link MSQFault#getErrorCode()} which are retried. @@ -133,27 +150,22 @@ public class ControllerQueryKernel UnknownFault.CODE, WorkerRpcFailedFault.CODE ); - private final int maxRetainedPartitionSketchBytes; - private final boolean faultToleranceEnabled; public ControllerQueryKernel( final QueryDefinition queryDef, - int maxRetainedPartitionSketchBytes, - boolean faultToleranceEnabled + final ControllerQueryKernelConfig config ) { this.queryDef = queryDef; - this.maxRetainedPartitionSketchBytes = maxRetainedPartitionSketchBytes; - this.faultToleranceEnabled = faultToleranceEnabled; - this.inflowMap = ImmutableMap.copyOf(computeStageInflowMap(queryDef)); - this.outflowMap = ImmutableMap.copyOf(computeStageOutflowMap(queryDef)); + this.config = config; + this.inflowMap = ImmutableMap.copyOf(ControllerQueryKernelUtils.computeStageInflowMap(queryDef)); + this.outflowMap = ImmutableMap.copyOf(ControllerQueryKernelUtils.computeStageOutflowMap(queryDef)); // pendingInflowMap and pendingOutflowMap are wholly separate from inflowMap, so we can edit the Sets. - this.pendingInflowMap = computeStageInflowMap(queryDef); - this.pendingOutflowMap = computeStageOutflowMap(queryDef); - - stageWorkOrders = new HashMap<>(); + this.pendingInflowMap = ControllerQueryKernelUtils.computeStageInflowMap(queryDef); + this.pendingOutflowMap = ControllerQueryKernelUtils.computeStageOutflowMap(queryDef); + this.stageGroupQueue = new ArrayDeque<>(ControllerQueryKernelUtils.computeStageGroups(queryDef, config)); initializeReadyToRunStages(); } @@ -166,31 +178,24 @@ public List createAndGetNewStageIds( final long maxInputBytesPerWorker ) { - final Int2IntMap stageWorkerCountMap = new Int2IntAVLTreeMap(); - final Int2ObjectMap stagePartitionsMap = new Int2ObjectAVLTreeMap<>(); - - for (final ControllerStageTracker stageKernel : stageTracker.values()) { - final int stageNumber = stageKernel.getStageDefinition().getStageNumber(); - stageWorkerCountMap.put(stageNumber, stageKernel.getWorkerInputs().workerCount()); - - if (stageKernel.hasResultPartitions()) { - stagePartitionsMap.put(stageNumber, stageKernel.getResultPartitions()); - } - } + createNewKernels( + slicerFactory, + assignmentStrategy, + maxInputBytesPerWorker + ); - createNewKernels(stageWorkerCountMap, slicerFactory.makeSlicer(stagePartitionsMap), assignmentStrategy, maxInputBytesPerWorker); - return stageTracker.values() - .stream() - .filter(controllerStageTracker -> controllerStageTracker.getPhase() == ControllerStagePhase.NEW) - .map(stageKernel -> stageKernel.getStageDefinition().getId()) - .collect(Collectors.toList()); + return stageTrackers.values() + .stream() + .filter(controllerStageTracker -> controllerStageTracker.getPhase() == ControllerStagePhase.NEW) + .map(stageTracker -> stageTracker.getStageDefinition().getId()) + .collect(Collectors.toList()); } /** * @return Stage kernels in this query kernel which can be safely cleaned up and marked as FINISHED. This returns the * kernel corresponding to a particular stage only once, to reduce the number of stages to iterate through. * It is expectant of the caller to eventually mark the stage as {@link ControllerStagePhase#FINISHED} after fetching - * the stage kernel + * the stage tracker */ public List getEffectivelyFinishedStageIds() { @@ -202,7 +207,23 @@ public List getEffectivelyFinishedStageIds() */ public List getActiveStages() { - return ImmutableList.copyOf(stageTracker.keySet()); + return ImmutableList.copyOf(stageTrackers.keySet()); + } + + /** + * Returns the number of stages that are active and in non-terminal phases. + */ + public int getNonTerminalActiveStageCount() + { + int n = 0; + + for (final ControllerStageTracker tracker : stageTrackers.values()) { + if (!tracker.getPhase().isTerminal() && tracker.getPhase() != ControllerStagePhase.RESULTS_READY) { + n++; + } + } + + return n; } /** @@ -219,10 +240,8 @@ public StageId getStageId(final int stageNumber) */ public boolean isDone() { - return Optional.ofNullable(stageTracker.get(queryDef.getFinalStageDefinition().getId())) - .filter(tracker -> ControllerStagePhase.isSuccessfulTerminalPhase(tracker.getPhase())) - .isPresent() - || stageTracker.values().stream().anyMatch(tracker -> tracker.getPhase() == ControllerStagePhase.FAILED); + return isSuccess() + || stageTrackers.values().stream().anyMatch(tracker -> tracker.getPhase() == ControllerStagePhase.FAILED); } /** @@ -237,7 +256,7 @@ public void markSuccessfulTerminalStagesAsFinished() // terminal phases" to FINISHED at the end, hence the if clause. Inside the conditional, depending on the // terminal phase it resides in, we synthetically mark it to completion (and therefore we need to check which // stage it is precisely in) - if (ControllerStagePhase.isSuccessfulTerminalPhase(phase)) { + if (phase.isSuccess()) { if (phase == ControllerStagePhase.RESULTS_READY) { finishStage(stageId, false); } @@ -246,14 +265,14 @@ public void markSuccessfulTerminalStagesAsFinished() } /** - * Returns true if all the stages comprising the query definition have been successful in producing their results + * Returns true if all the stages comprising the query definition have been successful in producing their results. */ public boolean isSuccess() { - return stageTracker.size() == queryDef.getStageDefinitions().size() - && stageTracker.values() - .stream() - .allMatch(tracker -> ControllerStagePhase.isSuccessfulTerminalPhase(tracker.getPhase())); + return stageTrackers.size() == queryDef.getStageDefinitions().size() + && stageTrackers.values() + .stream() + .allMatch(tracker -> tracker.getPhase() == ControllerStagePhase.FINISHED); } /** @@ -265,9 +284,10 @@ public Int2ObjectMap createWorkOrders( ) { final Int2ObjectMap workerToWorkOrder = new Int2ObjectAVLTreeMap<>(); - final ControllerStageTracker stageKernel = getStageKernelOrThrow(getStageId(stageNumber)); - + final ControllerStageTracker stageKernel = getStageTrackerOrThrow(getStageId(stageNumber)); final WorkerInputs workerInputs = stageKernel.getWorkerInputs(); + final OutputChannelMode outputChannelMode = stageOutputChannelModes.get(stageKernel.getStageDefinition().getId()); + for (int workerNumber : workerInputs.workers()) { final Object extraInfo = extraInfos != null ? extraInfos.get(workerNumber) : null; @@ -280,7 +300,9 @@ public Int2ObjectMap createWorkOrders( stageNumber, workerNumber, workerInputs.inputsForWorker(workerNumber), - extraInfoHolder + extraInfoHolder, + config.getWorkerIds(), + outputChannelMode ); QueryValidator.validateWorkOrder(workOrder); @@ -291,27 +313,80 @@ public Int2ObjectMap createWorkOrders( } private void createNewKernels( - final Int2IntMap stageWorkerCountMap, - final InputSpecSlicer slicer, + final InputSpecSlicerFactory slicerFactory, final WorkerAssignmentStrategy assignmentStrategy, final long maxInputBytesPerWorker ) { - for (final StageId nextStage : readyToRunStages) { - // Create a tracker. - final StageDefinition stageDef = queryDef.getStageDefinition(nextStage); - final ControllerStageTracker stageKernel = ControllerStageTracker.create( - stageDef, - stageWorkerCountMap, - slicer, - assignmentStrategy, - maxRetainedPartitionSketchBytes, - maxInputBytesPerWorker - ); - stageTracker.put(nextStage, stageKernel); + StageGroup stageGroup; + + while ((stageGroup = stageGroupQueue.peek()) != null) { + if (readyToRunStages.contains(stageGroup.first()) + && getNonTerminalActiveStageCount() + stageGroup.size() <= config.getMaxConcurrentStages()) { + // There is room to launch this stage group. + stageGroupQueue.poll(); + + for (final StageId stageId : stageGroup.stageIds()) { + // Create a tracker for this stage. + stageTrackers.put( + stageId, + createStageTracker( + stageId, + slicerFactory, + assignmentStrategy, + maxInputBytesPerWorker + ) + ); + + // Store output channel mode. + stageOutputChannelModes.put( + stageId, + stageGroup.stageOutputChannelMode(stageId) + ); + } + + stageGroup.stageIds().forEach(readyToRunStages::remove); + } else { + break; + } + } + } + + private ControllerStageTracker createStageTracker( + final StageId stageId, + final InputSpecSlicerFactory slicerFactory, + final WorkerAssignmentStrategy assignmentStrategy, + final long maxInputBytesPerWorker + ) + { + final Int2IntMap stageWorkerCountMap = new Int2IntAVLTreeMap(); + final Int2ObjectMap stagePartitionsMap = new Int2ObjectAVLTreeMap<>(); + final Int2ObjectMap stageOutputChannelModeMap = new Int2ObjectAVLTreeMap<>(); + + for (final ControllerStageTracker stageTracker : stageTrackers.values()) { + final int stageNumber = stageTracker.getStageDefinition().getStageNumber(); + stageWorkerCountMap.put(stageNumber, stageTracker.getWorkerInputs().workerCount()); + + if (stageTracker.hasResultPartitions()) { + stagePartitionsMap.put(stageNumber, stageTracker.getResultPartitions()); + } + + final OutputChannelMode outputChannelMode = + stageOutputChannelModes.get(stageTracker.getStageDefinition().getId()); + + if (outputChannelMode != null) { + stageOutputChannelModeMap.put(stageNumber, outputChannelMode); + } } - readyToRunStages.clear(); + return ControllerStageTracker.create( + getStageDefinition(stageId), + stageWorkerCountMap, + slicerFactory.makeSlicer(stagePartitionsMap, stageOutputChannelModeMap), + assignmentStrategy, + config.getMaxRetainedPartitionSketchBytes(), + maxInputBytesPerWorker + ); } /** @@ -320,33 +395,81 @@ private void createNewKernels( */ private void initializeReadyToRunStages() { - final Iterator>> pendingInflowIterator = pendingInflowMap.entrySet().iterator(); + final List readyStages = new ArrayList<>(); + final Iterator>> pendingInflowIterator = + pendingInflowMap.entrySet().iterator(); while (pendingInflowIterator.hasNext()) { - Map.Entry> stageToInflowStages = pendingInflowIterator.next(); - if (stageToInflowStages.getValue().size() == 0) { - readyToRunStages.add(stageToInflowStages.getKey()); + final Map.Entry> stageToInflowStages = pendingInflowIterator.next(); + if (stageToInflowStages.getValue().isEmpty()) { + readyStages.add(stageToInflowStages.getKey()); pendingInflowIterator.remove(); } } - } - // Following section contains the methods which delegate to appropriate stage kernel + readyToRunStages.addAll(readyStages); + } /** - * Delegates call to {@link ControllerStageTracker#getStageDefinition()} + * Returns the definition of a given stage. + * + * @throws NullPointerException if there is no stage with the given ID */ public StageDefinition getStageDefinition(final StageId stageId) { - return getStageKernelOrThrow(stageId).getStageDefinition(); + return queryDef.getStageDefinition(stageId); + } + + /** + * Returns the {@link OutputChannelMode} for a given stage. + * + * @throws IllegalStateException if there is no stage with the given ID + */ + public OutputChannelMode getStageOutputChannelMode(final StageId stageId) + { + final OutputChannelMode outputChannelMode = stageOutputChannelModes.get(stageId); + if (outputChannelMode == null) { + throw new ISE("No such stage[%s]", stageId); + } + + return outputChannelMode; } + /** + * Whether query results are readable. + */ + public boolean canReadQueryResults() + { + final StageId finalStageId = queryDef.getFinalStageDefinition().getId(); + final ControllerStageTracker stageTracker = stageTrackers.get(finalStageId); + if (stageTracker == null) { + return false; + } else { + final OutputChannelMode outputChannelMode = stageOutputChannelModes.get(finalStageId); + if (outputChannelMode == OutputChannelMode.MEMORY) { + return stageTracker.getPhase().isRunning(); + } else { + return stageTracker.getPhase() == ControllerStagePhase.RESULTS_READY; + } + } + } + + // Following section contains the methods which delegate to appropriate stage kernel + /** * Delegates call to {@link ControllerStageTracker#getPhase()} */ public ControllerStagePhase getStagePhase(final StageId stageId) { - return getStageKernelOrThrow(stageId).getPhase(); + return getStageTrackerOrThrow(stageId).getPhase(); + } + + /** + * Returns whether a particular stage is finished. Stages can finish early if their outputs are no longer needed. + */ + public boolean isStageFinished(final StageId stageId) + { + return getStagePhase(stageId) == ControllerStagePhase.FINISHED; } /** @@ -354,7 +477,7 @@ public ControllerStagePhase getStagePhase(final StageId stageId) */ public boolean doesStageHaveResultPartitions(final StageId stageId) { - return getStageKernelOrThrow(stageId).hasResultPartitions(); + return getStageTrackerOrThrow(stageId).hasResultPartitions(); } /** @@ -362,7 +485,7 @@ public boolean doesStageHaveResultPartitions(final StageId stageId) */ public ReadablePartitions getResultPartitionsForStage(final StageId stageId) { - return getStageKernelOrThrow(stageId).getResultPartitions(); + return getStageTrackerOrThrow(stageId).getResultPartitions(); } /** @@ -370,7 +493,7 @@ public ReadablePartitions getResultPartitionsForStage(final StageId stageId) */ public IntSet getWorkersToSendPartitionBoundaries(final StageId stageId) { - return getStageKernelOrThrow(stageId).getWorkersToSendPartitionBoundaries(); + return getStageTrackerOrThrow(stageId).getWorkersToSendPartitionBoundaries(); } /** @@ -378,7 +501,7 @@ public IntSet getWorkersToSendPartitionBoundaries(final StageId stageId) */ public void workOrdersSentForWorker(final StageId stageId, int worker) { - getStageKernelOrThrow(stageId).workOrderSentForWorker(worker); + doWithStageTracker(stageId, stageTracker -> stageTracker.workOrderSentForWorker(worker)); } /** @@ -386,7 +509,7 @@ public void workOrdersSentForWorker(final StageId stageId, int worker) */ public void partitionBoundariesSentForWorker(final StageId stageId, int worker) { - getStageKernelOrThrow(stageId).partitionBoundariesSentForWorker(worker); + doWithStageTracker(stageId, stageTracker -> stageTracker.partitionBoundariesSentForWorker(worker)); } /** @@ -394,7 +517,7 @@ public void partitionBoundariesSentForWorker(final StageId stageId, int worker) */ public ClusterByPartitions getResultPartitionBoundariesForStage(final StageId stageId) { - return getStageKernelOrThrow(stageId).getResultPartitionBoundaries(); + return getStageTrackerOrThrow(stageId).getResultPartitionBoundaries(); } /** @@ -402,7 +525,7 @@ public ClusterByPartitions getResultPartitionBoundariesForStage(final StageId st */ public CompleteKeyStatisticsInformation getCompleteKeyStatisticsInformation(final StageId stageId) { - return getStageKernelOrThrow(stageId).getCompleteKeyStatisticsInformation(); + return getStageTrackerOrThrow(stageId).getCompleteKeyStatisticsInformation(); } /** @@ -410,7 +533,7 @@ public CompleteKeyStatisticsInformation getCompleteKeyStatisticsInformation(fina */ public boolean hasStageCollectorEncounteredAnyMultiValueField(final StageId stageId) { - return getStageKernelOrThrow(stageId).collectorEncounteredAnyMultiValueField(); + return getStageTrackerOrThrow(stageId).collectorEncounteredAnyMultiValueField(); } /** @@ -418,7 +541,7 @@ public boolean hasStageCollectorEncounteredAnyMultiValueField(final StageId stag */ public Object getResultObjectForStage(final StageId stageId) { - return getStageKernelOrThrow(stageId).getResultObject(); + return getStageTrackerOrThrow(stageId).getResultObject(); } /** @@ -427,15 +550,17 @@ public Object getResultObjectForStage(final StageId stageId) */ public void startStage(final StageId stageId) { - final ControllerStageTracker stageKernel = getStageKernelOrThrow(stageId); - if (stageKernel.getPhase() != ControllerStagePhase.NEW) { - throw new ISE("Cannot start the stage: [%s]", stageId); - } if (stageWorkOrders.get(stageId) == null) { - throw new ISE("Work orders not present for stage %s", stageId); + throw new ISE("Work order not present for stage[%s]", stageId); } - stageKernel.start(); - transitionStageKernel(stageId, ControllerStagePhase.READING_INPUT); + + doWithStageTracker(stageId, stageTracker -> { + if (stageTracker.getPhase() != ControllerStagePhase.NEW) { + throw new ISE("Cannot start the stage: [%s]", stageId); + } + + stageTracker.start(); + }); } /** @@ -450,9 +575,10 @@ public void finishStage(final StageId stageId, final boolean strict) if (strict && !effectivelyFinishedStages.contains(stageId)) { throw new IAE("Cannot mark the stage: [%s] finished", stageId); } - getStageKernelOrThrow(stageId).finish(); - effectivelyFinishedStages.remove(stageId); - transitionStageKernel(stageId, ControllerStagePhase.FINISHED); + doWithStageTracker(stageId, stageTracker -> { + stageTracker.finish(); + effectivelyFinishedStages.remove(stageId); + }); stageWorkOrders.remove(stageId); } @@ -461,7 +587,7 @@ public void finishStage(final StageId stageId, final boolean strict) */ public WorkerInputs getWorkerInputsForStage(final StageId stageId) { - return getStageKernelOrThrow(stageId).getWorkerInputs(); + return getStageTrackerOrThrow(stageId).getWorkerInputs(); } /** @@ -474,20 +600,17 @@ public void addPartialKeyStatisticsForStageAndWorker( final PartialKeyStatisticsInformation partialKeyStatisticsInformation ) { - ControllerStageTracker stageKernel = getStageKernelOrThrow(stageId); - ControllerStagePhase newPhase = stageKernel.addPartialKeyInformationForWorker( - workerNumber, - partialKeyStatisticsInformation - ); + doWithStageTracker(stageId, stageTracker -> + stageTracker.addPartialKeyInformationForWorker(workerNumber, partialKeyStatisticsInformation)); + } - // If the kernel phase has transitioned, we need to account for that. - switch (newPhase) { - case MERGING_STATISTICS: - case POST_READING: - case FAILED: - transitionStageKernel(stageId, newPhase); - break; - } + /** + * Delegates call to {@link ControllerStageTracker#addPartialKeyInformationForWorker(int, PartialKeyStatisticsInformation)}. + * If calling this causes transition for the stage kernel, then this gets registered in this query kernel + */ + public void setDoneReadingInputForStageAndWorker(final StageId stageId, final int workerNumber) + { + doWithStageTracker(stageId, stageTracker -> stageTracker.setDoneReadingInputForWorker(workerNumber)); } /** @@ -500,9 +623,7 @@ public void setResultsCompleteForStageAndWorker( final Object resultObject ) { - if (getStageKernelOrThrow(stageId).setResultsCompleteForWorker(workerNumber, resultObject)) { - transitionStageKernel(stageId, ControllerStagePhase.RESULTS_READY); - } + doWithStageTracker(stageId, stageTracker -> stageTracker.setResultsCompleteForWorker(workerNumber, resultObject)); } /** @@ -510,13 +631,7 @@ public void setResultsCompleteForStageAndWorker( */ public MSQFault getFailureReasonForStage(final StageId stageId) { - return getStageKernelOrThrow(stageId).getFailureReason(); - } - - public void failStageForReason(final StageId stageId, MSQFault fault) - { - getStageKernelOrThrow(stageId).failForReason(fault); - transitionStageKernel(stageId, ControllerStagePhase.FAILED); + return getStageTrackerOrThrow(stageId).getFailureReason(); } /** @@ -524,20 +639,19 @@ public void failStageForReason(final StageId stageId, MSQFault fault) */ public void failStage(final StageId stageId) { - getStageKernelOrThrow(stageId).fail(); - transitionStageKernel(stageId, ControllerStagePhase.FAILED); + doWithStageTracker(stageId, ControllerStageTracker::fail); } /** * Fetches and returns the stage kernel corresponding to the provided stage id, else throws {@link IAE} */ - private ControllerStageTracker getStageKernelOrThrow(StageId stageId) + private ControllerStageTracker getStageTrackerOrThrow(StageId stageId) { - ControllerStageTracker stageKernel = stageTracker.get(stageId); - if (stageKernel == null) { + ControllerStageTracker stageTracker = stageTrackers.get(stageId); + if (stageTracker == null) { throw new IAE("Cannot find kernel corresponding to stage [%s] in query [%s]", stageId, queryDef.getQueryId()); } - return stageKernel; + return stageTracker; } private WorkOrder getWorkOrder(int workerNumber, StageId stageId) @@ -556,99 +670,99 @@ private WorkOrder getWorkOrder(int workerNumber, StageId stageId) } /** - * Whenever a stage kernel changes its phase, the change must be "registered" by calling this method with the stageId - * and the new phase + * Whether a given stage is ready to stream results to consumer stages upon transition to "newPhase". */ - public void transitionStageKernel(StageId stageId, ControllerStagePhase newPhase) + private boolean readyToReadResults(final StageId stageId, final ControllerStagePhase newPhase) { - Preconditions.checkArgument( - stageTracker.containsKey(stageId), - "Attempting to modify an unknown stageKernel" - ); + if (stageOutputChannelModes.get(stageId) == OutputChannelMode.MEMORY) { + if (getStageDefinition(stageId).doesSortDuringShuffle()) { + // Stages that sort during shuffle go through a READING_INPUT phase followed by a POST_READING phase + // (once all input is read). These stages start producing output once POST_READING starts. + return newPhase == ControllerStagePhase.POST_READING; + } else { + // Can read results immediately. + return newPhase == ControllerStagePhase.NEW; + } + } else { + return newPhase == ControllerStagePhase.RESULTS_READY; + } + } + + private void doWithStageTracker(final StageId stageId, final Consumer fn) + { + final ControllerStageTracker stageTracker = getStageTrackerOrThrow(stageId); + final ControllerStagePhase phase = stageTracker.getPhase(); + fn.accept(stageTracker); + + if (phase != stageTracker.getPhase()) { + registerStagePhaseChange(stageId, stageTracker.getPhase()); + } + } - if (newPhase == ControllerStagePhase.RESULTS_READY) { - // Once the stage has produced its results, we remove it from all the stages depending on this stage (for its - // output). + /** + * Whenever a stage kernel changes its phase, the change must be "registered" by calling this method with the stageId + * and the new phase. + */ + private void registerStagePhaseChange(final StageId stageId, final ControllerStagePhase newPhase) + { + if (readyToReadResults(stageId, newPhase)) { + // Once results from a stage are readable, remove this stage from pendingInflowMap and potentially mark + // dependent stages as ready to run. for (StageId dependentStageId : outflowMap.get(stageId)) { if (!pendingInflowMap.containsKey(dependentStageId)) { continue; } pendingInflowMap.get(dependentStageId).remove(stageId); // Check the dependent stage. If it has no dependencies left, it can be marked as to be initialized - if (pendingInflowMap.get(dependentStageId).size() == 0) { + if (pendingInflowMap.get(dependentStageId).isEmpty()) { readyToRunStages.add(dependentStageId); pendingInflowMap.remove(dependentStageId); } } } - if (ControllerStagePhase.isPostReadingPhase(newPhase)) { - - // when fault tolerance is enabled, we cannot delete the input data eagerly as we need the input stage for retry until - // results for the current stage are ready. - if (faultToleranceEnabled && newPhase == ControllerStagePhase.POST_READING) { - return; - } - // Once the stage has consumed all the data/input from its dependent stages, we remove it from all the stages - // whose input it was dependent on + if (newPhase.isSuccess() || (!config.isFaultTolerant() && newPhase.isDoneReadingInput())) { + // Once a stage no longer needs its input, we consider marking input stages as finished. for (StageId inputStage : inflowMap.get(stageId)) { if (!pendingOutflowMap.containsKey(inputStage)) { continue; } pendingOutflowMap.get(inputStage).remove(stageId); - // If no more stage is dependent on the "inputStage's" results, it can be safely transitioned to FINISHED - if (pendingOutflowMap.get(inputStage).size() == 0) { - effectivelyFinishedStages.add(inputStage); + // If no more stage is dependent on the inputStage's results, it can be safely transitioned to FINISHED + if (pendingOutflowMap.get(inputStage).isEmpty()) { pendingOutflowMap.remove(inputStage); + + // Mark input stage as effectively finished, if it's ready to finish. + // This leads to a later transition to FINISHED. + if (ControllerStagePhase.FINISHED.canTransitionFrom(stageTrackers.get(inputStage).getPhase())) { + effectivelyFinishedStages.add(inputStage); + } } } } - } - @VisibleForTesting - ControllerStageTracker getControllerStageKernel(int stageNumber) - { - return stageTracker.get(new StageId(queryDef.getQueryId(), stageNumber)); - } - - /** - * Returns a mapping of stage -> stages that flow *into* that stage. - */ - private static Map> computeStageInflowMap(final QueryDefinition queryDefinition) - { - final Map> retVal = new HashMap<>(); + // Mark stage as effectively finished, if it has no dependencies waiting for it. + // This leads to a later transition to FINISHED. + final boolean hasDependentStages = + pendingOutflowMap.containsKey(stageId) && !pendingOutflowMap.get(stageId).isEmpty(); - for (final StageDefinition stageDef : queryDefinition.getStageDefinitions()) { - final StageId stageId = stageDef.getId(); - retVal.computeIfAbsent(stageId, ignored -> new HashSet<>()); + if (!hasDependentStages) { + final boolean isFinalStage = queryDef.getFinalStageDefinition().getId().equals(stageId); - for (final int inputStageNumber : queryDefinition.getStageDefinition(stageId).getInputStageNumbers()) { - final StageId inputStageId = new StageId(queryDefinition.getQueryId(), inputStageNumber); - retVal.computeIfAbsent(stageId, ignored -> new HashSet<>()).add(inputStageId); + if (isFinalStage && newPhase == ControllerStagePhase.RESULTS_READY) { + // Final stage must run to completion (RESULTS_READY). + effectivelyFinishedStages.add(stageId); + } else if (!isFinalStage && ControllerStagePhase.FINISHED.canTransitionFrom(newPhase)) { + // Other stages can exit early (e.g. if there is a LIMIT). + effectivelyFinishedStages.add(stageId); } } - - return retVal; } - /** - * Returns a mapping of stage -> stages that depend on that stage. - */ - private static Map> computeStageOutflowMap(final QueryDefinition queryDefinition) + @VisibleForTesting + ControllerStageTracker getControllerStageTracker(int stageNumber) { - final Map> retVal = new HashMap<>(); - - for (final StageDefinition stageDef : queryDefinition.getStageDefinitions()) { - final StageId stageId = stageDef.getId(); - retVal.computeIfAbsent(stageId, ignored -> new HashSet<>()); - - for (final int inputStageNumber : queryDefinition.getStageDefinition(stageId).getInputStageNumbers()) { - final StageId inputStageId = new StageId(queryDefinition.getQueryId(), inputStageNumber); - retVal.computeIfAbsent(inputStageId, ignored -> new HashSet<>()).add(stageId); - } - } - - return retVal; + return stageTrackers.get(new StageId(queryDef.getQueryId(), stageNumber)); } /** @@ -660,6 +774,7 @@ private static Map> computeStageOutflowMap(final QueryDefi * * @param workerNumber * @param msqFault + * * @return List of {@link WorkOrder} that needs to be retried. */ public List getWorkInCaseWorkerEligibleForRetryElseThrow(int workerNumber, MSQFault msqFault) @@ -691,23 +806,23 @@ public static boolean isRetriableFault(MSQFault msqFault) * If yes adds the workOrder for that stage to the return list and transitions the stage kernel to {@link ControllerStagePhase#RETRYING} * * @param worker + * * @return List of {@link WorkOrder} that needs to be retried. */ private List getWorkInCaseWorkerEligibleForRetry(int worker) { List trackedSet = new ArrayList<>(getActiveStages()); - trackedSet.removeAll(getEffectivelyFinishedStageIds()); + trackedSet.removeAll(effectivelyFinishedStages); List workOrders = new ArrayList<>(); for (StageId stageId : trackedSet) { - ControllerStageTracker controllerStageTracker = getStageKernelOrThrow(stageId); - if (ControllerStagePhase.RETRYING.canTransitionFrom(controllerStageTracker.getPhase()) - && controllerStageTracker.retryIfNeeded(worker)) { - workOrders.add(getWorkOrder(worker, stageId)); - // should be a no-op. - transitionStageKernel(stageId, ControllerStagePhase.RETRYING); - } + doWithStageTracker(stageId, stageTracker -> { + if (ControllerStagePhase.RETRYING.canTransitionFrom(stageTracker.getPhase()) + && stageTracker.retryIfNeeded(worker)) { + workOrders.add(getWorkOrder(worker, stageId)); + } + }); } return workOrders; } @@ -723,7 +838,7 @@ public Map> getStagesAndWorkersToFetchClusterStats() Map> stageToWorkers = new HashMap<>(); for (StageId stageId : trackedSet) { - ControllerStageTracker controllerStageTracker = getStageKernelOrThrow(stageId); + ControllerStageTracker controllerStageTracker = getStageTrackerOrThrow(stageId); if (controllerStageTracker.getStageDefinition().mustGatherResultKeyStatistics()) { stageToWorkers.put(stageId, controllerStageTracker.getWorkersToFetchClusterStatisticsFrom()); } @@ -737,11 +852,11 @@ public Map> getStagesAndWorkersToFetchClusterStats() */ public void startFetchingStatsFromWorker(StageId stageId, Set workers) { - ControllerStageTracker controllerStageTracker = getStageKernelOrThrow(stageId); - - for (int worker : workers) { - controllerStageTracker.startFetchingStatsFromWorker(worker); - } + doWithStageTracker(stageId, stageTracker -> { + for (int worker : workers) { + stageTracker.startFetchingStatsFromWorker(worker); + } + }); } /** @@ -753,10 +868,8 @@ public void mergeClusterByStatisticsCollectorForAllTimeChunks( ClusterByStatisticsSnapshot clusterByStatsSnapshot ) { - getStageKernelOrThrow(stageId).mergeClusterByStatisticsCollectorForAllTimeChunks( - workerNumber, - clusterByStatsSnapshot - ); + doWithStageTracker(stageId, stageTracker -> + stageTracker.mergeClusterByStatisticsCollectorForAllTimeChunks(workerNumber, clusterByStatsSnapshot)); } /** @@ -770,11 +883,8 @@ public void mergeClusterByStatisticsCollectorForTimeChunk( ClusterByStatisticsSnapshot clusterByStatsSnapshot ) { - getStageKernelOrThrow(stageId).mergeClusterByStatisticsCollectorForTimeChunk( - workerNumber, - timeChunk, - clusterByStatsSnapshot - ); + doWithStageTracker(stageId, stageTracker -> + stageTracker.mergeClusterByStatisticsCollectorForTimeChunk(workerNumber, timeChunk, clusterByStatsSnapshot)); } /** @@ -782,7 +892,7 @@ public void mergeClusterByStatisticsCollectorForTimeChunk( */ public boolean allPartialKeyInformationPresent(StageId stageId) { - return getStageKernelOrThrow(stageId).allPartialKeyInformationFetched(); + return getStageTrackerOrThrow(stageId).allPartialKeyInformationFetched(); } /** diff --git a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/kernel/controller/ControllerQueryKernelConfig.java b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/kernel/controller/ControllerQueryKernelConfig.java new file mode 100644 index 000000000000..5c754aedd4f4 --- /dev/null +++ b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/kernel/controller/ControllerQueryKernelConfig.java @@ -0,0 +1,260 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.druid.msq.kernel.controller; + +import org.apache.druid.java.util.common.IAE; +import org.apache.druid.msq.indexing.destination.MSQDestination; + +import javax.annotation.Nullable; +import java.util.List; +import java.util.Objects; + +/** + * Configuration for {@link ControllerQueryKernel}. + */ +public class ControllerQueryKernelConfig +{ + private final int maxRetainedPartitionSketchBytes; + private final int maxConcurrentStages; + private final boolean pipeline; + private final boolean durableStorage; + private final boolean faultTolerance; + private final MSQDestination destination; + + @Nullable + private final String controllerId; + + @Nullable + private final List workerIds; + + private ControllerQueryKernelConfig( + int maxRetainedPartitionSketchBytes, + int maxConcurrentStages, + boolean pipeline, + boolean durableStorage, + boolean faultTolerance, + MSQDestination destination, + @Nullable String controllerId, + @Nullable List workerIds + ) + { + if (maxRetainedPartitionSketchBytes <= 0) { + throw new IAE("maxRetainedPartitionSketchBytes must be positive"); + } + + if (pipeline && maxConcurrentStages < 2) { + throw new IAE("maxConcurrentStagesPerWorker must be >= 2 when pipelining"); + } + + if (maxConcurrentStages <= 0) { + throw new IAE("maxConcurrentStagesPerWorker must be positive"); + } + + if (pipeline && faultTolerance) { + throw new IAE("Cannot pipeline with fault tolerance"); + } + + if (pipeline && durableStorage) { + throw new IAE("Cannot pipeline with durable storage"); + } + + if (faultTolerance && !durableStorage) { + throw new IAE("Cannot have fault tolerance without durable storage"); + } + + this.maxRetainedPartitionSketchBytes = maxRetainedPartitionSketchBytes; + this.maxConcurrentStages = maxConcurrentStages; + this.pipeline = pipeline; + this.durableStorage = durableStorage; + this.faultTolerance = faultTolerance; + this.destination = destination; + this.controllerId = controllerId; + this.workerIds = workerIds; + } + + public static Builder builder() + { + return new Builder(); + } + + public int getMaxRetainedPartitionSketchBytes() + { + return maxRetainedPartitionSketchBytes; + } + + public int getMaxConcurrentStages() + { + return maxConcurrentStages; + } + + public boolean isPipeline() + { + return pipeline; + } + + public boolean isDurableStorage() + { + return durableStorage; + } + + public boolean isFaultTolerant() + { + return faultTolerance; + } + + public MSQDestination getDestination() + { + return destination; + } + + @Nullable + public List getWorkerIds() + { + return workerIds; + } + + @Override + public boolean equals(Object o) + { + if (this == o) { + return true; + } + if (o == null || getClass() != o.getClass()) { + return false; + } + ControllerQueryKernelConfig that = (ControllerQueryKernelConfig) o; + return maxRetainedPartitionSketchBytes == that.maxRetainedPartitionSketchBytes + && maxConcurrentStages == that.maxConcurrentStages + && pipeline == that.pipeline + && durableStorage == that.durableStorage + && faultTolerance == that.faultTolerance + && Objects.equals(controllerId, that.controllerId) + && Objects.equals(workerIds, that.workerIds); + } + + @Override + public int hashCode() + { + return Objects.hash( + maxRetainedPartitionSketchBytes, + maxConcurrentStages, + pipeline, + durableStorage, + faultTolerance, + controllerId, + workerIds + ); + } + + @Override + public String toString() + { + return "ControllerQueryKernelConfig{" + + "maxRetainedPartitionSketchBytes=" + maxRetainedPartitionSketchBytes + + ", maxConcurrentStages=" + maxConcurrentStages + + ", pipeline=" + pipeline + + ", durableStorage=" + durableStorage + + ", faultTolerant=" + faultTolerance + + ", controllerId='" + controllerId + '\'' + + ", workerIds=" + workerIds + + '}'; + } + + public static class Builder + { + private int maxRetainedPartitionSketchBytes = -1; + private int maxConcurrentStages = 1; + private boolean pipeline; + private boolean durableStorage; + private boolean faultTolerant; + private MSQDestination destination; + private String controllerId; + private List workerIds; + + /** + * Use {@link #builder()}. + */ + private Builder() + { + } + + public Builder maxRetainedPartitionSketchBytes(final int maxRetainedPartitionSketchBytes) + { + this.maxRetainedPartitionSketchBytes = maxRetainedPartitionSketchBytes; + return this; + } + + public Builder maxConcurrentStages(final int maxConcurrentStages) + { + this.maxConcurrentStages = maxConcurrentStages; + return this; + } + + public Builder pipeline(final boolean pipeline) + { + this.pipeline = pipeline; + return this; + } + + public Builder durableStorage(final boolean durableStorage) + { + this.durableStorage = durableStorage; + return this; + } + + public Builder faultTolerance(final boolean faultTolerant) + { + this.faultTolerant = faultTolerant; + return this; + } + + public Builder destination(final MSQDestination destination) + { + this.destination = destination; + return this; + } + + public Builder controllerId(final String controllerId) + { + this.controllerId = controllerId; + return this; + } + + public Builder workerIds(final List workerIds) + { + this.workerIds = workerIds; + return this; + } + + public ControllerQueryKernelConfig build() + { + return new ControllerQueryKernelConfig( + maxRetainedPartitionSketchBytes, + maxConcurrentStages, + pipeline, + durableStorage, + faultTolerant, + destination, + controllerId, + workerIds + ); + } + } +} diff --git a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/kernel/controller/ControllerQueryKernelUtils.java b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/kernel/controller/ControllerQueryKernelUtils.java new file mode 100644 index 000000000000..d971f33a9f2d --- /dev/null +++ b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/kernel/controller/ControllerQueryKernelUtils.java @@ -0,0 +1,406 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.druid.msq.kernel.controller; + +import com.google.common.collect.ImmutableList; +import com.google.common.collect.Iterables; +import it.unimi.dsi.fastutil.ints.IntSet; +import org.apache.druid.msq.exec.OutputChannelMode; +import org.apache.druid.msq.indexing.destination.MSQDestination; +import org.apache.druid.msq.indexing.destination.MSQSelectDestination; +import org.apache.druid.msq.input.InputSpec; +import org.apache.druid.msq.input.InputSpecs; +import org.apache.druid.msq.kernel.QueryDefinition; +import org.apache.druid.msq.kernel.StageDefinition; +import org.apache.druid.msq.kernel.StageId; + +import javax.annotation.Nullable; +import java.util.ArrayList; +import java.util.Collections; +import java.util.HashSet; +import java.util.List; +import java.util.Map; +import java.util.Set; +import java.util.SortedSet; +import java.util.TreeMap; +import java.util.TreeSet; + +/** + * Utilties for {@link ControllerQueryKernel}. + */ +public class ControllerQueryKernelUtils +{ + /** + * Put stages from {@link QueryDefinition} into groups that must each be launched simultaneously. + * + * This method's goal is to maximize the usage of {@link OutputChannelMode#MEMORY} channels, subject to constraints + * provided by {@link ControllerQueryKernelConfig#isPipeline()}, + * {@link ControllerQueryKernelConfig#getMaxConcurrentStages()}, and + * {@link ControllerQueryKernelConfig#isFaultTolerant()}. + * + * An important part of the logic here is determining the output channel mode of the final stage in a group, i.e. + * {@link StageGroup#lastStageOutputChannelMode()}. + * + * If the {@link StageGroup#lastStageOutputChannelMode()} is not {@link OutputChannelMode#MEMORY}, then the stage + * group is fully executed, and fully generates its output, prior to any downstream stage groups starting. + * + * On the other hand, if {@link StageGroup#lastStageOutputChannelMode()} is {@link OutputChannelMode#MEMORY}, the + * stage group executes up to such a point that the group's last stage has results ready-to-read; see + * {@link ControllerQueryKernel#readyToReadResults(StageId, ControllerStagePhase)}. A downstream stage group, if any, + * is started while the current group is still running. This enables them to transfer data in memory. + * + * Stage groups always end when some stage in them sorts during shuffle, i.e. returns true from + * ({@link StageDefinition#doesSortDuringShuffle()}). This enables "leapfrog" execution, where a sequence + * of sorting stages in separate groups can all run with {@link OutputChannelMode#MEMORY}, even when there are more + * stages than the maxConcurrentStages parameter. To achieve this, we wind down upstream stage groups prior to + * starting downstream stage groups, such that only two groups are ever running at a time. + * + * For example, consider a case where pipeline = true and maxConcurrentStages = 2, and the query has three stages, + * all of which sort during shuffle. The expected return from this method is a list of 3 stage groups, each with + * one stage, and each with {@link StageGroup#lastStageOutputChannelMode()} set to {@link OutputChannelMode#MEMORY}. + * To stay within maxConcurrentStages = 2, execution leapfrogs in the following manner (note- not all transitions + * are shown here, for brevity): + * + *

    + *
  1. Stage 0 starts
  2. + *
  3. Stage 0 enters {@link ControllerStagePhase#POST_READING}, finishes sorting
  4. + *
  5. Stage 1 enters {@link ControllerStagePhase#READING_INPUT}
  6. + *
  7. Stage 1 enters {@link ControllerStagePhase#POST_READING}, finishes sorting
  8. + *
  9. Stage 0 stops, ends in {@link ControllerStagePhase#FINISHED})
  10. + *
  11. Stage 2 starts
  12. + *
  13. Stage 2 enters {@link ControllerStagePhase#POST_READING}, finishes sorting
  14. + *
  15. Stage 1 stops, ends in {@link ControllerStagePhase#FINISHED})
  16. + *
  17. Stage 2 stops and query completes
  18. + *
+ * + * When maxConcurrentStages = 2, leapfrogging is only possible with stage groups containing a single stage each. + * When maxConcurrentStages > 2, leapfrogging can happen with larger stage groups containing longer chains. + */ + public static List computeStageGroups( + final QueryDefinition queryDef, + final ControllerQueryKernelConfig config + ) + { + final MSQDestination destination = config.getDestination(); + final List stageGroups = new ArrayList<>(); + final boolean useDurableStorage = config.isDurableStorage(); + final Map> inflow = computeStageInflowMap(queryDef); + final Map> outflow = computeStageOutflowMap(queryDef); + final Set stagesRun = new HashSet<>(); + + // This loop simulates execution of all stages, such that we arrive at an order of execution that is compatible + // with all relevant constraints. + + while (stagesRun.size() < queryDef.getStageDefinitions().size()) { + // 1) Find all stages that are ready to run, and that cannot use MEMORY output modes. Run them as solo groups. + boolean didRun; + do { + didRun = false; + + for (final StageId stageId : ImmutableList.copyOf(inflow.keySet())) { + if (!stagesRun.contains(stageId) /* stage has not run yet */ + && inflow.get(stageId).isEmpty() /* stage data is fully available */ + && !canUseMemoryOutput(queryDef, stageId.getStageNumber(), config, outflow)) { + stagesRun.add(stageId); + stageGroups.add( + new StageGroup( + Collections.singletonList(stageId), + getOutputChannelMode( + queryDef, + stageId.getStageNumber(), + destination.toSelectDestination(), + useDurableStorage, + false + ) + ) + ); + + // Simulate this stage finishing. + removeStageFlow(stageId, inflow, outflow); + didRun = true; + } + } + } while (didRun); + + // 2) Pick some stage that can use MEMORY output mode, and run that as well as all ready-to-run dependents. + StageId currentStageId = null; + for (final StageId stageId : ImmutableList.copyOf(inflow.keySet())) { + if (!stagesRun.contains(stageId) + && inflow.get(stageId).isEmpty() + && canUseMemoryOutput(queryDef, stageId.getStageNumber(), config, outflow)) { + currentStageId = stageId; + break; + } + } + + if (currentStageId == null) { + // Didn't find a stage that can use MEMORY output mode. + continue; + } + + // Found a stage that can use MEMORY output mode. Build a maximally-sized StageGroup around it. + final List currentStageGroup = new ArrayList<>(); + + // maxStageGroupSize => largest size this stage group can be while respecting maxConcurrentStages and leaving + // room for a priorGroup to run concurrently (if priorGroup uses MEMORY output mode). + final int maxStageGroupSize; + + if (stageGroups.isEmpty()) { + maxStageGroupSize = config.getMaxConcurrentStages(); + } else { + final StageGroup priorGroup = stageGroups.get(stageGroups.size() - 1); + if (priorGroup.lastStageOutputChannelMode() == OutputChannelMode.MEMORY) { + // Prior group runs concurrently with this group. (Can happen when leapfrogging; see class-level Javadoc.) + + // Note: priorGroup.size() is strictly less than config.getMaxConcurrentStages(), because the prior group + // would have its size limited by maxStageGroupSizeAllowingForDownstreamConsumer below. + + maxStageGroupSize = config.getMaxConcurrentStages() - priorGroup.size(); + } else { + // Prior group exits before this group starts. + maxStageGroupSize = config.getMaxConcurrentStages(); + } + } + + OutputChannelMode currentOutputChannelMode = null; + while (currentStageId != null) { + final boolean canUseMemoryOuput = + canUseMemoryOutput(queryDef, currentStageId.getStageNumber(), config, outflow); + final Set currentOutflow = outflow.get(currentStageId); + + // maxStageGroupSizeAllowingForDownstreamConsumer => largest size this stage group can be while still being + // able to use MEMORY output mode. (With MEMORY output mode, the downstream consumer must run concurrently.) + final int maxStageGroupSizeAllowingForDownstreamConsumer; + + if (queryDef.getStageDefinition(currentStageId).doesSortDuringShuffle()) { + // When the current group sorts, there's a pipeline break, so we can leapfrog: close the prior group before + // starting the downstream group. In this case, we only need to reserve a single concurrent-stage slot for + // a downstream consumer. + + // Note: the only stage that can possibly sort is the final stage, because of the check below that closes + // off the stage group if currentStageId "doesSortDuringShuffle()". + + maxStageGroupSizeAllowingForDownstreamConsumer = config.getMaxConcurrentStages() - 1; + } else { + // When the final stage in the current group doesn't sort, we can't leapfrog. We need to reserve a single + // concurrent-stage slot for a downstream consumer, plus enough space to keep the priorGroup running (which + // is accounted for in maxStageGroupSize). + maxStageGroupSizeAllowingForDownstreamConsumer = maxStageGroupSize - 1; + } + + currentOutputChannelMode = + getOutputChannelMode( + queryDef, + currentStageId.getStageNumber(), + config.getDestination().toSelectDestination(), + config.isDurableStorage(), + canUseMemoryOuput + + // Stages can only use MEMORY output mode if they have <= 1 downstream stage (checked by + // "canUseMemoryOutput") and if that downstream stage has all of its other inputs available. + && (currentOutflow.isEmpty() || + Collections.singleton(currentStageId) + .equals(inflow.get(Iterables.getOnlyElement(currentOutflow)))) + + // And, stages can only use MEMORY output mode if their downstream consumer is able to start + // concurrently. So, once the stage group gets too big, we must stop using MEMORY output mode. + && (currentOutflow.isEmpty() + || currentStageGroup.size() < maxStageGroupSizeAllowingForDownstreamConsumer) + ); + + currentStageGroup.add(currentStageId); + + if (currentOutflow.size() == 1 + && currentStageGroup.size() < maxStageGroupSize + && currentOutputChannelMode == OutputChannelMode.MEMORY + + // Sorting causes a pipeline break: a sorting stage must read all its input before writing any output. + // Continue the stage group only if this stage does not sort its output. + && !queryDef.getStageDefinition(currentStageId).doesSortDuringShuffle()) { + currentStageId = Iterables.getOnlyElement(currentOutflow); + } else { + currentStageId = null; + } + } + + stageGroups.add(new StageGroup(currentStageGroup, currentOutputChannelMode)); + + // Simulate this stage group finishing. + for (final StageId stageId : currentStageGroup) { + stagesRun.add(stageId); + removeStageFlow(stageId, inflow, outflow); + } + } + + return stageGroups; + } + + /** + * Returns a mapping of stage -> stages that flow *into* that stage. Uses TreeMaps and TreeSets so we have a + * consistent order for analyzing and running stages. + */ + public static Map> computeStageInflowMap(final QueryDefinition queryDefinition) + { + final Map> retVal = new TreeMap<>(); + + for (final StageDefinition stageDef : queryDefinition.getStageDefinitions()) { + final StageId stageId = stageDef.getId(); + retVal.computeIfAbsent(stageId, ignored -> new TreeSet<>()); + + for (final int inputStageNumber : queryDefinition.getStageDefinition(stageId).getInputStageNumbers()) { + final StageId inputStageId = new StageId(queryDefinition.getQueryId(), inputStageNumber); + retVal.computeIfAbsent(stageId, ignored -> new TreeSet<>()).add(inputStageId); + } + } + + return retVal; + } + + /** + * Returns a mapping of stage -> stages that depend on that stage. Uses TreeMaps and TreeSets so we have a consistent + * order for analyzing and running stages. + */ + public static Map> computeStageOutflowMap(final QueryDefinition queryDefinition) + { + final Map> retVal = new TreeMap<>(); + + for (final StageDefinition stageDef : queryDefinition.getStageDefinitions()) { + final StageId stageId = stageDef.getId(); + retVal.computeIfAbsent(stageId, ignored -> new TreeSet<>()); + + for (final int inputStageNumber : queryDefinition.getStageDefinition(stageId).getInputStageNumbers()) { + final StageId inputStageId = new StageId(queryDefinition.getQueryId(), inputStageNumber); + retVal.computeIfAbsent(inputStageId, ignored -> new TreeSet<>()).add(stageId); + } + } + + return retVal; + } + + /** + * Whether output of a stage can possibly use {@link OutputChannelMode#MEMORY}. Returning true does not guarantee + * that the stage *will* use {@link OutputChannelMode#MEMORY}. Additional requirements are checked in + * {@link #computeStageGroups(QueryDefinition, ControllerQueryKernelConfig)}. + */ + public static boolean canUseMemoryOutput( + final QueryDefinition queryDefinition, + final int stageNumber, + final ControllerQueryKernelConfig config, + final Map> outflowMap + ) + { + if (config.isFaultTolerant()) { + // Cannot use MEMORY output mode if fault tolerance is enabled: durable storage is required. + return false; + } + + if (!config.isPipeline() || config.getMaxConcurrentStages() < 2) { + // Cannot use MEMORY output mode if pipelining (& running multiple stages at once) is not enabled. + return false; + } + + final StageId stageId = queryDefinition.getStageDefinition(stageNumber).getId(); + final Set outflowStageIds = outflowMap.get(stageId); + + if (outflowStageIds.isEmpty()) { + return true; + } else if (outflowStageIds.size() == 1) { + final StageDefinition outflowStageDef = + queryDefinition.getStageDefinition(Iterables.getOnlyElement(outflowStageIds)); + + // Two things happening here: + // 1) Stages cannot use output mode MEMORY when broadcasting. This is because when using output mode MEMORY, we + // can only support a single reader. + // 2) Downstream stages can only have a single input stage with output mode MEMORY. This isn't strictly + // necessary, but it simplifies the logic around concurrently launching stages. + return stageId.equals(getOnlyNonBroadcastInputAsStageId(outflowStageDef)); + } else { + return false; + } + } + + /** + * Return an {@link OutputChannelMode} to use for a given stage, based on query and context. + */ + public static OutputChannelMode getOutputChannelMode( + final QueryDefinition queryDef, + final int stageNumber, + @Nullable final MSQSelectDestination selectDestination, + final boolean durableStorage, + final boolean canStream + ) + { + final boolean isFinalStage = queryDef.getFinalStageDefinition().getStageNumber() == stageNumber; + + if (isFinalStage && selectDestination == MSQSelectDestination.DURABLESTORAGE) { + return OutputChannelMode.DURABLE_STORAGE_QUERY_RESULTS; + } else if (canStream) { + return OutputChannelMode.MEMORY; + } else if (durableStorage) { + return OutputChannelMode.DURABLE_STORAGE_INTERMEDIATE; + } else { + return OutputChannelMode.LOCAL_STORAGE; + } + } + + /** + * If a stage has a single non-broadcast input stage, returns that input stage. Otherwise, returns null. + * This is a helper used by {@link #canUseMemoryOutput}. + */ + @Nullable + public static StageId getOnlyNonBroadcastInputAsStageId(final StageDefinition downstreamStageDef) + { + final List inputSpecs = downstreamStageDef.getInputSpecs(); + final IntSet broadcastInputNumbers = downstreamStageDef.getBroadcastInputNumbers(); + + if (inputSpecs.size() - broadcastInputNumbers.size() != 1) { + return null; + } + + for (int i = 0; i < inputSpecs.size(); i++) { + if (!broadcastInputNumbers.contains(i)) { + final IntSet stageNumbers = InputSpecs.getStageNumbers(Collections.singletonList(inputSpecs.get(i))); + if (stageNumbers.size() == 1) { + return new StageId(downstreamStageDef.getId().getQueryId(), stageNumbers.iterator().nextInt()); + } + } + } + + return null; + } + + /** + * Remove all outflow from "stageId". At the conclusion of this method, "outflow" has an empty set for "stageId", + * and no sets in "inflow" reference "stageId". Outflow and inflow sets may become empty as a result of this + * operation. Sets that become empty are left empty, not removed from the inflow and outflow maps. + */ + private static void removeStageFlow( + final StageId stageId, + final Map> inflow, + final Map> outflow + ) + { + for (final StageId outStageId : outflow.get(stageId)) { + inflow.get(outStageId).remove(stageId); + } + + outflow.get(stageId).clear(); + } +} diff --git a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/kernel/controller/ControllerStagePhase.java b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/kernel/controller/ControllerStagePhase.java index 3f8f3d19b3f8..eb124ab5b9f3 100644 --- a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/kernel/controller/ControllerStagePhase.java +++ b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/kernel/controller/ControllerStagePhase.java @@ -19,9 +19,12 @@ package org.apache.druid.msq.kernel.controller; -import com.google.common.collect.ImmutableSet; - -import java.util.Set; +import org.apache.druid.msq.exec.ClusterStatisticsMergeMode; +import org.apache.druid.msq.exec.OutputChannelMode; +import org.apache.druid.msq.kernel.ShuffleKind; +import org.apache.druid.msq.kernel.ShuffleSpec; +import org.apache.druid.msq.kernel.StageDefinition; +import org.apache.druid.msq.statistics.PartialKeyStatisticsInformation; /** * Phases that a stage can be in, as far as the controller is concerned. @@ -30,7 +33,12 @@ */ public enum ControllerStagePhase { - // Not doing anything yet. Just recently initialized. + /** + * Not doing anything yet. Just recently initialized. + * + * When using {@link OutputChannelMode#MEMORY}, entering this phase tells us that it is time to launch the consumer + * stage (see {@link ControllerQueryKernel#readyToReadResults}). + */ NEW { @Override public boolean canTransitionFrom(final ControllerStagePhase priorPhase) @@ -39,7 +47,12 @@ public boolean canTransitionFrom(final ControllerStagePhase priorPhase) } }, - // Reading and mapping inputs (using "stateless" operators like filters, transforms which operate on individual records). + /** + * Reading inputs. + * + * Stages may transition directly from here to {@link #RESULTS_READY}, or they may go through + * {@link #MERGING_STATISTICS} and {@link #POST_READING}, depending on the {@link ShuffleSpec}. + */ READING_INPUT { @Override public boolean canTransitionFrom(final ControllerStagePhase priorPhase) @@ -48,12 +61,16 @@ public boolean canTransitionFrom(final ControllerStagePhase priorPhase) } }, - // Waiting to fetch key statistics in the background from the workers and incrementally generate partitions. - // This phase is only transitioned to once all partialKeyInformation are received from workers. - // Transitioning to this phase should also enqueue the task to fetch key statistics if `SEQUENTIAL` strategy is used. - // In `PARALLEL` strategy, we start fetching the key statistics as soon as they are available on the worker. - // This stage is not required in non-pre shuffle contexts - + /** + * Waiting to fetch key statistics in the background from the workers and incrementally generate partitions. + * + * This phase is only transitioned to once all {@link PartialKeyStatisticsInformation} are received from workers. + * Transitioning to this phase should also enqueue the task to fetch key statistics if + * {@link ClusterStatisticsMergeMode#SEQUENTIAL} strategy is used. In {@link ClusterStatisticsMergeMode#PARALLEL} + * strategy, we start fetching the key statistics as soon as they are available on the worker. + * + * This stage is used if, and only if, {@link StageDefinition#mustGatherResultKeyStatistics()}. + */ MERGING_STATISTICS { @Override public boolean canTransitionFrom(final ControllerStagePhase priorPhase) @@ -62,18 +79,29 @@ public boolean canTransitionFrom(final ControllerStagePhase priorPhase) } }, - // Post the inputs have been read and mapped to frames, in the `POST_READING` stage, we pre-shuffle and determining the partition boundaries. - // This step for a stage spits out the statistics of the data as a whole (and not just the individual records). This - // phase is not required in non-pre shuffle contexts. + /** + * Inputs have been completely read, and sorting is in progress. + * + * When using {@link OutputChannelMode#MEMORY} with {@link StageDefinition#doesSortDuringShuffle()}, entering this + * phase tells us that it is time to launch the consumer stage (see {@link ControllerQueryKernel#readyToReadResults}). + * + * This phase is only used when {@link ShuffleKind#isSort()}. Note that it may not *always* be used even when sorting; + * for example, when not using {@link OutputChannelMode#MEMORY} and also not gathering statistics + * ({@link StageDefinition#mustGatherResultKeyStatistics()}), a stage phase may transition directly from + * {@link #READING_INPUT} to {@link #RESULTS_READY}. + */ POST_READING { @Override public boolean canTransitionFrom(final ControllerStagePhase priorPhase) { - return priorPhase == MERGING_STATISTICS; + return priorPhase == READING_INPUT /* when sorting locally */ + || priorPhase == MERGING_STATISTICS /* when sorting globally */; } }, - // Done doing work and all results have been generated. + /** + * Done doing work, and all results have been generated. + */ RESULTS_READY { @Override public boolean canTransitionFrom(final ControllerStagePhase priorPhase) @@ -82,17 +110,25 @@ public boolean canTransitionFrom(final ControllerStagePhase priorPhase) } }, - // The worker outputs for this stage might have been cleaned up in the workers, and they cannot be used by - // any other phase. "Metadata" for the stage such as counters are still available however + /** + * Stage has completed successfully and has been cleaned up. Worker outputs for this stage are no longer + * available and cannot be used by any other stage. Metadata such as counters are still available. + * + * Any non-terminal phase can transition to FINISHED. This can even happen prior to RESULTS_READY, if the + * controller determines that the outputs of the stage are no longer needed. For example, this happens when + * a downstream consumer is reading with limit, and decides it's finished processing. + */ FINISHED { @Override public boolean canTransitionFrom(final ControllerStagePhase priorPhase) { - return priorPhase == RESULTS_READY; + return !priorPhase.isTerminal(); } }, - // Something went wrong. + /** + * Something went wrong. + */ FAILED { @Override public boolean canTransitionFrom(final ControllerStagePhase priorPhase) @@ -101,9 +137,11 @@ public boolean canTransitionFrom(final ControllerStagePhase priorPhase) } }, - // Stages whose workers are currently under relaunch. We can transition out of Retrying state only when all the work orders - // of this stage have been sent. - // We can transition into Retrying phase when the prior phase did not publish its final results yet. + /** + * Stages whose workers are currently under relaunch. We can transition out of this phase only when all the work + * orders of this stage have been sent. We can transition into this phase when the prior phase did not + * publish its final results yet. + */ RETRYING { @Override public boolean canTransitionFrom(final ControllerStagePhase priorPhase) @@ -117,30 +155,40 @@ public boolean canTransitionFrom(final ControllerStagePhase priorPhase) public abstract boolean canTransitionFrom(ControllerStagePhase priorPhase); - private static final Set TERMINAL_PHASES = ImmutableSet.of( - RESULTS_READY, - FINISHED - ); + /** + * Whether this phase indicates that the stage has been started and is still running. (It hasn't been cleaned up + * or failed yet.) + */ + public boolean isRunning() + { + return this == READING_INPUT + || this == MERGING_STATISTICS + || this == POST_READING + || this == RESULTS_READY + || this == RETRYING; + } /** - * @return true if the phase indicates that the stage has completed its work and produced results successfully + * Whether this phase indicates that the stage has consumed its inputs from the previous stages successfully. */ - public static boolean isSuccessfulTerminalPhase(final ControllerStagePhase phase) + public boolean isDoneReadingInput() { - return TERMINAL_PHASES.contains(phase); + return this == POST_READING || this == RESULTS_READY || this == FINISHED; } - private static final Set POST_READING_PHASES = ImmutableSet.of( - POST_READING, - RESULTS_READY, - FINISHED - ); + /** + * Whether this phase indicates that the stage has completed its work and produced results successfully. + */ + public boolean isSuccess() + { + return this == RESULTS_READY || this == FINISHED; + } /** - * @return true if the phase indicates that the stage has consumed its inputs from the previous stages successfully + * Whether this phase indicates that the stage is no longer running. */ - public static boolean isPostReadingPhase(final ControllerStagePhase phase) + public boolean isTerminal() { - return POST_READING_PHASES.contains(phase); + return this == FINISHED || this == FAILED; } } diff --git a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/kernel/controller/ControllerStageTracker.java b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/kernel/controller/ControllerStageTracker.java index e0190bfacb34..0a62ba24b639 100644 --- a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/kernel/controller/ControllerStageTracker.java +++ b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/kernel/controller/ControllerStageTracker.java @@ -26,11 +26,13 @@ import it.unimi.dsi.fastutil.ints.Int2ObjectOpenHashMap; import it.unimi.dsi.fastutil.ints.IntAVLTreeSet; import it.unimi.dsi.fastutil.ints.IntSet; +import org.apache.druid.error.DruidException; import org.apache.druid.frame.key.ClusterByPartition; import org.apache.druid.frame.key.ClusterByPartitions; import org.apache.druid.java.util.common.Either; import org.apache.druid.java.util.common.IAE; import org.apache.druid.java.util.common.ISE; +import org.apache.druid.java.util.common.StringUtils; import org.apache.druid.java.util.common.granularity.Granularities; import org.apache.druid.java.util.common.logger.Logger; import org.apache.druid.msq.exec.ClusterStatisticsMergeMode; @@ -171,7 +173,14 @@ static ControllerStageTracker create( final long maxInputBytesPerWorker ) { - final WorkerInputs workerInputs = WorkerInputs.create(stageDef, stageWorkerCountMap, slicer, assignmentStrategy, maxInputBytesPerWorker); + final WorkerInputs workerInputs = WorkerInputs.create( + stageDef, + stageWorkerCountMap, + slicer, + assignmentStrategy, + maxInputBytesPerWorker + ); + return new ControllerStageTracker( stageDef, workerInputs, @@ -331,12 +340,15 @@ boolean collectorEncounteredAnyMultiValueField() */ Object getResultObject() { - if (phase == ControllerStagePhase.FINISHED) { - throw new ISE("Result object has been cleaned up prematurely"); - } else if (phase != ControllerStagePhase.RESULTS_READY) { - throw new ISE("Result object is not ready yet"); + if (!phase.isSuccess()) { + throw new ISE("Result object for stage[%s] is not ready yet", stageDef.getId()); } else if (resultObject == null) { - throw new NullPointerException("resultObject was unexpectedly null"); + throw new NullPointerException( + StringUtils.format( + "Result object for stage[%s] was unexpectedly null", + stageDef.getId() + ) + ); } else { return resultObject; } @@ -382,7 +394,7 @@ public CompleteKeyStatisticsInformation getCompleteKeyStatisticsInformation() * @param workerNumber the worker * @param partialKeyStatisticsInformation partial key statistics */ - ControllerStagePhase addPartialKeyInformationForWorker( + void addPartialKeyInformationForWorker( final int workerNumber, final PartialKeyStatisticsInformation partialKeyStatisticsInformation ) @@ -412,7 +424,7 @@ ControllerStagePhase addPartialKeyInformationForWorker( if (partialKeyStatisticsInformation.getTimeSegments().contains(null)) { // Time should not contain null value failForReason(InsertTimeNullFault.instance()); - return getPhase(); + return; } completeKeyStatisticsInformation.mergePartialInformation(workerNumber, partialKeyStatisticsInformation); } @@ -470,7 +482,6 @@ ControllerStagePhase addPartialKeyInformationForWorker( fail(); throw e; } - return getPhase(); } private void initializeTimeChunkWorkerTrackers() @@ -502,7 +513,6 @@ private void initializeTimeChunkWorkerTrackers() *

* If all the stats from all the workers are merged, we transition the stage to {@link ControllerStagePhase#POST_READING} */ - void mergeClusterByStatisticsCollectorForTimeChunk( int workerNumber, Long timeChunk, @@ -762,6 +772,58 @@ void setClusterByPartitionBoundaries(ClusterByPartitions clusterByPartitions) transitionTo(ControllerStagePhase.POST_READING); } + /** + * Transitions phase directly from {@link ControllerStagePhase#READING_INPUT} to + * {@link ControllerStagePhase#POST_READING}, skipping {@link ControllerStagePhase#MERGING_STATISTICS}. + * This method is used for stages that sort but do not need to gather result key statistics. + */ + void setDoneReadingInputForWorker(final int workerNumber) + { + if (stageDef.mustGatherResultKeyStatistics()) { + throw DruidException.defensive( + "Cannot setDoneReadingInput for stage[%s], it should send partial key information instead", + stageDef.getId() + ); + } + + if (!stageDef.doesSortDuringShuffle()) { + throw DruidException.defensive("Cannot setDoneReadingInput for stage[%s], it is not sorting", stageDef.getId()); + } + + if (workerNumber < 0 || workerNumber >= workerCount) { + throw new IAE("Invalid workerNumber[%s] for stage[%s]", workerNumber, stageDef.getId()); + } + + ControllerWorkerStagePhase currentPhase = workerToPhase.get(workerNumber); + + if (currentPhase == null) { + throw new ISE("Worker[%d] not found for stage[%s]", workerNumber, stageDef.getId()); + } + + try { + if (ControllerWorkerStagePhase.PRESHUFFLE_WRITING_OUTPUT.canTransitionFrom(currentPhase)) { + workerToPhase.put(workerNumber, ControllerWorkerStagePhase.PRESHUFFLE_WRITING_OUTPUT); + + if (allWorkersDoneReadingInput()) { + transitionTo(ControllerStagePhase.POST_READING); + } + } else { + throw new ISE( + "Worker[%d] for stage[%d] expected to be in phase that can transition to[%s]. Found phase[%s]", + workerNumber, + stageDef.getStageNumber(), + ControllerWorkerStagePhase.PRESHUFFLE_WRITING_OUTPUT, + currentPhase + ); + } + } + catch (Exception e) { + // If this op fails, we're in an inconsistent state and must cancel the stage. + fail(); + throw e; + } + } + /** * Accepts and sets the results that each worker produces for this particular stage * @@ -937,6 +999,21 @@ public boolean allPartialKeyInformationFetched() == workerCount; } + /** + * True if all workers are done reading their inputs. + */ + public boolean allWorkersDoneReadingInput() + { + for (final ControllerWorkerStagePhase phase : workerToPhase.values()) { + if (phase != ControllerWorkerStagePhase.PRESHUFFLE_WRITING_OUTPUT + && phase != ControllerWorkerStagePhase.RESULTS_READY) { + return false; + } + } + + return true; + } + /** * True if all {@link org.apache.druid.msq.kernel.WorkOrder} are sent else false. */ @@ -973,7 +1050,7 @@ private void transitionTo(final ControllerStagePhase newPhase) if (newPhase.canTransitionFrom(phase)) { phase = newPhase; } else { - throw new IAE("Cannot transition from [%s] to [%s]", phase, newPhase); + throw new IAE("Cannot transition stage[%s] from[%s] to[%s]", stageDef.getId(), phase, newPhase); } } diff --git a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/kernel/controller/ControllerWorkerStagePhase.java b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/kernel/controller/ControllerWorkerStagePhase.java index 1c3e370dc80e..89eca8f83755 100644 --- a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/kernel/controller/ControllerWorkerStagePhase.java +++ b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/kernel/controller/ControllerWorkerStagePhase.java @@ -69,7 +69,8 @@ public boolean canTransitionFrom(final ControllerWorkerStagePhase priorPhase) @Override public boolean canTransitionFrom(final ControllerWorkerStagePhase priorPhase) { - return priorPhase == PRESHUFFLE_WAITING_FOR_RESULT_PARTITION_BOUNDARIES; + return priorPhase == READING_INPUT /* when sorting locally */ + || priorPhase == PRESHUFFLE_WAITING_FOR_RESULT_PARTITION_BOUNDARIES /* when sorting globally */; } }, diff --git a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/kernel/controller/StageGroup.java b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/kernel/controller/StageGroup.java new file mode 100644 index 000000000000..f58eb3ee9c39 --- /dev/null +++ b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/kernel/controller/StageGroup.java @@ -0,0 +1,133 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.druid.msq.kernel.controller; + +import org.apache.druid.java.util.common.IAE; +import org.apache.druid.msq.exec.OutputChannelMode; +import org.apache.druid.msq.kernel.StageId; + +import java.util.List; +import java.util.Objects; + +/** + * Group of stages that must be launched as a unit. Within each group, stages communicate with each other using + * {@link OutputChannelMode#MEMORY} channels. The final stage in a group writes its own output using + * {@link #lastStageOutputChannelMode()}. + * + * Stages in a group have linear (non-branching) data flow: the first stage is an input to the second stage, the second + * stage is an input to the third stage, and so on. This is done to simplify the logic. In the future, it is possible + * that stage groups may contain branching data flow. + */ +public class StageGroup +{ + private final List stageIds; + private final OutputChannelMode groupOutputChannelMode; + + public StageGroup(final List stageIds, final OutputChannelMode groupOutputChannelMode) + { + this.stageIds = stageIds; + this.groupOutputChannelMode = groupOutputChannelMode; + } + + /** + * List of stage IDs in this group. + * + * The first stage is an input to the second stage, the second stage is an input to the third stage, and so on. + * See class-level javadocs for more details. + */ + public List stageIds() + { + return stageIds; + } + + /** + * Output mode of the final stage in this group. + */ + public OutputChannelMode lastStageOutputChannelMode() + { + return stageOutputChannelMode(last()); + } + + /** + * Output mode of the given stage. + */ + public OutputChannelMode stageOutputChannelMode(final StageId stageId) + { + if (last().equals(stageId)) { + return groupOutputChannelMode; + } else if (stageIds.contains(stageId)) { + return OutputChannelMode.MEMORY; + } else { + throw new IAE("Stage[%s] not in group", stageId); + } + } + + /** + * First stage in this group. + */ + public StageId first() + { + return stageIds.get(0); + } + + /** + * Last stage in this group. + */ + public StageId last() + { + return stageIds.get(stageIds.size() - 1); + } + + /** + * Number of stages in this group. + */ + public int size() + { + return stageIds.size(); + } + + @Override + public boolean equals(Object o) + { + if (this == o) { + return true; + } + if (o == null || getClass() != o.getClass()) { + return false; + } + StageGroup that = (StageGroup) o; + return Objects.equals(stageIds, that.stageIds) && groupOutputChannelMode == that.groupOutputChannelMode; + } + + @Override + public int hashCode() + { + return Objects.hash(stageIds, groupOutputChannelMode); + } + + @Override + public String toString() + { + return "StageGroup{" + + "stageIds=" + stageIds + + ", groupOutputChannelMode=" + groupOutputChannelMode + + '}'; + } +} diff --git a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/querykit/BaseLeafFrameProcessorFactory.java b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/querykit/BaseLeafFrameProcessorFactory.java index ed0807475ef6..09c48aa942c9 100644 --- a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/querykit/BaseLeafFrameProcessorFactory.java +++ b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/querykit/BaseLeafFrameProcessorFactory.java @@ -125,8 +125,7 @@ public ProcessorsAndChannels makeProcessors( final OutputChannel outputChannel = outputChannelFactory.openChannel(0 /* Partition number doesn't matter */); outputChannels.add(outputChannel); channelQueue.add(outputChannel.getWritableChannel()); - frameWriterFactoryQueue.add(stageDefinition.createFrameWriterFactory(outputChannel.getFrameMemoryAllocator()) - ); + frameWriterFactoryQueue.add(stageDefinition.createFrameWriterFactory(outputChannel.getFrameMemoryAllocator())); } diff --git a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/querykit/DataSourcePlan.java b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/querykit/DataSourcePlan.java index c532dcee56e8..831c9b139d3d 100644 --- a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/querykit/DataSourcePlan.java +++ b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/querykit/DataSourcePlan.java @@ -555,7 +555,7 @@ private static DataSourcePlan forUnion( // This is done to prevent loss of generality since MSQ can plan any type of DataSource. List children = unionDataSource.getDataSources(); - final QueryDefinitionBuilder subqueryDefBuilder = QueryDefinition.builder(); + final QueryDefinitionBuilder subqueryDefBuilder = QueryDefinition.builder(queryId); final List newChildren = new ArrayList<>(); final List inputSpecs = new ArrayList<>(); final IntSet broadcastInputs = new IntOpenHashSet(); @@ -605,7 +605,7 @@ private static DataSourcePlan forBroadcastHashJoin( final boolean broadcast ) { - final QueryDefinitionBuilder subQueryDefBuilder = QueryDefinition.builder(); + final QueryDefinitionBuilder subQueryDefBuilder = QueryDefinition.builder(queryId); final DataSourceAnalysis analysis = dataSource.getAnalysis(); final DataSourcePlan basePlan = forDataSource( @@ -683,7 +683,7 @@ private static DataSourcePlan forSortMergeJoin( SortMergeJoinFrameProcessorFactory.validateCondition(dataSource.getConditionAnalysis()) ); - final QueryDefinitionBuilder subQueryDefBuilder = QueryDefinition.builder(); + final QueryDefinitionBuilder subQueryDefBuilder = QueryDefinition.builder(queryId); // Plan the left input. // We're confident that we can cast dataSource.getLeft() to QueryDataSource, because DruidJoinQueryRel creates diff --git a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/querykit/ShuffleSpecFactory.java b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/querykit/ShuffleSpecFactory.java index b9c0f1a0d262..85c0c14e16e9 100644 --- a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/querykit/ShuffleSpecFactory.java +++ b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/querykit/ShuffleSpecFactory.java @@ -20,6 +20,7 @@ package org.apache.druid.msq.querykit; import org.apache.druid.frame.key.ClusterBy; +import org.apache.druid.msq.kernel.GlobalSortShuffleSpec; import org.apache.druid.msq.kernel.ShuffleSpec; /** @@ -29,7 +30,7 @@ public interface ShuffleSpecFactory { /** * Build a {@link ShuffleSpec} for given {@link ClusterBy}. The {@code aggregate} flag is used to populate - * {@link ShuffleSpec#doesAggregate()}. + * {@link GlobalSortShuffleSpec#doesAggregate()}. */ ShuffleSpec build(ClusterBy clusterBy, boolean aggregate); } diff --git a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/querykit/WindowOperatorQueryKit.java b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/querykit/WindowOperatorQueryKit.java index 0bbe8eb91aed..d08d78ef791f 100644 --- a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/querykit/WindowOperatorQueryKit.java +++ b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/querykit/WindowOperatorQueryKit.java @@ -76,7 +76,7 @@ public QueryDefinition makeQueryDefinition( ShuffleSpec nextShuffleSpec = findShuffleSpecForNextWindow(operatorList.get(0), maxWorkerCount); // add this shuffle spec to the last stage of the inner query - final QueryDefinitionBuilder queryDefBuilder = QueryDefinition.builder().queryId(queryId); + final QueryDefinitionBuilder queryDefBuilder = QueryDefinition.builder(queryId); if (nextShuffleSpec != null) { final ClusterBy windowClusterBy = nextShuffleSpec.clusterBy(); originalQuery = (WindowOperatorQuery) originalQuery.withOverriddenContext(ImmutableMap.of( @@ -178,7 +178,7 @@ public QueryDefinition makeQueryDefinition( ); } } - return queryDefBuilder.queryId(queryId).build(); + return queryDefBuilder.build(); } /** diff --git a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/querykit/groupby/GroupByQueryKit.java b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/querykit/groupby/GroupByQueryKit.java index 96b4b77f159b..f02e505d0c5a 100644 --- a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/querykit/groupby/GroupByQueryKit.java +++ b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/querykit/groupby/GroupByQueryKit.java @@ -78,7 +78,7 @@ public QueryDefinition makeQueryDefinition( { validateQuery(originalQuery); - final QueryDefinitionBuilder queryDefBuilder = QueryDefinition.builder().queryId(queryId); + final QueryDefinitionBuilder queryDefBuilder = QueryDefinition.builder(queryId); final DataSourcePlan dataSourcePlan = DataSourcePlan.forDataSource( queryKit, queryId, @@ -240,7 +240,7 @@ public QueryDefinition makeQueryDefinition( } } - return queryDefBuilder.queryId(queryId).build(); + return queryDefBuilder.build(); } /** diff --git a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/querykit/scan/ScanQueryKit.java b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/querykit/scan/ScanQueryKit.java index 8bc6f0bfa96d..2927264382a4 100644 --- a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/querykit/scan/ScanQueryKit.java +++ b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/querykit/scan/ScanQueryKit.java @@ -92,7 +92,7 @@ public QueryDefinition makeQueryDefinition( final int minStageNumber ) { - final QueryDefinitionBuilder queryDefBuilder = QueryDefinition.builder().queryId(queryId); + final QueryDefinitionBuilder queryDefBuilder = QueryDefinition.builder(queryId); final DataSourcePlan dataSourcePlan = DataSourcePlan.forDataSource( queryKit, queryId, diff --git a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/rpc/BaseWorkerClientImpl.java b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/rpc/BaseWorkerClientImpl.java new file mode 100644 index 000000000000..92042d59a8a8 --- /dev/null +++ b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/rpc/BaseWorkerClientImpl.java @@ -0,0 +1,270 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.druid.msq.rpc; + +import com.fasterxml.jackson.core.type.TypeReference; +import com.fasterxml.jackson.databind.ObjectMapper; +import com.google.common.util.concurrent.FutureCallback; +import com.google.common.util.concurrent.Futures; +import com.google.common.util.concurrent.ListenableFuture; +import com.google.common.util.concurrent.SettableFuture; +import org.apache.druid.common.guava.FutureUtils; +import org.apache.druid.frame.channel.ReadableByteChunksFrameChannel; +import org.apache.druid.frame.file.FrameFileHttpResponseHandler; +import org.apache.druid.frame.file.FrameFilePartialFetch; +import org.apache.druid.frame.key.ClusterByPartitions; +import org.apache.druid.java.util.common.StringUtils; +import org.apache.druid.java.util.common.concurrent.Execs; +import org.apache.druid.java.util.common.logger.Logger; +import org.apache.druid.java.util.http.client.response.BytesFullResponseHandler; +import org.apache.druid.java.util.http.client.response.BytesFullResponseHolder; +import org.apache.druid.msq.counters.CounterSnapshotsTree; +import org.apache.druid.msq.exec.WorkerClient; +import org.apache.druid.msq.kernel.StageId; +import org.apache.druid.msq.kernel.WorkOrder; +import org.apache.druid.msq.statistics.ClusterByStatisticsSnapshot; +import org.apache.druid.rpc.IgnoreHttpResponseHandler; +import org.apache.druid.rpc.RequestBuilder; +import org.apache.druid.rpc.ServiceClient; +import org.jboss.netty.handler.codec.http.HttpMethod; + +import javax.annotation.Nonnull; +import javax.ws.rs.core.HttpHeaders; +import java.io.IOException; + +/** + * Base worker client. Subclasses override {@link #getClient(String)} and {@link #close()} to build a complete client + * for talking to specific types of workers. + */ +public abstract class BaseWorkerClientImpl implements WorkerClient +{ + private final ObjectMapper objectMapper; + private final String contentType; + + protected BaseWorkerClientImpl(final ObjectMapper objectMapper, final String contentType) + { + this.objectMapper = objectMapper; + this.contentType = contentType; + } + + @Nonnull + public static String getStagePartitionPath(StageId stageId, int partitionNumber) + { + return StringUtils.format( + "/channels/%s/%d/%d", + StringUtils.urlEncode(stageId.getQueryId()), + stageId.getStageNumber(), + partitionNumber + ); + } + + @Override + public ListenableFuture postWorkOrder(String workerId, WorkOrder workOrder) + { + return getClient(workerId).asyncRequest( + new RequestBuilder(HttpMethod.POST, "/workOrder") + .objectContent(objectMapper, contentType, workOrder), + IgnoreHttpResponseHandler.INSTANCE + ); + } + + @Override + public ListenableFuture fetchClusterByStatisticsSnapshot( + String workerId, + StageId stageId + ) + { + String path = StringUtils.format( + "/keyStatistics/%s/%d", + StringUtils.urlEncode(stageId.getQueryId()), + stageId.getStageNumber() + ); + + return FutureUtils.transform( + getClient(workerId).asyncRequest( + new RequestBuilder(HttpMethod.POST, path).header(HttpHeaders.ACCEPT, contentType), + new BytesFullResponseHandler() + ), + holder -> deserialize(holder, new TypeReference() {}) + ); + } + + @Override + public ListenableFuture fetchClusterByStatisticsSnapshotForTimeChunk( + String workerId, + StageId stageId, + long timeChunk + ) + { + String path = StringUtils.format( + "/keyStatisticsForTimeChunk/%s/%d/%d", + StringUtils.urlEncode(stageId.getQueryId()), + stageId.getStageNumber(), + timeChunk + ); + + return FutureUtils.transform( + getClient(workerId).asyncRequest( + new RequestBuilder(HttpMethod.POST, path).header(HttpHeaders.ACCEPT, contentType), + new BytesFullResponseHandler() + ), + holder -> deserialize(holder, new TypeReference() {}) + ); + } + + @Override + public ListenableFuture postResultPartitionBoundaries( + String workerId, + StageId stageId, + ClusterByPartitions partitionBoundaries + ) + { + final String path = StringUtils.format( + "/resultPartitionBoundaries/%s/%d", + StringUtils.urlEncode(stageId.getQueryId()), + stageId.getStageNumber() + ); + + return getClient(workerId).asyncRequest( + new RequestBuilder(HttpMethod.POST, path) + .objectContent(objectMapper, contentType, partitionBoundaries), + IgnoreHttpResponseHandler.INSTANCE + ); + } + + /** + * Client-side method for {@link org.apache.druid.msq.indexing.client.WorkerChatHandler#httpPostCleanupStage}. + */ + @Override + public ListenableFuture postCleanupStage( + final String workerId, + final StageId stageId + ) + { + final String path = StringUtils.format( + "/cleanupStage/%s/%d", + StringUtils.urlEncode(stageId.getQueryId()), + stageId.getStageNumber() + ); + + return getClient(workerId).asyncRequest( + new RequestBuilder(HttpMethod.POST, path), + IgnoreHttpResponseHandler.INSTANCE + ); + } + + @Override + public ListenableFuture postFinish(String workerId) + { + return getClient(workerId).asyncRequest( + new RequestBuilder(HttpMethod.POST, "/finish"), + IgnoreHttpResponseHandler.INSTANCE + ); + } + + @Override + public ListenableFuture getCounters(String workerId) + { + return FutureUtils.transform( + getClient(workerId).asyncRequest( + new RequestBuilder(HttpMethod.GET, "/counters").header(HttpHeaders.ACCEPT, contentType), + new BytesFullResponseHandler() + ), + holder -> deserialize(holder, new TypeReference() {}) + ); + } + + private static final Logger log = new Logger(BaseWorkerClientImpl.class); + + @Override + public ListenableFuture fetchChannelData( + String workerId, + StageId stageId, + int partitionNumber, + long offset, + ReadableByteChunksFrameChannel channel + ) + { + final ServiceClient client = getClient(workerId); + final String path = getStagePartitionPath(stageId, partitionNumber); + + final SettableFuture retVal = SettableFuture.create(); + final ListenableFuture clientFuture = + client.asyncRequest( + new RequestBuilder(HttpMethod.GET, StringUtils.format("%s?offset=%d", path, offset)) + .header(HttpHeaders.ACCEPT_ENCODING, "identity"), // Data is compressed at app level + new FrameFileHttpResponseHandler(channel) + ); + + Futures.addCallback( + clientFuture, + new FutureCallback() + { + @Override + public void onSuccess(FrameFilePartialFetch partialFetch) + { + if (partialFetch.isExceptionCaught()) { + // Exception while reading channel. Recoverable. + log.noStackTrace().info( + partialFetch.getExceptionCaught(), + "Encountered exception while reading channel [%s]", + channel.getId() + ); + } + + // Empty fetch means this is the last fetch for the channel. + partialFetch.backpressureFuture().addListener( + () -> retVal.set(partialFetch.isLastFetch()), + Execs.directExecutor() + ); + } + + @Override + public void onFailure(Throwable t) + { + retVal.setException(t); + } + }, + Execs.directExecutor() + ); + + return retVal; + } + + /** + * Create a client to communicate with a given worker ID. + */ + protected abstract ServiceClient getClient(String workerId); + + /** + * Deserialize a {@link BytesFullResponseHolder} as JSON. + * + * It would be reasonable to move this to {@link BytesFullResponseHolder} itself, or some shared utility class. + */ + protected T deserialize(final BytesFullResponseHolder bytesHolder, final TypeReference typeReference) + { + try { + return objectMapper.readValue(bytesHolder.getContent(), typeReference); + } + catch (IOException e) { + throw new RuntimeException(e); + } + } +} diff --git a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/rpc/ControllerResource.java b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/rpc/ControllerResource.java new file mode 100644 index 000000000000..d3e9eefa86d2 --- /dev/null +++ b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/rpc/ControllerResource.java @@ -0,0 +1,196 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.druid.msq.rpc; + +import org.apache.druid.indexer.report.TaskReport; +import org.apache.druid.msq.counters.CounterSnapshots; +import org.apache.druid.msq.counters.CounterSnapshotsTree; +import org.apache.druid.msq.exec.Controller; +import org.apache.druid.msq.exec.ControllerClient; +import org.apache.druid.msq.indexing.MSQTaskList; +import org.apache.druid.msq.indexing.error.MSQErrorReport; +import org.apache.druid.msq.kernel.StageId; +import org.apache.druid.msq.statistics.PartialKeyStatisticsInformation; +import org.apache.druid.server.security.AuthorizerMapper; + +import javax.servlet.http.HttpServletRequest; +import javax.ws.rs.Consumes; +import javax.ws.rs.GET; +import javax.ws.rs.POST; +import javax.ws.rs.Path; +import javax.ws.rs.PathParam; +import javax.ws.rs.Produces; +import javax.ws.rs.core.Context; +import javax.ws.rs.core.MediaType; +import javax.ws.rs.core.Response; +import java.util.List; + +public class ControllerResource +{ + private final Controller controller; + private final ResourcePermissionMapper permissionMapper; + private final AuthorizerMapper authorizerMapper; + + public ControllerResource( + final Controller controller, + final ResourcePermissionMapper permissionMapper, + final AuthorizerMapper authorizerMapper + ) + { + this.controller = controller; + this.permissionMapper = permissionMapper; + this.authorizerMapper = authorizerMapper; + } + + /** + * Used by subtasks to post {@link PartialKeyStatisticsInformation} for shuffling stages. + * + * See {@link ControllerClient#postPartialKeyStatistics(StageId, int, PartialKeyStatisticsInformation)} + * for the client-side code that calls this API. + */ + @POST + @Path("/partialKeyStatisticsInformation/{queryId}/{stageNumber}/{workerNumber}") + @Produces(MediaType.APPLICATION_JSON) + @Consumes(MediaType.APPLICATION_JSON) + public Response httpPostPartialKeyStatistics( + final Object partialKeyStatisticsObject, + @PathParam("queryId") final String queryId, + @PathParam("stageNumber") final int stageNumber, + @PathParam("workerNumber") final int workerNumber, + @Context final HttpServletRequest req + ) + { + MSQResourceUtils.authorizeAdminRequest(permissionMapper, authorizerMapper, req); + controller.updatePartialKeyStatisticsInformation(stageNumber, workerNumber, partialKeyStatisticsObject); + return Response.status(Response.Status.ACCEPTED).build(); + } + + /** + * Used by subtasks to post system errors. Note that the errors are organized by taskId, not by query/stage/worker, + * because system errors are associated with a task rather than a specific query/stage/worker execution context. + * + * See {@link ControllerClient#postWorkerError} for the client-side code that calls this API. + */ + @POST + @Path("/workerError/{taskId}") + @Produces(MediaType.APPLICATION_JSON) + @Consumes(MediaType.APPLICATION_JSON) + public Response httpPostWorkerError( + final MSQErrorReport errorReport, + @PathParam("taskId") final String taskId, + @Context final HttpServletRequest req + ) + { + MSQResourceUtils.authorizeAdminRequest(permissionMapper, authorizerMapper, req); + controller.workerError(errorReport); + return Response.status(Response.Status.ACCEPTED).build(); + } + + /** + * Used by subtasks to post system warnings. + * + * See {@link ControllerClient#postWorkerWarning} for the client-side code that calls this API. + */ + @POST + @Path("/workerWarning") + @Produces(MediaType.APPLICATION_JSON) + @Consumes(MediaType.APPLICATION_JSON) + public Response httpPostWorkerWarning( + final List errorReport, + @Context final HttpServletRequest req + ) + { + MSQResourceUtils.authorizeAdminRequest(permissionMapper, authorizerMapper, req); + controller.workerWarning(errorReport); + return Response.status(Response.Status.ACCEPTED).build(); + } + + + /** + * Used by subtasks to post {@link CounterSnapshots} periodically. + * + * See {@link ControllerClient#postCounters} for the client-side code that calls this API. + */ + @POST + @Path("/counters/{taskId}") + @Produces(MediaType.APPLICATION_JSON) + @Consumes(MediaType.APPLICATION_JSON) + public Response httpPostCounters( + @PathParam("taskId") final String taskId, + final CounterSnapshotsTree snapshotsTree, + @Context final HttpServletRequest req + ) + { + MSQResourceUtils.authorizeAdminRequest(permissionMapper, authorizerMapper, req); + controller.updateCounters(taskId, snapshotsTree); + return Response.status(Response.Status.OK).build(); + } + + /** + * Used by subtasks to post notifications that their results are ready. + * + * See {@link ControllerClient#postResultsComplete} for the client-side code that calls this API. + */ + @POST + @Path("/resultsComplete/{queryId}/{stageNumber}/{workerNumber}") + @Produces(MediaType.APPLICATION_JSON) + @Consumes(MediaType.APPLICATION_JSON) + public Response httpPostResultsComplete( + final Object resultObject, + @PathParam("queryId") final String queryId, + @PathParam("stageNumber") final int stageNumber, + @PathParam("workerNumber") final int workerNumber, + @Context final HttpServletRequest req + ) + { + MSQResourceUtils.authorizeAdminRequest(permissionMapper, authorizerMapper, req); + controller.resultsComplete(queryId, stageNumber, workerNumber, resultObject); + return Response.status(Response.Status.ACCEPTED).build(); + } + + /** + * See {@link ControllerClient#getTaskList()} for the client-side code that calls this API. + */ + @GET + @Path("/taskList") + @Produces(MediaType.APPLICATION_JSON) + public Response httpGetTaskList(@Context final HttpServletRequest req) + { + MSQResourceUtils.authorizeAdminRequest(permissionMapper, authorizerMapper, req); + return Response.ok(new MSQTaskList(controller.getTaskIds())).build(); + } + + /** + * See {@link org.apache.druid.indexing.overlord.RemoteTaskRunner#streamTaskReports} for the client-side code that + * calls this API. + */ + @GET + @Path("/liveReports") + @Produces(MediaType.APPLICATION_JSON) + public Response httpGetLiveReports(@Context final HttpServletRequest req) + { + MSQResourceUtils.authorizeAdminRequest(permissionMapper, authorizerMapper, req); + final TaskReport.ReportMap reports = controller.liveReports(); + if (reports == null) { + return Response.status(Response.Status.NOT_FOUND).build(); + } + return Response.ok(reports).build(); + } +} diff --git a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/rpc/MSQResourceUtils.java b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/rpc/MSQResourceUtils.java new file mode 100644 index 000000000000..30a8179fe0f0 --- /dev/null +++ b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/rpc/MSQResourceUtils.java @@ -0,0 +1,50 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.druid.msq.rpc; + +import org.apache.druid.server.security.Access; +import org.apache.druid.server.security.AuthorizationUtils; +import org.apache.druid.server.security.AuthorizerMapper; +import org.apache.druid.server.security.ForbiddenException; +import org.apache.druid.server.security.ResourceAction; + +import javax.servlet.http.HttpServletRequest; +import java.util.List; + +/** + * Utility methods for MSQ resources such as {@link ControllerResource}. + */ +public class MSQResourceUtils +{ + public static void authorizeAdminRequest( + final ResourcePermissionMapper permissionMapper, + final AuthorizerMapper authorizerMapper, + final HttpServletRequest request + ) + { + final List resourceActions = permissionMapper.getAdminPermissions(); + + Access access = AuthorizationUtils.authorizeAllResourceActions(request, resourceActions, authorizerMapper); + + if (!access.isAllowed()) { + throw new ForbiddenException(access.toString()); + } + } +} diff --git a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/guice/annotations/MSQ.java b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/rpc/ResourcePermissionMapper.java similarity index 59% rename from extensions-core/multi-stage-query/src/main/java/org/apache/druid/guice/annotations/MSQ.java rename to extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/rpc/ResourcePermissionMapper.java index c480168de258..8c79f4fa0e05 100644 --- a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/guice/annotations/MSQ.java +++ b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/rpc/ResourcePermissionMapper.java @@ -17,24 +17,17 @@ * under the License. */ -package org.apache.druid.guice.annotations; +package org.apache.druid.msq.rpc; -import com.google.inject.BindingAnnotation; +import org.apache.druid.server.security.ResourceAction; -import java.lang.annotation.ElementType; -import java.lang.annotation.Retention; -import java.lang.annotation.RetentionPolicy; -import java.lang.annotation.Target; +import java.util.List; /** - * Binding annotation for implements of interfaces that are MSQ (MultiStageQuery) focused. This is generally - * contrasted with the NativeQ annotation. - * - * @see Parent + * Provides HTTP resources such as {@link ControllerResource} with information about which permissions are needed + * for requests. */ -@Target({ElementType.FIELD, ElementType.PARAMETER, ElementType.METHOD}) -@Retention(RetentionPolicy.RUNTIME) -@BindingAnnotation -public @interface MSQ +public interface ResourcePermissionMapper { + List getAdminPermissions(); } diff --git a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/shuffle/input/DurableStorageInputChannelFactory.java b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/shuffle/input/DurableStorageInputChannelFactory.java index c0e892b99bf0..f913dbb1858a 100644 --- a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/shuffle/input/DurableStorageInputChannelFactory.java +++ b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/shuffle/input/DurableStorageInputChannelFactory.java @@ -25,6 +25,7 @@ import org.apache.druid.frame.channel.ReadableInputStreamFrameChannel; import org.apache.druid.java.util.common.IOE; import org.apache.druid.java.util.common.ISE; +import org.apache.druid.java.util.common.StringUtils; import org.apache.druid.java.util.common.concurrent.Execs; import org.apache.druid.java.util.common.io.Closer; import org.apache.druid.java.util.common.logger.Logger; @@ -73,8 +74,11 @@ public static DurableStorageInputChannelFactory createStandardImplementation( final boolean isQueryResults ) { + final String threadNameFormat = + StringUtils.encodeForFormat(Preconditions.checkNotNull(controllerTaskId, "controllerTaskId")) + + "-remote-fetcher-%d"; final ExecutorService remoteInputStreamPool = - Executors.newCachedThreadPool(Execs.makeThreadFactory(controllerTaskId + "-remote-fetcher-%d")); + Executors.newCachedThreadPool(Execs.makeThreadFactory(threadNameFormat)); closer.register(remoteInputStreamPool::shutdownNow); if (isQueryResults) { return new DurableStorageQueryResultsInputChannelFactory( diff --git a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/sql/resources/SqlStatementResource.java b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/sql/resources/SqlStatementResource.java index cc360a48ede2..4beb2a869ef0 100644 --- a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/sql/resources/SqlStatementResource.java +++ b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/sql/resources/SqlStatementResource.java @@ -37,7 +37,6 @@ import org.apache.druid.error.NotFound; import org.apache.druid.error.QueryExceptionCompat; import org.apache.druid.frame.channel.FrameChannelSequence; -import org.apache.druid.guice.annotations.MSQ; import org.apache.druid.indexer.TaskStatusPlus; import org.apache.druid.java.util.common.ISE; import org.apache.druid.java.util.common.RE; @@ -48,6 +47,7 @@ import org.apache.druid.java.util.common.guava.Yielders; import org.apache.druid.java.util.common.io.Closer; import org.apache.druid.java.util.common.logger.Logger; +import org.apache.druid.msq.exec.ResultsContext; import org.apache.druid.msq.guice.MultiStageQuery; import org.apache.druid.msq.indexing.MSQControllerTask; import org.apache.druid.msq.indexing.MSQSpec; @@ -131,7 +131,7 @@ public class SqlStatementResource @Inject public SqlStatementResource( - final @MSQ SqlStatementFactory msqSqlStatementFactory, + final @MultiStageQuery SqlStatementFactory msqSqlStatementFactory, final ObjectMapper jsonMapper, final OverlordClient overlordClient, final @MultiStageQuery StorageConnector storageConnector, @@ -540,27 +540,9 @@ private Optional getResultSetInformation( List results = null; if (isSelectQuery) { results = new ArrayList<>(); - Yielder yielder = null; if (msqTaskReportPayload.getResults() != null) { - yielder = msqTaskReportPayload.getResults().getResultYielder(); + results = msqTaskReportPayload.getResults().getResults(); } - try { - while (yielder != null && !yielder.isDone()) { - results.add(yielder.get()); - yielder = yielder.next(null); - } - } - finally { - if (yielder != null) { - try { - yielder.close(); - } - catch (IOException e) { - log.warn(e, StringUtils.format("Unable to close yielder for query[%s]", queryId)); - } - } - } - } return Optional.of( @@ -739,10 +721,10 @@ private Optional> getResultYielder( contactOverlord(overlordClient.taskReportAsMap(queryId), queryId) ); - if (msqTaskReportPayload.getResults().getResultYielder() == null) { + if (msqTaskReportPayload.getResults().getResults() == null) { results = Optional.empty(); } else { - results = Optional.of(msqTaskReportPayload.getResults().getResultYielder()); + results = Optional.of(Yielders.each(Sequences.simple(msqTaskReportPayload.getResults().getResults()))); } } else if (msqControllerTask.getQuerySpec().getDestination() instanceof DurableStorageMSQDestination) { @@ -801,12 +783,17 @@ private Optional> getResultYielder( } }) .collect(Collectors.toList())) - .flatMap(frame -> SqlStatementResourceHelper.getResultSequence( - msqControllerTask, - finalStage, - frame, - jsonMapper - ) + .flatMap(frame -> + SqlStatementResourceHelper.getResultSequence( + frame, + finalStage.getFrameReader(), + msqControllerTask.getQuerySpec().getColumnMappings(), + new ResultsContext( + msqControllerTask.getSqlTypeNames(), + msqControllerTask.getSqlResultsContext() + ), + jsonMapper + ) ) .withBaggage(closer))); diff --git a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/sql/resources/SqlTaskResource.java b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/sql/resources/SqlTaskResource.java index e9b4c61cef23..7a51bc8d26a4 100644 --- a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/sql/resources/SqlTaskResource.java +++ b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/sql/resources/SqlTaskResource.java @@ -26,12 +26,12 @@ import org.apache.druid.error.DruidException; import org.apache.druid.error.ErrorResponse; import org.apache.druid.error.QueryExceptionCompat; -import org.apache.druid.guice.annotations.MSQ; import org.apache.druid.indexer.TaskState; import org.apache.druid.java.util.common.guava.Sequence; import org.apache.druid.java.util.common.guava.Yielder; import org.apache.druid.java.util.common.guava.Yielders; import org.apache.druid.java.util.common.logger.Logger; +import org.apache.druid.msq.guice.MultiStageQuery; import org.apache.druid.msq.sql.MSQTaskSqlEngine; import org.apache.druid.msq.sql.SqlTaskStatus; import org.apache.druid.query.QueryException; @@ -86,7 +86,7 @@ public class SqlTaskResource @Inject public SqlTaskResource( - final @MSQ SqlStatementFactory sqlStatementFactory, + final @MultiStageQuery SqlStatementFactory sqlStatementFactory, final ServerConfig serverConfig, final AuthorizerMapper authorizerMapper, final ObjectMapper jsonMapper diff --git a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/statistics/PartialKeyStatisticsInformation.java b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/statistics/PartialKeyStatisticsInformation.java index 535af8dafb0a..9a6e256d9add 100644 --- a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/statistics/PartialKeyStatisticsInformation.java +++ b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/statistics/PartialKeyStatisticsInformation.java @@ -22,6 +22,7 @@ import com.fasterxml.jackson.annotation.JsonCreator; import com.fasterxml.jackson.annotation.JsonProperty; +import java.util.Objects; import java.util.Set; /** @@ -64,4 +65,35 @@ public double getBytesRetained() { return bytesRetained; } + + @Override + public boolean equals(Object o) + { + if (this == o) { + return true; + } + if (o == null || getClass() != o.getClass()) { + return false; + } + PartialKeyStatisticsInformation that = (PartialKeyStatisticsInformation) o; + return multipleValues == that.multipleValues + && Double.compare(bytesRetained, that.bytesRetained) == 0 + && Objects.equals(timeSegments, that.timeSegments); + } + + @Override + public int hashCode() + { + return Objects.hash(timeSegments, multipleValues, bytesRetained); + } + + @Override + public String toString() + { + return "PartialKeyStatisticsInformation{" + + "timeSegments=" + timeSegments + + ", multipleValues=" + multipleValues + + ", bytesRetained=" + bytesRetained + + '}'; + } } diff --git a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/util/MultiStageQueryContext.java b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/util/MultiStageQueryContext.java index 60734b5b1dad..4b599cd32d5b 100644 --- a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/util/MultiStageQueryContext.java +++ b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/util/MultiStageQueryContext.java @@ -33,6 +33,7 @@ import org.apache.druid.msq.exec.Limits; import org.apache.druid.msq.exec.SegmentSource; import org.apache.druid.msq.indexing.destination.MSQSelectDestination; +import org.apache.druid.msq.indexing.error.MSQWarnings; import org.apache.druid.msq.kernel.WorkerAssignmentStrategy; import org.apache.druid.msq.sql.MSQMode; import org.apache.druid.query.QueryContext; @@ -112,6 +113,8 @@ public class MultiStageQueryContext public static final String CTX_INCLUDE_SEGMENT_SOURCE = "includeSegmentSource"; public static final SegmentSource DEFAULT_INCLUDE_SEGMENT_SOURCE = SegmentSource.NONE; + public static final String CTX_MAX_CONCURRENT_STAGES = "maxConcurrentStages"; + public static final int DEFAULT_MAX_CONCURRENT_STAGES = 1; public static final String CTX_DURABLE_SHUFFLE_STORAGE = "durableShuffleStorage"; private static final boolean DEFAULT_DURABLE_SHUFFLE_STORAGE = false; public static final String CTX_SELECT_DESTINATION = "selectDestination"; @@ -173,6 +176,14 @@ public static String getMSQMode(final QueryContext queryContext) ); } + public static int getMaxConcurrentStages(final QueryContext queryContext) + { + return queryContext.getInt( + CTX_MAX_CONCURRENT_STAGES, + DEFAULT_MAX_CONCURRENT_STAGES + ); + } + public static boolean isDurableStorageEnabled(final QueryContext queryContext) { return queryContext.getBoolean( @@ -316,6 +327,14 @@ public static IndexSpec getIndexSpec(final QueryContext queryContext, final Obje return decodeIndexSpec(queryContext.get(CTX_INDEX_SPEC), objectMapper); } + public static long getMaxParseExceptions(final QueryContext queryContext) + { + return queryContext.getLong( + MSQWarnings.CTX_MAX_PARSE_EXCEPTIONS_ALLOWED, + MSQWarnings.DEFAULT_MAX_PARSE_EXCEPTIONS_ALLOWED + ); + } + public static boolean useAutoColumnSchemas(final QueryContext queryContext) { return queryContext.getBoolean(CTX_USE_AUTO_SCHEMAS, DEFAULT_USE_AUTO_SCHEMAS); diff --git a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/util/SqlStatementResourceHelper.java b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/util/SqlStatementResourceHelper.java index f90959a56667..4f07dcb2cc02 100644 --- a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/util/SqlStatementResourceHelper.java +++ b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/util/SqlStatementResourceHelper.java @@ -29,6 +29,7 @@ import org.apache.druid.error.NotFound; import org.apache.druid.frame.Frame; import org.apache.druid.frame.processor.FrameProcessors; +import org.apache.druid.frame.read.FrameReader; import org.apache.druid.indexer.TaskLocation; import org.apache.druid.indexer.TaskState; import org.apache.druid.indexer.TaskStatusPlus; @@ -42,6 +43,7 @@ import org.apache.druid.msq.counters.CounterSnapshotsTree; import org.apache.druid.msq.counters.QueryCounterSnapshot; import org.apache.druid.msq.counters.SegmentGenerationProgressCounter; +import org.apache.druid.msq.exec.ResultsContext; import org.apache.druid.msq.indexing.MSQControllerTask; import org.apache.druid.msq.indexing.destination.DataSourceMSQDestination; import org.apache.druid.msq.indexing.destination.DurableStorageMSQDestination; @@ -52,7 +54,6 @@ import org.apache.druid.msq.indexing.report.MSQStagesReport; import org.apache.druid.msq.indexing.report.MSQTaskReport; import org.apache.druid.msq.indexing.report.MSQTaskReportPayload; -import org.apache.druid.msq.kernel.StageDefinition; import org.apache.druid.msq.sql.SqlStatementState; import org.apache.druid.msq.sql.entity.ColumnNameAndTypes; import org.apache.druid.msq.sql.entity.PageInformation; @@ -71,6 +72,7 @@ import java.util.Iterator; import java.util.List; import java.util.Map; +import java.util.NoSuchElementException; import java.util.Optional; import java.util.stream.Collectors; @@ -296,25 +298,23 @@ protected DruidException makeException(DruidException.DruidExceptionBuilder bob) } public static Sequence getResultSequence( - MSQControllerTask msqControllerTask, - StageDefinition finalStage, - Frame frame, - ObjectMapper jsonMapper + final Frame resultsFrame, + final FrameReader resultFrameReader, + final ColumnMappings resultColumnMappings, + final ResultsContext resultsContext, + final ObjectMapper jsonMapper ) { - final Cursor cursor = FrameProcessors.makeCursor(frame, finalStage.getFrameReader()); - + final Cursor cursor = FrameProcessors.makeCursor(resultsFrame, resultFrameReader); final ColumnSelectorFactory columnSelectorFactory = cursor.getColumnSelectorFactory(); - final ColumnMappings columnMappings = msqControllerTask.getQuerySpec().getColumnMappings(); @SuppressWarnings("rawtypes") - final List selectors = columnMappings.getMappings() - .stream() - .map(mapping -> columnSelectorFactory.makeColumnValueSelector( - mapping.getQueryColumn())) - .collect(Collectors.toList()); - - final List sqlTypeNames = msqControllerTask.getSqlTypeNames(); - Iterable retVal = () -> new Iterator() + final List selectors = + resultColumnMappings.getMappings() + .stream() + .map(mapping -> columnSelectorFactory.makeColumnValueSelector(mapping.getQueryColumn())) + .collect(Collectors.toList()); + + final Iterable retVal = () -> new Iterator() { @Override public boolean hasNext() @@ -325,19 +325,23 @@ public boolean hasNext() @Override public Object[] next() { - final Object[] row = new Object[columnMappings.size()]; + if (!hasNext()) { + throw new NoSuchElementException(); + } + + final Object[] row = new Object[resultColumnMappings.size()]; for (int i = 0; i < row.length; i++) { final Object value = selectors.get(i).getObject(); - if (sqlTypeNames == null || msqControllerTask.getSqlResultsContext() == null) { + if (resultsContext.getSqlTypeNames() == null || resultsContext.getSqlResultsContext() == null) { // SQL type unknown, or no SQL results context: pass-through as is. row[i] = value; } else { row[i] = SqlResults.coerce( jsonMapper, - msqControllerTask.getSqlResultsContext(), + resultsContext.getSqlResultsContext(), value, - sqlTypeNames.get(i), - columnMappings.getOutputColumnName(i) + resultsContext.getSqlTypeNames().get(i), + resultColumnMappings.getOutputColumnName(i) ); } } diff --git a/extensions-core/multi-stage-query/src/main/resources/META-INF/services/org.apache.druid.initialization.DruidModule b/extensions-core/multi-stage-query/src/main/resources/META-INF/services/org.apache.druid.initialization.DruidModule index cabd131fb758..92be5604cb8a 100644 --- a/extensions-core/multi-stage-query/src/main/resources/META-INF/services/org.apache.druid.initialization.DruidModule +++ b/extensions-core/multi-stage-query/src/main/resources/META-INF/services/org.apache.druid.initialization.DruidModule @@ -13,8 +13,10 @@ # See the License for the specific language governing permissions and # limitations under the License. +org.apache.druid.msq.guice.IndexerMemoryManagementModule +org.apache.druid.msq.guice.MSQDurableStorageModule org.apache.druid.msq.guice.MSQExternalDataSourceModule org.apache.druid.msq.guice.MSQIndexingModule -org.apache.druid.msq.guice.MSQDurableStorageModule org.apache.druid.msq.guice.MSQSqlModule +org.apache.druid.msq.guice.PeonMemoryManagementModule org.apache.druid.msq.guice.SqlTaskModule diff --git a/extensions-core/multi-stage-query/src/main/resources/log4j2.xml b/extensions-core/multi-stage-query/src/main/resources/log4j2.xml index e99abd743366..d98bb05ef6cd 100644 --- a/extensions-core/multi-stage-query/src/main/resources/log4j2.xml +++ b/extensions-core/multi-stage-query/src/main/resources/log4j2.xml @@ -31,6 +31,9 @@ + + + diff --git a/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/exec/ControllerImplTest.java b/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/exec/ControllerImplTest.java index 41c3cff66a50..db5ef1a089c2 100644 --- a/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/exec/ControllerImplTest.java +++ b/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/exec/ControllerImplTest.java @@ -72,6 +72,7 @@ public void test_performSegmentPublish_ok() throws IOException // All OK. ControllerImpl.performSegmentPublish(taskActionClient, action); + EasyMock.verify(taskActionClient); } @Test @@ -90,6 +91,7 @@ public void test_performSegmentPublish_publishFail() throws IOException ); Assert.assertEquals(InsertLockPreemptedFault.instance(), e.getFault()); + EasyMock.verify(taskActionClient); } @Test @@ -108,6 +110,7 @@ public void test_performSegmentPublish_publishException() throws IOException ); Assert.assertEquals("oops", e.getMessage()); + EasyMock.verify(taskActionClient); } @Test @@ -126,6 +129,7 @@ public void test_performSegmentPublish_publishLockPreemptedException() throws IO ); Assert.assertEquals(InsertLockPreemptedFault.instance(), e.getFault()); + EasyMock.verify(taskActionClient); } diff --git a/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/exec/ControllerMemoryParametersTest.java b/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/exec/ControllerMemoryParametersTest.java new file mode 100644 index 000000000000..9d27dcca666b --- /dev/null +++ b/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/exec/ControllerMemoryParametersTest.java @@ -0,0 +1,121 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.druid.msq.exec; + +import com.google.common.collect.ImmutableMap; +import org.apache.druid.msq.indexing.error.MSQException; +import org.apache.druid.msq.indexing.error.NotEnoughMemoryFault; +import org.apache.druid.sql.calcite.util.TestLookupProvider; +import org.junit.Assert; +import org.junit.Test; + +public class ControllerMemoryParametersTest +{ + private static final double USABLE_MEMORY_FRACTION = 0.8; + private static final int NUM_PROCESSORS_IN_JVM = 2; + + @Test + public void test_oneQueryInJvm() + { + final ControllerMemoryParameters memoryParameters = ControllerMemoryParameters.createProductionInstance( + makeMemoryIntrospector(128_000_000, 1), + 1 + ); + + Assert.assertEquals(100_400_000, memoryParameters.getPartitionStatisticsMaxRetainedBytes()); + } + + @Test + public void test_oneQueryInJvm_oneHundredWorkers() + { + final ControllerMemoryParameters memoryParameters = ControllerMemoryParameters.createProductionInstance( + makeMemoryIntrospector(256_000_000, 1), + 100 + ); + + Assert.assertEquals(103_800_000, memoryParameters.getPartitionStatisticsMaxRetainedBytes()); + } + + @Test + public void test_twoQueriesInJvm() + { + final ControllerMemoryParameters memoryParameters = ControllerMemoryParameters.createProductionInstance( + makeMemoryIntrospector(128_000_000, 2), + 1 + ); + + Assert.assertEquals(49_200_000, memoryParameters.getPartitionStatisticsMaxRetainedBytes()); + } + + @Test + public void test_maxSized() + { + final ControllerMemoryParameters memoryParameters = ControllerMemoryParameters.createProductionInstance( + makeMemoryIntrospector(1_000_000_000, 1), + 1 + ); + + Assert.assertEquals(300_000_000, memoryParameters.getPartitionStatisticsMaxRetainedBytes()); + } + + @Test + public void test_notEnoughMemory() + { + final MSQException e = Assert.assertThrows( + MSQException.class, + () -> ControllerMemoryParameters.createProductionInstance( + makeMemoryIntrospector(30_000_000, 1), + 1 + ) + ); + + final NotEnoughMemoryFault fault = (NotEnoughMemoryFault) e.getFault(); + Assert.assertEquals(30_000_000, fault.getServerMemory()); + Assert.assertEquals(1, fault.getServerWorkers()); + Assert.assertEquals(NUM_PROCESSORS_IN_JVM, fault.getServerThreads()); + Assert.assertEquals(24_000_000, fault.getUsableMemory()); + Assert.assertEquals(33_750_000, fault.getSuggestedServerMemory()); + } + + @Test + public void test_minimalMemory() + { + final ControllerMemoryParameters memoryParameters = ControllerMemoryParameters.createProductionInstance( + makeMemoryIntrospector(33_750_000, 1), + 1 + ); + + Assert.assertEquals(25_000_000, memoryParameters.getPartitionStatisticsMaxRetainedBytes()); + } + + private MemoryIntrospector makeMemoryIntrospector( + final long totalMemoryInJvm, + final int numQueriesInJvm + ) + { + return new MemoryIntrospectorImpl( + new TestLookupProvider(ImmutableMap.of()), + totalMemoryInJvm, + USABLE_MEMORY_FRACTION, + numQueriesInJvm, + NUM_PROCESSORS_IN_JVM + ); + } +} diff --git a/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/exec/MSQFaultsTest.java b/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/exec/MSQFaultsTest.java index b3b1442074b7..425609628b3a 100644 --- a/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/exec/MSQFaultsTest.java +++ b/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/exec/MSQFaultsTest.java @@ -27,6 +27,7 @@ import org.apache.druid.indexing.common.TaskLockType; import org.apache.druid.indexing.common.actions.RetrieveUsedSegmentsAction; import org.apache.druid.indexing.common.actions.SegmentAllocateAction; +import org.apache.druid.indexing.common.actions.TaskAction; import org.apache.druid.indexing.common.task.Tasks; import org.apache.druid.java.util.common.ISE; import org.apache.druid.java.util.common.Intervals; @@ -51,6 +52,7 @@ import org.hamcrest.CoreMatchers; import org.junit.internal.matchers.ThrowableMessageMatcher; import org.junit.jupiter.api.Test; +import org.mockito.ArgumentMatcher; import org.mockito.ArgumentMatchers; import org.mockito.Mockito; @@ -532,7 +534,10 @@ public void testReplaceTombstonesWithTooManyBucketsThrowsFault() Mockito.doReturn(ImmutableSet.of(existingDataSegment)) .when(testTaskActionClient) - .submit(ArgumentMatchers.isA(RetrieveUsedSegmentsAction.class)); + .submit(ArgumentMatchers.argThat( + (ArgumentMatcher>) argument -> + argument instanceof RetrieveUsedSegmentsAction + && "foo1".equals(((RetrieveUsedSegmentsAction) argument).getDataSource()))); String expectedError = new TooManyBucketsFault(Limits.MAX_PARTITION_BUCKETS).getErrorMessage(); @@ -578,7 +583,10 @@ public void testReplaceTombstonesWithTooManyBucketsThrowsFault2() Mockito.doReturn(ImmutableSet.of(existingDataSegment)) .when(testTaskActionClient) - .submit(ArgumentMatchers.isA(RetrieveUsedSegmentsAction.class)); + .submit(ArgumentMatchers.argThat( + (ArgumentMatcher>) argument -> + argument instanceof RetrieveUsedSegmentsAction + && "foo1".equals(((RetrieveUsedSegmentsAction) argument).getDataSource()))); String expectedError = new TooManyBucketsFault(Limits.MAX_PARTITION_BUCKETS).getErrorMessage(); diff --git a/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/exec/MSQInsertTest.java b/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/exec/MSQInsertTest.java index f05e35c304c0..ecdc30294dbd 100644 --- a/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/exec/MSQInsertTest.java +++ b/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/exec/MSQInsertTest.java @@ -1343,7 +1343,7 @@ public void testInsertWithTooLargeRowShouldThrowException(String contextName, Ma final File toRead = getResourceAsTemporaryFile("/wikipedia-sampled.json"); final String toReadFileNameAsJson = queryFramework().queryJsonMapper().writeValueAsString(toRead.getAbsolutePath()); - Mockito.doReturn(500).when(workerMemoryParameters).getLargeFrameSize(); + Mockito.doReturn(500).when(workerMemoryParameters).getStandardFrameSize(); testIngestQuery().setSql(" insert into foo1 SELECT\n" + " floor(TIME_PARSE(\"timestamp\") to day) AS __time,\n" diff --git a/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/exec/MSQReplaceTest.java b/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/exec/MSQReplaceTest.java index 9a4fb98666b3..227a9656a142 100644 --- a/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/exec/MSQReplaceTest.java +++ b/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/exec/MSQReplaceTest.java @@ -34,6 +34,7 @@ import org.apache.druid.indexer.partitions.PartitionsSpec; import org.apache.druid.indexing.common.TaskLockType; import org.apache.druid.indexing.common.actions.RetrieveUsedSegmentsAction; +import org.apache.druid.indexing.common.actions.TaskAction; import org.apache.druid.indexing.common.task.Tasks; import org.apache.druid.java.util.common.DateTimes; import org.apache.druid.java.util.common.Intervals; @@ -58,6 +59,7 @@ import org.joda.time.Interval; import org.junit.jupiter.params.ParameterizedTest; import org.junit.jupiter.params.provider.MethodSource; +import org.mockito.ArgumentMatcher; import org.mockito.ArgumentMatchers; import org.mockito.Mockito; @@ -1650,7 +1652,12 @@ public void testEmptyReplaceAllOverEternitySegment(String contextName, Map>) argument -> + argument instanceof RetrieveUsedSegmentsAction + && "foo1".equals(((RetrieveUsedSegmentsAction) argument).getDataSource()) + )); // Insert with a condition which results in 0 rows being inserted -- do nothing. testIngestQuery().setSql( @@ -1683,7 +1690,10 @@ public void testEmptyReplaceAllWithAllGrainOverFiniteIntervalSegment(String cont .build(); Mockito.doReturn(ImmutableSet.of(existingDataSegment)) .when(testTaskActionClient) - .submit(ArgumentMatchers.isA(RetrieveUsedSegmentsAction.class)); + .submit(ArgumentMatchers.argThat( + (ArgumentMatcher>) argument -> + argument instanceof RetrieveUsedSegmentsAction + && "foo1".equals(((RetrieveUsedSegmentsAction) argument).getDataSource()))); // Insert with a condition which results in 0 rows being inserted -- do nothing. testIngestQuery().setSql( @@ -1716,7 +1726,10 @@ public void testEmptyReplaceAllWithAllGrainOverEternitySegment(String contextNam Mockito.doReturn(ImmutableSet.of(existingDataSegment)) .when(testTaskActionClient) - .submit(ArgumentMatchers.isA(RetrieveUsedSegmentsAction.class)); + .submit(ArgumentMatchers.argThat( + (ArgumentMatcher>) argument -> + argument instanceof RetrieveUsedSegmentsAction + && "foo1".equals(((RetrieveUsedSegmentsAction) argument).getDataSource()))); // Insert with a condition which results in 0 rows being inserted -- do nothing. testIngestQuery().setSql( @@ -1800,7 +1813,10 @@ public void testEmptyReplaceIntervalOverEternitySegment(String contextName, Map< Mockito.doReturn(ImmutableSet.of(existingDataSegment)) .when(testTaskActionClient) - .submit(ArgumentMatchers.isA(RetrieveUsedSegmentsAction.class)); + .submit(ArgumentMatchers.argThat( + (ArgumentMatcher>) argument -> + argument instanceof RetrieveUsedSegmentsAction + && "foo1".equals(((RetrieveUsedSegmentsAction) argument).getDataSource()))); // Insert with a condition which results in 0 rows being inserted -- do nothing! testIngestQuery().setSql( diff --git a/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/exec/MSQSelectTest.java b/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/exec/MSQSelectTest.java index 7c4af7389f6c..56f1ce986965 100644 --- a/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/exec/MSQSelectTest.java +++ b/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/exec/MSQSelectTest.java @@ -2220,10 +2220,7 @@ public void testSelectRowsGetUntruncatedByDefault(String contextName, Map context) { - - - - // This test asserts that the join algorithnm used is a different one from that supplied. In sqlCompatible() mode + // This test asserts that the join algorithm used is a different one from that supplied. In sqlCompatible() mode // the query gets planned differently, therefore we do use the sortMerge processor. Instead of having separate // handling, a similar test has been described in CalciteJoinQueryMSQTest, therefore we don't want to repeat that // here, hence ignoring in sqlCompatible() mode diff --git a/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/exec/MSQTasksTest.java b/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/exec/MSQTasksTest.java index 73a443db8a25..904510408ff0 100644 --- a/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/exec/MSQTasksTest.java +++ b/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/exec/MSQTasksTest.java @@ -20,10 +20,19 @@ package org.apache.druid.msq.exec; import com.google.common.collect.ImmutableMap; +import com.google.common.collect.ImmutableSet; +import com.google.common.util.concurrent.Futures; +import com.google.common.util.concurrent.ListenableFuture; import com.google.errorprone.annotations.concurrent.GuardedBy; +import org.apache.druid.client.indexing.NoopOverlordClient; +import org.apache.druid.client.indexing.TaskStatusResponse; +import org.apache.druid.common.guava.FutureUtils; +import org.apache.druid.indexer.RunnerTaskState; import org.apache.druid.indexer.TaskLocation; import org.apache.druid.indexer.TaskState; import org.apache.druid.indexer.TaskStatus; +import org.apache.druid.indexer.TaskStatusPlus; +import org.apache.druid.java.util.common.DateTimes; import org.apache.druid.java.util.common.ISE; import org.apache.druid.msq.indexing.MSQWorkerTask; import org.apache.druid.msq.indexing.MSQWorkerTaskLauncher; @@ -37,6 +46,7 @@ import org.apache.druid.msq.indexing.error.TooManyWorkersFault; import org.apache.druid.msq.indexing.error.UnknownFault; import org.apache.druid.msq.indexing.error.WorkerRpcFailedFault; +import org.apache.druid.utils.CollectionUtils; import org.junit.Assert; import org.junit.Test; @@ -47,8 +57,6 @@ import java.util.concurrent.TimeUnit; import static org.junit.Assert.fail; -import static org.mockito.Mockito.mock; -import static org.mockito.Mockito.when; public class MSQTasksTest { @@ -214,12 +222,10 @@ public void test_queryWithoutEnoughSlots_shouldThrowException() final int numSlots = 5; final int numTasks = 10; - ControllerContext controllerContext = mock(ControllerContext.class); - when(controllerContext.workerManager()).thenReturn(new TasksTestWorkerManagerClient(numSlots)); MSQWorkerTaskLauncher msqWorkerTaskLauncher = new MSQWorkerTaskLauncher( CONTROLLER_ID, "foo", - controllerContext, + new TasksTestOverlordClient(numSlots), (task, fault) -> {}, ImmutableMap.of(), TimeUnit.SECONDS.toMillis(5) @@ -227,7 +233,7 @@ public void test_queryWithoutEnoughSlots_shouldThrowException() try { msqWorkerTaskLauncher.start(); - msqWorkerTaskLauncher.launchTasksIfNeeded(numTasks); + msqWorkerTaskLauncher.launchWorkersIfNeeded(numTasks); fail(); } catch (Exception e) { @@ -238,7 +244,7 @@ public void test_queryWithoutEnoughSlots_shouldThrowException() } } - static class TasksTestWorkerManagerClient implements WorkerManagerClient + static class TasksTestOverlordClient extends NoopOverlordClient { // Num of slots available for tasks final int numSlots; @@ -252,13 +258,13 @@ static class TasksTestWorkerManagerClient implements WorkerManagerClient @GuardedBy("this") final Set canceledTasks = new HashSet<>(); - public TasksTestWorkerManagerClient(final int numSlots) + public TasksTestOverlordClient(final int numSlots) { this.numSlots = numSlots; } @Override - public synchronized Map statuses(final Set taskIds) + public synchronized ListenableFuture> taskStatuses(final Set taskIds) { final Map retVal = new HashMap<>(); @@ -277,42 +283,66 @@ public synchronized Map statuses(final Set taskIds) } } - return retVal; + return Futures.immediateFuture(retVal); } @Override - public synchronized TaskLocation location(String workerId) + public synchronized ListenableFuture taskStatus(String workerId) { + final TaskStatus status = CollectionUtils.getOnlyElement( + FutureUtils.getUnchecked(taskStatuses(ImmutableSet.of(workerId)), true).values(), + xs -> new ISE("Expected one worker with id[%s] but saw[%s]", workerId, xs) + ); + + final TaskLocation location; + if (runningTasks.contains(workerId)) { - return TaskLocation.create("host-" + workerId, 1, -1); + location = TaskLocation.create("host-" + workerId, 1, -1); } else { - return TaskLocation.unknown(); + location = TaskLocation.unknown(); } + + return Futures.immediateFuture( + new TaskStatusResponse( + status.getId(), + new TaskStatusPlus( + status.getId(), + null, + null, + DateTimes.utc(0), + DateTimes.utc(0), + status.getStatusCode(), + status.getStatusCode(), + RunnerTaskState.NONE, + status.getDuration(), + location, + null, + status.getErrorMsg() + ) + ) + ); } @Override - public synchronized String run(String taskId, MSQWorkerTask task) + public synchronized ListenableFuture runTask(String taskId, Object taskObject) { + final MSQWorkerTask task = (MSQWorkerTask) taskObject; + allTasks.add(task.getId()); if (runningTasks.size() < numSlots) { runningTasks.add(task.getId()); } - return task.getId(); + return Futures.immediateFuture(null); } @Override - public synchronized void cancel(String workerId) + public synchronized ListenableFuture cancelTask(String workerId) { runningTasks.remove(workerId); canceledTasks.add(workerId); - } - - @Override - public void close() - { - // do nothing + return Futures.immediateFuture(null); } } } diff --git a/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/exec/QueryValidatorTest.java b/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/exec/QueryValidatorTest.java index d9cbb48d986c..d7364124483a 100644 --- a/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/exec/QueryValidatorTest.java +++ b/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/exec/QueryValidatorTest.java @@ -106,6 +106,8 @@ public void testMoreInputFiles() 0, 0, Collections.singletonList(() -> inputFiles), // Slice with a large number of inputFiles + null, + null, null ); @@ -125,8 +127,7 @@ public void testMoreInputFiles() private static QueryDefinition createQueryDefinition(int numColumns, int numWorkers) { - QueryDefinitionBuilder builder = QueryDefinition.builder(); - builder.queryId(UUID.randomUUID().toString()); + QueryDefinitionBuilder builder = QueryDefinition.builder(UUID.randomUUID().toString()); StageDefinitionBuilder stageBuilder = StageDefinition.builder(0); builder.add(stageBuilder); diff --git a/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/exec/WorkerSketchFetcherTest.java b/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/exec/WorkerSketchFetcherTest.java index 592fd089ef4e..cba8ede156ce 100644 --- a/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/exec/WorkerSketchFetcherTest.java +++ b/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/exec/WorkerSketchFetcherTest.java @@ -24,7 +24,6 @@ import com.google.common.collect.ImmutableSortedMap; import com.google.common.util.concurrent.Futures; import org.apache.druid.java.util.common.ISE; -import org.apache.druid.msq.indexing.MSQWorkerTaskLauncher; import org.apache.druid.msq.kernel.StageDefinition; import org.apache.druid.msq.kernel.StageId; import org.apache.druid.msq.kernel.controller.ControllerQueryKernel; @@ -44,7 +43,6 @@ import static org.easymock.EasyMock.mock; import static org.mockito.ArgumentMatchers.any; -import static org.mockito.ArgumentMatchers.anyInt; import static org.mockito.ArgumentMatchers.anyLong; import static org.mockito.Mockito.doAnswer; import static org.mockito.Mockito.doReturn; @@ -56,7 +54,7 @@ public class WorkerSketchFetcherTest private CompleteKeyStatisticsInformation completeKeyStatisticsInformation; @Mock - private MSQWorkerTaskLauncher workerTaskLauncher; + private WorkerManager workerManager; @Mock private ControllerQueryKernel kernel; @@ -82,7 +80,10 @@ public void setUp() doReturn(ImmutableSortedMap.of(123L, ImmutableSet.of(1, 2))).when(completeKeyStatisticsInformation) .getTimeSegmentVsWorkerMap(); - doReturn(true).when(workerTaskLauncher).isTaskLatest(any()); + doReturn(0).when(workerManager).getWorkerNumber(TASK_0); + doReturn(1).when(workerManager).getWorkerNumber(TASK_1); + doReturn(2).when(workerManager).getWorkerNumber(TASK_2); + doReturn(true).when(workerManager).isWorkerActive(any()); } @After @@ -100,13 +101,13 @@ public void test_submitFetcherTask_parallelFetch() throws InterruptedException final CountDownLatch latch = new CountDownLatch(TASK_IDS.size()); - target = spy(new WorkerSketchFetcher(workerClient, workerTaskLauncher, true)); + target = spy(new WorkerSketchFetcher(workerClient, workerManager, true)); // When fetching snapshots, return a mock and add it to queue doAnswer(invocation -> { ClusterByStatisticsSnapshot snapshot = mock(ClusterByStatisticsSnapshot.class); return Futures.immediateFuture(snapshot); - }).when(workerClient).fetchClusterByStatisticsSnapshot(any(), any(), anyInt()); + }).when(workerClient).fetchClusterByStatisticsSnapshot(any(), any()); target.inMemoryFullSketchMerging((kernelConsumer) -> { kernelConsumer.accept(kernel); @@ -123,13 +124,13 @@ public void test_submitFetcherTask_sequentialFetch() throws InterruptedException doReturn(true).when(completeKeyStatisticsInformation).isComplete(); final CountDownLatch latch = new CountDownLatch(TASK_IDS.size()); - target = spy(new WorkerSketchFetcher(workerClient, workerTaskLauncher, true)); + target = spy(new WorkerSketchFetcher(workerClient, workerManager, true)); // When fetching snapshots, return a mock and add it to queue doAnswer(invocation -> { ClusterByStatisticsSnapshot snapshot = mock(ClusterByStatisticsSnapshot.class); return Futures.immediateFuture(snapshot); - }).when(workerClient).fetchClusterByStatisticsSnapshotForTimeChunk(any(), any(), anyInt(), anyLong()); + }).when(workerClient).fetchClusterByStatisticsSnapshotForTimeChunk(any(), any(), anyLong()); target.sequentialTimeChunkMerging( (kernelConsumer) -> { @@ -151,7 +152,7 @@ public void test_sequentialMerge_nonCompleteInformation() { doReturn(false).when(completeKeyStatisticsInformation).isComplete(); - target = spy(new WorkerSketchFetcher(workerClient, workerTaskLauncher, true)); + target = spy(new WorkerSketchFetcher(workerClient, workerManager, true)); Assert.assertThrows(ISE.class, () -> target.sequentialTimeChunkMerging( (ignore) -> {}, completeKeyStatisticsInformation, @@ -166,7 +167,7 @@ public void test_inMemoryRetryEnabled_retryInvoked() throws InterruptedException { final CountDownLatch latch = new CountDownLatch(TASK_IDS.size()); - target = spy(new WorkerSketchFetcher(workerClient, workerTaskLauncher, true)); + target = spy(new WorkerSketchFetcher(workerClient, workerManager, true)); workersWithFailedFetchParallel(ImmutableSet.of(TASK_1)); @@ -185,8 +186,8 @@ public void test_inMemoryRetryEnabled_retryInvoked() throws InterruptedException }) ); - Assert.assertTrue(latch.await(5, TimeUnit.SECONDS)); - Assert.assertTrue(retryLatch.await(5, TimeUnit.SECONDS)); + Assert.assertTrue(latch.await(500, TimeUnit.SECONDS)); + Assert.assertTrue(retryLatch.await(500, TimeUnit.SECONDS)); } @Test @@ -195,7 +196,7 @@ public void test_SequentialRetryEnabled_retryInvoked() throws InterruptedExcepti doReturn(true).when(completeKeyStatisticsInformation).isComplete(); final CountDownLatch latch = new CountDownLatch(TASK_IDS.size()); - target = spy(new WorkerSketchFetcher(workerClient, workerTaskLauncher, true)); + target = spy(new WorkerSketchFetcher(workerClient, workerManager, true)); workersWithFailedFetchSequential(ImmutableSet.of(TASK_1)); CountDownLatch retryLatch = new CountDownLatch(1); @@ -222,7 +223,7 @@ public void test_SequentialRetryEnabled_retryInvoked() throws InterruptedExcepti public void test_InMemoryRetryDisabled_multipleFailures() throws InterruptedException { - target = spy(new WorkerSketchFetcher(workerClient, workerTaskLauncher, false)); + target = spy(new WorkerSketchFetcher(workerClient, workerManager, false)); workersWithFailedFetchParallel(ImmutableSet.of(TASK_1, TASK_0)); @@ -251,7 +252,7 @@ public void test_InMemoryRetryDisabled_multipleFailures() throws InterruptedExce public void test_InMemoryRetryDisabled_singleFailure() throws InterruptedException { - target = spy(new WorkerSketchFetcher(workerClient, workerTaskLauncher, false)); + target = spy(new WorkerSketchFetcher(workerClient, workerManager, false)); workersWithFailedFetchParallel(ImmutableSet.of(TASK_1)); @@ -282,7 +283,7 @@ public void test_SequentialRetryDisabled_multipleFailures() throws InterruptedEx { doReturn(true).when(completeKeyStatisticsInformation).isComplete(); - target = spy(new WorkerSketchFetcher(workerClient, workerTaskLauncher, false)); + target = spy(new WorkerSketchFetcher(workerClient, workerManager, false)); workersWithFailedFetchSequential(ImmutableSet.of(TASK_1, TASK_0)); @@ -314,7 +315,7 @@ public void test_SequentialRetryDisabled_multipleFailures() throws InterruptedEx public void test_SequentialRetryDisabled_singleFailure() throws InterruptedException { doReturn(true).when(completeKeyStatisticsInformation).isComplete(); - target = spy(new WorkerSketchFetcher(workerClient, workerTaskLauncher, false)); + target = spy(new WorkerSketchFetcher(workerClient, workerManager, false)); workersWithFailedFetchSequential(ImmutableSet.of(TASK_1)); @@ -351,7 +352,7 @@ private void workersWithFailedFetchSequential(Set failedTasks) return Futures.immediateFailedFuture(new Exception("Task fetch failed :" + invocation.getArgument(0))); } return Futures.immediateFuture(snapshot); - }).when(workerClient).fetchClusterByStatisticsSnapshotForTimeChunk(any(), any(), anyInt(), anyLong()); + }).when(workerClient).fetchClusterByStatisticsSnapshotForTimeChunk(any(), any(), anyLong()); } private void workersWithFailedFetchParallel(Set failedTasks) @@ -362,7 +363,7 @@ private void workersWithFailedFetchParallel(Set failedTasks) return Futures.immediateFailedFuture(new Exception("Task fetch failed :" + invocation.getArgument(0))); } return Futures.immediateFuture(snapshot); - }).when(workerClient).fetchClusterByStatisticsSnapshot(any(), any(), anyInt()); + }).when(workerClient).fetchClusterByStatisticsSnapshot(any(), any()); } } diff --git a/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/indexing/MSQWorkerTaskLauncherTest.java b/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/indexing/MSQWorkerTaskLauncherTest.java index bc3f24065aea..25ab33f76f94 100644 --- a/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/indexing/MSQWorkerTaskLauncherTest.java +++ b/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/indexing/MSQWorkerTaskLauncherTest.java @@ -21,7 +21,7 @@ import com.google.common.collect.ImmutableMap; import com.google.common.collect.ImmutableSet; -import org.apache.druid.msq.exec.ControllerContext; +import org.apache.druid.rpc.indexing.OverlordClient; import org.junit.Assert; import org.junit.Before; import org.junit.Test; @@ -40,7 +40,7 @@ public void setUp() target = new MSQWorkerTaskLauncher( "controller-id", "foo", - Mockito.mock(ControllerContext.class), + Mockito.mock(OverlordClient.class), (task, fault) -> {}, ImmutableMap.of(), TimeUnit.SECONDS.toMillis(5) diff --git a/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/indexing/TaskReportQueryListenerTest.java b/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/indexing/TaskReportQueryListenerTest.java new file mode 100644 index 000000000000..11c33d215170 --- /dev/null +++ b/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/indexing/TaskReportQueryListenerTest.java @@ -0,0 +1,206 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.druid.msq.indexing; + +import com.fasterxml.jackson.core.type.TypeReference; +import com.fasterxml.jackson.databind.ObjectMapper; +import com.google.common.base.Suppliers; +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; +import com.google.common.collect.ImmutableSet; +import org.apache.calcite.sql.type.SqlTypeName; +import org.apache.druid.indexer.TaskState; +import org.apache.druid.indexer.report.TaskContextReport; +import org.apache.druid.indexer.report.TaskReport; +import org.apache.druid.msq.counters.CounterSnapshotsTree; +import org.apache.druid.msq.exec.Limits; +import org.apache.druid.msq.guice.MSQIndexingModule; +import org.apache.druid.msq.indexing.destination.DurableStorageMSQDestination; +import org.apache.druid.msq.indexing.destination.TaskReportMSQDestination; +import org.apache.druid.msq.indexing.report.MSQResultsReport; +import org.apache.druid.msq.indexing.report.MSQStagesReport; +import org.apache.druid.msq.indexing.report.MSQStatusReport; +import org.apache.druid.msq.indexing.report.MSQTaskReport; +import org.apache.druid.msq.indexing.report.MSQTaskReportPayload; +import org.apache.druid.msq.indexing.report.MSQTaskReportTest; +import org.apache.druid.segment.TestHelper; +import org.apache.druid.segment.column.ColumnType; +import org.junit.Assert; +import org.junit.Test; + +import java.io.ByteArrayOutputStream; +import java.io.IOException; +import java.util.Arrays; +import java.util.Collections; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.stream.Collectors; +import java.util.stream.IntStream; + +public class TaskReportQueryListenerTest +{ + private static final String TASK_ID = "mytask"; + private static final Map TASK_CONTEXT = ImmutableMap.of("foo", "bar"); + private static final List SIGNATURE = ImmutableList.of( + new MSQResultsReport.ColumnAndType("x", ColumnType.STRING) + ); + private static final List SQL_TYPE_NAMES = ImmutableList.of(SqlTypeName.VARCHAR); + private static final ObjectMapper JSON_MAPPER = + TestHelper.makeJsonMapper().registerModules(new MSQIndexingModule().getJacksonModules()); + + private final ByteArrayOutputStream baos = new ByteArrayOutputStream(); + + @Test + public void test_taskReportDestination() throws IOException + { + final TaskReportQueryListener listener = new TaskReportQueryListener( + TaskReportMSQDestination.instance(), + Suppliers.ofInstance(baos)::get, + JSON_MAPPER, + TASK_ID, + TASK_CONTEXT + ); + + Assert.assertTrue(listener.readResults()); + listener.onResultsStart(SIGNATURE, SQL_TYPE_NAMES); + Assert.assertTrue(listener.onResultRow(new Object[]{"foo"})); + Assert.assertTrue(listener.onResultRow(new Object[]{"bar"})); + listener.onResultsComplete(); + listener.onQueryComplete( + new MSQTaskReportPayload( + new MSQStatusReport( + TaskState.SUCCESS, + null, + Collections.emptyList(), + null, + 0, + new HashMap<>(), + 1, + 2, + null, + null + ), + MSQStagesReport.create( + MSQTaskReportTest.QUERY_DEFINITION, + ImmutableMap.of(), + ImmutableMap.of(), + ImmutableMap.of(), + ImmutableMap.of(), + ImmutableMap.of() + ), + new CounterSnapshotsTree(), + null + ) + ); + + final TaskReport.ReportMap reportMap = + JSON_MAPPER.readValue( + baos.toByteArray(), + new TypeReference() {} + ); + + Assert.assertEquals(ImmutableSet.of("multiStageQuery", TaskContextReport.REPORT_KEY), reportMap.keySet()); + Assert.assertEquals(TASK_CONTEXT, ((TaskContextReport) reportMap.get(TaskContextReport.REPORT_KEY)).getPayload()); + + final MSQTaskReport report = (MSQTaskReport) reportMap.get("multiStageQuery"); + final List> results = + report.getPayload().getResults().getResults().stream().map(Arrays::asList).collect(Collectors.toList()); + + Assert.assertEquals( + ImmutableList.of( + ImmutableList.of("foo"), + ImmutableList.of("bar") + ), + results + ); + + Assert.assertFalse(report.getPayload().getResults().isResultsTruncated()); + Assert.assertEquals(TaskState.SUCCESS, report.getPayload().getStatus().getStatus()); + } + + @Test + public void test_durableDestination() throws IOException + { + final TaskReportQueryListener listener = new TaskReportQueryListener( + DurableStorageMSQDestination.instance(), + Suppliers.ofInstance(baos)::get, + JSON_MAPPER, + TASK_ID, + TASK_CONTEXT + ); + + Assert.assertTrue(listener.readResults()); + listener.onResultsStart(SIGNATURE, SQL_TYPE_NAMES); + for (int i = 0; i < Limits.MAX_SELECT_RESULT_ROWS - 1; i++) { + Assert.assertTrue("row #" + i, listener.onResultRow(new Object[]{"foo"})); + } + Assert.assertFalse(listener.onResultRow(new Object[]{"foo"})); + listener.onQueryComplete( + new MSQTaskReportPayload( + new MSQStatusReport( + TaskState.SUCCESS, + null, + Collections.emptyList(), + null, + 0, + new HashMap<>(), + 1, + 2, + null, + null + ), + MSQStagesReport.create( + MSQTaskReportTest.QUERY_DEFINITION, + ImmutableMap.of(), + ImmutableMap.of(), + ImmutableMap.of(), + ImmutableMap.of(), + ImmutableMap.of() + ), + new CounterSnapshotsTree(), + null + ) + ); + + final TaskReport.ReportMap reportMap = + JSON_MAPPER.readValue( + baos.toByteArray(), + new TypeReference() {} + ); + + Assert.assertEquals(ImmutableSet.of("multiStageQuery", TaskContextReport.REPORT_KEY), reportMap.keySet()); + Assert.assertEquals(TASK_CONTEXT, ((TaskContextReport) reportMap.get(TaskContextReport.REPORT_KEY)).getPayload()); + + final MSQTaskReport report = (MSQTaskReport) reportMap.get("multiStageQuery"); + final List> results = + report.getPayload().getResults().getResults().stream().map(Arrays::asList).collect(Collectors.toList()); + + Assert.assertEquals( + IntStream.range(0, (int) Limits.MAX_SELECT_RESULT_ROWS) + .mapToObj(i -> ImmutableList.of("foo")) + .collect(Collectors.toList()), + results + ); + + Assert.assertTrue(report.getPayload().getResults().isResultsTruncated()); + Assert.assertEquals(TaskState.SUCCESS, report.getPayload().getStatus().getStatus()); + } +} diff --git a/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/indexing/WorkerChatHandlerTest.java b/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/indexing/WorkerChatHandlerTest.java index 9fe32cc8c8c1..3fd346f4db42 100644 --- a/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/indexing/WorkerChatHandlerTest.java +++ b/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/indexing/WorkerChatHandlerTest.java @@ -22,9 +22,8 @@ import com.fasterxml.jackson.databind.ObjectMapper; import org.apache.druid.frame.key.ClusterByPartitions; import org.apache.druid.indexer.TaskStatus; -import org.apache.druid.indexer.report.TaskReport; -import org.apache.druid.indexer.report.TaskReportFileWriter; import org.apache.druid.indexing.common.TaskToolbox; +import org.apache.druid.indexing.common.task.NoopTestTaskReportFileWriter; import org.apache.druid.jackson.DefaultObjectMapper; import org.apache.druid.java.util.common.ISE; import org.apache.druid.msq.counters.CounterSnapshotsTree; @@ -83,22 +82,7 @@ public void setUp() toolbox = builder.authorizerMapper(CalciteTests.TEST_AUTHORIZER_MAPPER) .indexIO(indexIO) .indexMergerV9(indexMerger) - .taskReportFileWriter( - new TaskReportFileWriter() - { - @Override - public void write(String taskId, TaskReport.ReportMap reports) - { - - } - - @Override - public void setObjectMapper(ObjectMapper objectMapper) - { - - } - } - ) + .taskReportFileWriter(new NoopTestTaskReportFileWriter()) .build(); } diff --git a/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/indexing/client/ControllerChatHandlerTest.java b/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/indexing/client/ControllerChatHandlerTest.java index 10a724f4b7ed..93436c84eadd 100644 --- a/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/indexing/client/ControllerChatHandlerTest.java +++ b/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/indexing/client/ControllerChatHandlerTest.java @@ -21,9 +21,7 @@ import org.apache.druid.indexer.report.KillTaskReport; import org.apache.druid.indexer.report.TaskReport; -import org.apache.druid.indexing.common.TaskToolbox; import org.apache.druid.msq.exec.Controller; -import org.apache.druid.msq.indexing.MSQControllerTask; import org.apache.druid.server.security.AuthorizerMapper; import org.junit.Assert; import org.junit.Test; @@ -35,6 +33,8 @@ public class ControllerChatHandlerTest { + private static final String DATASOURCE = "wiki"; + @Test public void testHttpGetLiveReports() { @@ -46,17 +46,8 @@ public void testHttpGetLiveReports() Mockito.when(controller.liveReports()) .thenReturn(reportMap); - MSQControllerTask task = Mockito.mock(MSQControllerTask.class); - Mockito.when(task.getDataSource()) - .thenReturn("wiki"); - Mockito.when(controller.task()) - .thenReturn(task); - - TaskToolbox toolbox = Mockito.mock(TaskToolbox.class); - Mockito.when(toolbox.getAuthorizerMapper()) - .thenReturn(new AuthorizerMapper(null)); - - ControllerChatHandler chatHandler = new ControllerChatHandler(toolbox, controller); + final AuthorizerMapper authorizerMapper = new AuthorizerMapper(null); + ControllerChatHandler chatHandler = new ControllerChatHandler(controller, DATASOURCE, authorizerMapper); HttpServletRequest httpRequest = Mockito.mock(HttpServletRequest.class); Mockito.when(httpRequest.getAttribute(ArgumentMatchers.anyString())) diff --git a/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/indexing/report/MSQTaskReportTest.java b/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/indexing/report/MSQTaskReportTest.java index 158f65a05940..4ab992aec096 100644 --- a/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/indexing/report/MSQTaskReportTest.java +++ b/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/indexing/report/MSQTaskReportTest.java @@ -30,9 +30,6 @@ import org.apache.druid.indexer.report.SingleFileTaskReportFileWriter; import org.apache.druid.indexer.report.TaskReport; import org.apache.druid.java.util.common.DateTimes; -import org.apache.druid.java.util.common.guava.Sequences; -import org.apache.druid.java.util.common.guava.Yielder; -import org.apache.druid.java.util.common.guava.Yielders; import org.apache.druid.msq.counters.CounterSnapshotsTree; import org.apache.druid.msq.exec.SegmentLoadStatusFetcher; import org.apache.druid.msq.guice.MSQIndexingModule; @@ -52,10 +49,10 @@ import java.io.File; import java.util.ArrayDeque; -import java.util.ArrayList; import java.util.Collections; import java.util.HashMap; import java.util.List; +import java.util.UUID; public class MSQTaskReportTest { @@ -63,7 +60,7 @@ public class MSQTaskReportTest private static final String HOST = "example.com:1234"; public static final QueryDefinition QUERY_DEFINITION = QueryDefinition - .builder() + .builder(UUID.randomUUID().toString()) .add( StageDefinition .builder(0) @@ -112,13 +109,14 @@ public void testSerdeResultsReport() throws Exception ImmutableMap.of(), ImmutableMap.of(), ImmutableMap.of(), + ImmutableMap.of(), ImmutableMap.of() ), new CounterSnapshotsTree(), new MSQResultsReport( Collections.singletonList(new MSQResultsReport.ColumnAndType("s", ColumnType.STRING)), ImmutableList.of(SqlTypeName.VARCHAR), - Yielders.each(Sequences.simple(results)), + results, null ) ) @@ -139,13 +137,7 @@ public void testSerdeResultsReport() throws Exception Assert.assertEquals(report.getPayload().getStatus().getPendingTasks(), report2.getPayload().getStatus().getPendingTasks()); Assert.assertEquals(report.getPayload().getStages(), report2.getPayload().getStages()); - Yielder yielder = report2.getPayload().getResults().getResultYielder(); - final List results2 = new ArrayList<>(); - - while (!yielder.isDone()) { - results2.add(yielder.get()); - yielder = yielder.next(null); - } + final List results2 = report2.getPayload().getResults().getResults(); Assert.assertEquals(results.size(), results2.size()); for (int i = 0; i < results.size(); i++) { Assert.assertArrayEquals(results.get(i), results2.get(i)); @@ -177,6 +169,7 @@ public void testSerdeErrorReport() throws Exception ImmutableMap.of(), ImmutableMap.of(), ImmutableMap.of(), + ImmutableMap.of(), ImmutableMap.of() ), new CounterSnapshotsTree(), @@ -225,6 +218,7 @@ public void testWriteTaskReport() throws Exception ImmutableMap.of(), ImmutableMap.of(), ImmutableMap.of(), + ImmutableMap.of(), ImmutableMap.of() ), new CounterSnapshotsTree(), diff --git a/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/input/stage/StageInputSliceTest.java b/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/input/stage/StageInputSliceTest.java index 3b2705c8ba6c..d550fef84c77 100644 --- a/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/input/stage/StageInputSliceTest.java +++ b/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/input/stage/StageInputSliceTest.java @@ -21,6 +21,7 @@ import com.fasterxml.jackson.databind.ObjectMapper; import nl.jqno.equalsverifier.EqualsVerifier; +import org.apache.druid.msq.exec.OutputChannelMode; import org.apache.druid.msq.guice.MSQIndexingModule; import org.apache.druid.msq.input.InputSlice; import org.apache.druid.segment.TestHelper; @@ -37,7 +38,8 @@ public void testSerde() throws Exception final StageInputSlice slice = new StageInputSlice( 2, - ReadablePartitions.striped(2, 3, 4) + ReadablePartitions.striped(2, 3, 4), + OutputChannelMode.MEMORY ); Assert.assertEquals( diff --git a/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/input/stage/StageInputSpecSlicerTest.java b/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/input/stage/StageInputSpecSlicerTest.java index 43d89e7fc690..024ad956cf29 100644 --- a/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/input/stage/StageInputSpecSlicerTest.java +++ b/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/input/stage/StageInputSpecSlicerTest.java @@ -24,6 +24,7 @@ import it.unimi.dsi.fastutil.ints.Int2ObjectMap; import it.unimi.dsi.fastutil.ints.Int2ObjectOpenHashMap; import it.unimi.dsi.fastutil.ints.IntAVLTreeSet; +import org.apache.druid.msq.exec.OutputChannelMode; import org.hamcrest.CoreMatchers; import org.hamcrest.MatcherAssert; import org.junit.Assert; @@ -43,12 +44,21 @@ public class StageInputSpecSlicerTest .build() ); + private static final Int2ObjectMap STAGE_OUTPUT_MODE_MAP = + new Int2ObjectOpenHashMap<>( + ImmutableMap.builder() + .put(0, OutputChannelMode.LOCAL_STORAGE) + .put(1, OutputChannelMode.LOCAL_STORAGE) + .put(2, OutputChannelMode.LOCAL_STORAGE) + .build() + ); + private StageInputSpecSlicer slicer; @Before public void setUp() { - slicer = new StageInputSpecSlicer(STAGE_PARTITIONS_MAP); + slicer = new StageInputSpecSlicer(STAGE_PARTITIONS_MAP, STAGE_OUTPUT_MODE_MAP); } @Test @@ -64,7 +74,8 @@ public void test_sliceStatic_stageZeroOneSlice() Collections.singletonList( new StageInputSlice( 0, - ReadablePartitions.striped(0, 2, 2) + ReadablePartitions.striped(0, 2, 2), + OutputChannelMode.LOCAL_STORAGE ) ), slicer.sliceStatic(new StageInputSpec(0), 1) @@ -78,11 +89,13 @@ public void test_sliceStatic_stageZeroTwoSlices() ImmutableList.of( new StageInputSlice( 0, - new StripedReadablePartitions(0, 2, new IntAVLTreeSet(new int[]{0})) + new StripedReadablePartitions(0, 2, new IntAVLTreeSet(new int[]{0})), + OutputChannelMode.LOCAL_STORAGE ), new StageInputSlice( 0, - new StripedReadablePartitions(0, 2, new IntAVLTreeSet(new int[]{1})) + new StripedReadablePartitions(0, 2, new IntAVLTreeSet(new int[]{1})), + OutputChannelMode.LOCAL_STORAGE ) ), slicer.sliceStatic(new StageInputSpec(0), 2) @@ -96,11 +109,13 @@ public void test_sliceStatic_stageOneTwoSlices() ImmutableList.of( new StageInputSlice( 1, - new StripedReadablePartitions(1, 2, new IntAVLTreeSet(new int[]{0, 2})) + new StripedReadablePartitions(1, 2, new IntAVLTreeSet(new int[]{0, 2})), + OutputChannelMode.LOCAL_STORAGE ), new StageInputSlice( 1, - new StripedReadablePartitions(1, 2, new IntAVLTreeSet(new int[]{1, 3})) + new StripedReadablePartitions(1, 2, new IntAVLTreeSet(new int[]{1, 3})), + OutputChannelMode.LOCAL_STORAGE ) ), slicer.sliceStatic(new StageInputSpec(1), 2) @@ -115,6 +130,6 @@ public void test_sliceStatic_notAvailable() () -> slicer.sliceStatic(new StageInputSpec(3), 1) ); - MatcherAssert.assertThat(e.getMessage(), CoreMatchers.equalTo("Stage [3] not available")); + MatcherAssert.assertThat(e.getMessage(), CoreMatchers.equalTo("Stage[3] output partitions not available")); } } diff --git a/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/input/table/TableInputSpecSlicerTest.java b/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/input/table/TableInputSpecSlicerTest.java index 634427d01a9b..a27ae7d97804 100644 --- a/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/input/table/TableInputSpecSlicerTest.java +++ b/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/input/table/TableInputSpecSlicerTest.java @@ -19,15 +19,20 @@ package org.apache.druid.msq.input.table; +import com.google.common.collect.FluentIterable; import com.google.common.collect.ImmutableList; import org.apache.druid.data.input.StringTuple; +import org.apache.druid.indexing.common.actions.RetrieveUsedSegmentsAction; +import org.apache.druid.indexing.common.actions.TaskAction; +import org.apache.druid.indexing.common.actions.TaskActionClient; import org.apache.druid.java.util.common.Intervals; +import org.apache.druid.msq.exec.SegmentSource; import org.apache.druid.msq.input.NilInputSlice; -import org.apache.druid.msq.querykit.DataSegmentTimelineView; import org.apache.druid.query.filter.SelectorDimFilter; import org.apache.druid.testing.InitializedNullHandlingTest; import org.apache.druid.timeline.DataSegment; import org.apache.druid.timeline.SegmentTimeline; +import org.apache.druid.timeline.VersionedIntervalTimeline; import org.apache.druid.timeline.partition.DimensionRangeShardSpec; import org.apache.druid.timeline.partition.TombstoneShardSpec; import org.junit.Assert; @@ -35,7 +40,6 @@ import org.junit.Test; import java.util.Collections; -import java.util.Optional; public class TableInputSpecSlicerTest extends InitializedNullHandlingTest { @@ -94,19 +98,44 @@ public class TableInputSpecSlicerTest extends InitializedNullHandlingTest ); private SegmentTimeline timeline; private TableInputSpecSlicer slicer; + private TaskActionClient taskActionClient; @Before public void setUp() { timeline = SegmentTimeline.forSegments(ImmutableList.of(SEGMENT1, SEGMENT2, SEGMENT3)); - DataSegmentTimelineView timelineView = (dataSource, intervals) -> { - if (DATASOURCE.equals(dataSource)) { - return Optional.of(timeline); - } else { - return Optional.empty(); + taskActionClient = new TaskActionClient() + { + @Override + @SuppressWarnings("unchecked") + public RetType submit(TaskAction taskAction) + { + if (taskAction instanceof RetrieveUsedSegmentsAction) { + final RetrieveUsedSegmentsAction retrieveUsedSegmentsAction = (RetrieveUsedSegmentsAction) taskAction; + final String dataSource = retrieveUsedSegmentsAction.getDataSource(); + + if (DATASOURCE.equals(dataSource)) { + return (RetType) FluentIterable + .from(retrieveUsedSegmentsAction.getIntervals()) + .transformAndConcat( + interval -> + VersionedIntervalTimeline.getAllObjects(timeline.lookup(interval)) + ) + .toList(); + } else { + return (RetType) Collections.emptyList(); + } + } + + throw new UnsupportedOperationException(); } }; - slicer = new TableInputSpecSlicer(timelineView); + + slicer = new TableInputSpecSlicer( + null /* not used for SegmentSource.NONE */, + taskActionClient, + SegmentSource.NONE + ); } @Test diff --git a/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/kernel/QueryDefinitionTest.java b/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/kernel/QueryDefinitionTest.java index 8a5533d22cb0..857584127a9a 100644 --- a/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/kernel/QueryDefinitionTest.java +++ b/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/kernel/QueryDefinitionTest.java @@ -33,6 +33,8 @@ import org.junit.Assert; import org.junit.Test; +import java.util.UUID; + public class QueryDefinitionTest { @Test @@ -40,7 +42,7 @@ public void testSerde() throws Exception { final QueryDefinition queryDef = QueryDefinition - .builder() + .builder(UUID.randomUUID().toString()) .add( StageDefinition .builder(0) diff --git a/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/kernel/controller/BaseControllerQueryKernelTest.java b/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/kernel/controller/BaseControllerQueryKernelTest.java index 6ae18dda1e1d..2365b5cf86bc 100644 --- a/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/kernel/controller/BaseControllerQueryKernelTest.java +++ b/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/kernel/controller/BaseControllerQueryKernelTest.java @@ -19,6 +19,7 @@ package org.apache.druid.msq.kernel.controller; +import com.google.common.base.Function; import com.google.common.base.Preconditions; import com.google.common.collect.ImmutableMap; import com.google.common.collect.ImmutableSet; @@ -28,6 +29,7 @@ import org.apache.druid.java.util.common.ISE; import org.apache.druid.java.util.common.StringUtils; import org.apache.druid.msq.exec.Limits; +import org.apache.druid.msq.indexing.destination.TaskReportMSQDestination; import org.apache.druid.msq.indexing.error.MSQFault; import org.apache.druid.msq.indexing.error.UnknownFault; import org.apache.druid.msq.input.InputSpecSlicerFactory; @@ -47,15 +49,28 @@ import java.util.List; import java.util.Set; import java.util.stream.Collectors; -import java.util.stream.IntStream; public class BaseControllerQueryKernelTest extends InitializedNullHandlingTest { public static final UnknownFault RETRIABLE_FAULT = UnknownFault.forMessage(""); - public ControllerQueryKernelTester testControllerQueryKernel(int numWorkers) + public ControllerQueryKernelTester testControllerQueryKernel() { - return new ControllerQueryKernelTester(numWorkers); + return testControllerQueryKernel(ControllerQueryKernelConfig.Builder::build); + } + + public ControllerQueryKernelTester testControllerQueryKernel( + final Function configFn + ) + { + return new ControllerQueryKernelTester( + configFn.apply( + ControllerQueryKernelConfig + .builder() + .maxRetainedPartitionSketchBytes(100_000_000) + .destination(TaskReportMSQDestination.instance()) + ) + ); } /** @@ -69,34 +84,29 @@ public static class ControllerQueryKernelTester private boolean initialized = false; private QueryDefinition queryDefinition = null; private ControllerQueryKernel controllerQueryKernel = null; - private InputSpecSlicerFactory inputSlicerFactory = - stagePartitionsMap -> + private final InputSpecSlicerFactory inputSlicerFactory = + (stagePartitionsMap, stageOutputChannelModeMap) -> new MapInputSpecSlicer( ImmutableMap.of( - StageInputSpec.class, new StageInputSpecSlicer(stagePartitionsMap), + StageInputSpec.class, new StageInputSpecSlicer(stagePartitionsMap, stageOutputChannelModeMap), ControllerTestInputSpec.class, new ControllerTestInputSpecSlicer() ) ); - private final int numWorkers; + private final ControllerQueryKernelConfig config; Set setupStages = new HashSet<>(); - private ControllerQueryKernelTester(int numWorkers) + private ControllerQueryKernelTester(ControllerQueryKernelConfig config) { - this.numWorkers = numWorkers; + this.config = config; } public ControllerQueryKernelTester queryDefinition(QueryDefinition queryDefinition) { this.queryDefinition = Preconditions.checkNotNull(queryDefinition); - this.controllerQueryKernel = new ControllerQueryKernel( - queryDefinition, - 100_000_000, - true - ); + this.controllerQueryKernel = new ControllerQueryKernel(queryDefinition, config); return this; } - public ControllerQueryKernelTester setupStage( int stageNumber, ControllerStagePhase controllerStagePhase @@ -275,11 +285,17 @@ public void startWorkOrder(int stageNumber) { StageId stageId = new StageId(queryDefinition.getQueryId(), stageNumber); Preconditions.checkArgument(initialized); - IntStream.range(0, queryDefinition.getStageDefinition(stageId).getMaxWorkerCount()) - .forEach(n -> controllerQueryKernel.workOrdersSentForWorker(stageId, n)); - + controllerQueryKernel.getWorkerInputsForStage(stageId).workers() + .forEach(n -> controllerQueryKernel.workOrdersSentForWorker(stageId, n)); } + public void doneReadingInput(int stageNumber) + { + StageId stageId = new StageId(queryDefinition.getQueryId(), stageNumber); + Preconditions.checkArgument(initialized); + controllerQueryKernel.getWorkerInputsForStage(stageId).workers() + .forEach(n -> controllerQueryKernel.setDoneReadingInputForStageAndWorker(stageId, n)); + } public void finishStage(int stageNumber) { @@ -353,22 +369,27 @@ public void failStage(int stageNumber) public void assertStagePhase(int stageNumber, ControllerStagePhase expectedControllerStagePhase) { Preconditions.checkArgument(initialized); - ControllerStageTracker controllerStageTracker = Preconditions.checkNotNull( - controllerQueryKernel.getControllerStageKernel(stageNumber), + ControllerStageTracker controllerStageKernel = Preconditions.checkNotNull( + controllerQueryKernel.getControllerStageTracker(stageNumber), StringUtils.format("Stage kernel for stage number %d is not initialized yet", stageNumber) ); - if (controllerStageTracker.getPhase() != expectedControllerStagePhase) { + if (controllerStageKernel.getPhase() != expectedControllerStagePhase) { throw new ISE( StringUtils.format( "Stage kernel for stage number %d is in %s phase which is different from the expected phase %s", stageNumber, - controllerStageTracker.getPhase(), + controllerStageKernel.getPhase(), expectedControllerStagePhase ) ); } } + public ControllerQueryKernelConfig getConfig() + { + return config; + } + /** * Checks if the state of the BaseControllerQueryKernel is initialized properly. Currently, this is just stubbed to * return true irrespective of the actual state diff --git a/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/kernel/controller/ControllerQueryKernelTest.java b/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/kernel/controller/ControllerQueryKernelTest.java index 8e47c470bf82..03f963b133bf 100644 --- a/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/kernel/controller/ControllerQueryKernelTest.java +++ b/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/kernel/controller/ControllerQueryKernelTest.java @@ -19,9 +19,13 @@ package org.apache.druid.msq.kernel.controller; +import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableSet; import org.apache.druid.java.util.common.ISE; import org.apache.druid.java.util.common.StringUtils; +import org.apache.druid.msq.exec.OutputChannelMode; +import org.apache.druid.msq.kernel.QueryDefinition; +import org.apache.druid.msq.kernel.ShuffleKind; import org.apache.druid.msq.kernel.worker.WorkerStagePhase; import org.junit.Assert; import org.junit.Test; @@ -34,7 +38,7 @@ public class ControllerQueryKernelTest extends BaseControllerQueryKernelTest @Test public void testCompleteDAGExecutionForSingleWorker() { - ControllerQueryKernelTester controllerQueryKernelTester = testControllerQueryKernel(1); + ControllerQueryKernelTester controllerQueryKernelTester = testControllerQueryKernel(); // 0 1 // | / | // 2 / 3 @@ -44,13 +48,13 @@ public void testCompleteDAGExecutionForSingleWorker() // 6 controllerQueryKernelTester.queryDefinition( new MockQueryDefinitionBuilder(7) - .addVertex(0, 2) - .addVertex(1, 2) - .addVertex(1, 3) - .addVertex(2, 4) - .addVertex(3, 5) - .addVertex(4, 6) - .addVertex(5, 6) + .addEdge(0, 2) + .addEdge(1, 2) + .addEdge(1, 3) + .addEdge(2, 4) + .addEdge(3, 5) + .addEdge(4, 6) + .addEdge(5, 6) .getQueryDefinitionBuilder() .build() ); @@ -62,79 +66,196 @@ public void testCompleteDAGExecutionForSingleWorker() newStageNumbers = controllerQueryKernelTester.createAndGetNewStageNumbers(); effectivelyFinishedStageNumbers = controllerQueryKernelTester.getEffectivelyFinishedStageNumbers(); - Assert.assertEquals(ImmutableSet.of(0, 1), newStageNumbers); + Assert.assertEquals(ImmutableSet.of(0), newStageNumbers); Assert.assertEquals(ImmutableSet.of(), effectivelyFinishedStageNumbers); + // Mark 0 as done. Next up will be 1. + transitionNewToResultsComplete(controllerQueryKernelTester, 0); + newStageNumbers = controllerQueryKernelTester.createAndGetNewStageNumbers(); + Assert.assertEquals(ImmutableSet.of(1), newStageNumbers); + effectivelyFinishedStageNumbers = controllerQueryKernelTester.getEffectivelyFinishedStageNumbers(); + Assert.assertEquals(ImmutableSet.of(), effectivelyFinishedStageNumbers); + // Mark 1 as done and fetch the new kernels. Next up will be 2. transitionNewToResultsComplete(controllerQueryKernelTester, 1); newStageNumbers = controllerQueryKernelTester.createAndGetNewStageNumbers(); - Assert.assertEquals(ImmutableSet.of(0, 3), newStageNumbers); + Assert.assertEquals(ImmutableSet.of(2), newStageNumbers); effectivelyFinishedStageNumbers = controllerQueryKernelTester.getEffectivelyFinishedStageNumbers(); Assert.assertEquals(ImmutableSet.of(), effectivelyFinishedStageNumbers); + // Mark 2 as done and fetch the new kernels. Next up will be 3. + transitionNewToResultsComplete(controllerQueryKernelTester, 2); + newStageNumbers = controllerQueryKernelTester.createAndGetNewStageNumbers(); + Assert.assertEquals(ImmutableSet.of(3), newStageNumbers); + effectivelyFinishedStageNumbers = controllerQueryKernelTester.getEffectivelyFinishedStageNumbers(); + Assert.assertEquals(ImmutableSet.of(0), effectivelyFinishedStageNumbers); - // Mark 3 as done and fetch the new kernels. 5 should be unblocked along with 0. + // Mark 3 as done and fetch the new kernels. Next up will be 4. transitionNewToResultsComplete(controllerQueryKernelTester, 3); newStageNumbers = controllerQueryKernelTester.createAndGetNewStageNumbers(); - Assert.assertEquals(ImmutableSet.of(0, 5), newStageNumbers); + Assert.assertEquals(ImmutableSet.of(4), newStageNumbers); effectivelyFinishedStageNumbers = controllerQueryKernelTester.getEffectivelyFinishedStageNumbers(); - Assert.assertEquals(ImmutableSet.of(), effectivelyFinishedStageNumbers); + Assert.assertEquals(ImmutableSet.of(0, 1), effectivelyFinishedStageNumbers); + // Mark 4 as done and fetch new kernels. Next up will be 5. + transitionNewToResultsComplete(controllerQueryKernelTester, 4); + newStageNumbers = controllerQueryKernelTester.createAndGetNewStageNumbers(); + Assert.assertEquals(ImmutableSet.of(5), newStageNumbers); + effectivelyFinishedStageNumbers = controllerQueryKernelTester.getEffectivelyFinishedStageNumbers(); + Assert.assertEquals(ImmutableSet.of(0, 1, 2), effectivelyFinishedStageNumbers); + + // Mark 0, 1, 2 finished together. + effectivelyFinishedStageNumbers.forEach(controllerQueryKernelTester::finishStage); - // Mark 5 as done and fetch the new kernels. Only 0 is still unblocked, but 3 can now be cleaned + // Mark 5 as done and fetch new kernels. Next up will be 6, and 3 will be ready to finish. transitionNewToResultsComplete(controllerQueryKernelTester, 5); newStageNumbers = controllerQueryKernelTester.createAndGetNewStageNumbers(); - Assert.assertEquals(ImmutableSet.of(0), newStageNumbers); + Assert.assertEquals(ImmutableSet.of(6), newStageNumbers); effectivelyFinishedStageNumbers = controllerQueryKernelTester.getEffectivelyFinishedStageNumbers(); Assert.assertEquals(ImmutableSet.of(3), effectivelyFinishedStageNumbers); - // Mark 0 as done and fetch the new kernels. This should unblock 2 + // Mark 6 as done. No more kernels left, but we can clean up 4, 5, 6 along with 3. + transitionNewToResultsComplete(controllerQueryKernelTester, 6); + newStageNumbers = controllerQueryKernelTester.createAndGetNewStageNumbers(); + Assert.assertEquals(ImmutableSet.of(), newStageNumbers); + effectivelyFinishedStageNumbers = controllerQueryKernelTester.getEffectivelyFinishedStageNumbers(); + Assert.assertEquals(ImmutableSet.of(3, 4, 5, 6), effectivelyFinishedStageNumbers); + effectivelyFinishedStageNumbers.forEach(controllerQueryKernelTester::finishStage); + } + + @Test + public void testCompleteDAGExecutionForSingleWorkerWithPipelining() + { + ControllerQueryKernelTester controllerQueryKernelTester = testControllerQueryKernel( + configBuilder -> + configBuilder.maxConcurrentStages(2).pipeline(true).build() + ); + // 0 [HLS] 1 [HLS] + // | / | + // 2 [none] 3 [HLS] + // | | + // 4 [mix] 5 [HLS] + // \ / + // \ / + // 6 [none] + + final QueryDefinition queryDef = new MockQueryDefinitionBuilder(7) + .addEdge(0, 2) + .addEdge(1, 2) + .addEdge(1, 3) + .addEdge(2, 4) + .addEdge(3, 5) + .addEdge(4, 6) + .addEdge(5, 6) + .defineStage(0, ShuffleKind.HASH_LOCAL_SORT) + .defineStage(1, ShuffleKind.HASH_LOCAL_SORT) + .defineStage(3, ShuffleKind.HASH_LOCAL_SORT) + .defineStage(4, ShuffleKind.MIX) + .defineStage(5, ShuffleKind.HASH_LOCAL_SORT) + .getQueryDefinitionBuilder() + .build(); + + controllerQueryKernelTester.queryDefinition(queryDef); + controllerQueryKernelTester.init(); + + Assert.assertEquals( + ImmutableList.of( + ControllerQueryKernelUtilsTest.makeStageGroup(queryDef.getQueryId(), OutputChannelMode.LOCAL_STORAGE, 0), + ControllerQueryKernelUtilsTest.makeStageGroup(queryDef.getQueryId(), OutputChannelMode.LOCAL_STORAGE, 1), + ControllerQueryKernelUtilsTest.makeStageGroup(queryDef.getQueryId(), OutputChannelMode.LOCAL_STORAGE, 2, 4), + ControllerQueryKernelUtilsTest.makeStageGroup(queryDef.getQueryId(), OutputChannelMode.MEMORY, 3), + ControllerQueryKernelUtilsTest.makeStageGroup(queryDef.getQueryId(), OutputChannelMode.LOCAL_STORAGE, 5), + ControllerQueryKernelUtilsTest.makeStageGroup(queryDef.getQueryId(), OutputChannelMode.MEMORY, 6) + ), + ControllerQueryKernelUtils.computeStageGroups(queryDef, controllerQueryKernelTester.getConfig()) + ); + + Set newStageNumbers; + Set effectivelyFinishedStageNumbers; + + newStageNumbers = controllerQueryKernelTester.createAndGetNewStageNumbers(); + effectivelyFinishedStageNumbers = controllerQueryKernelTester.getEffectivelyFinishedStageNumbers(); + Assert.assertEquals(ImmutableSet.of(0, 1), newStageNumbers); + Assert.assertEquals(ImmutableSet.of(), effectivelyFinishedStageNumbers); + + + transitionNewToResultsComplete(controllerQueryKernelTester, 1); + newStageNumbers = controllerQueryKernelTester.createAndGetNewStageNumbers(); + Assert.assertEquals(ImmutableSet.of(0), newStageNumbers); + effectivelyFinishedStageNumbers = controllerQueryKernelTester.getEffectivelyFinishedStageNumbers(); + Assert.assertEquals(ImmutableSet.of(), effectivelyFinishedStageNumbers); + + + // Mark 0 as done and fetch the new kernels. 2 should be unblocked along with 4. transitionNewToResultsComplete(controllerQueryKernelTester, 0); newStageNumbers = controllerQueryKernelTester.createAndGetNewStageNumbers(); - Assert.assertEquals(ImmutableSet.of(2), newStageNumbers); + Assert.assertEquals(ImmutableSet.of(2, 4), newStageNumbers); effectivelyFinishedStageNumbers = controllerQueryKernelTester.getEffectivelyFinishedStageNumbers(); - Assert.assertEquals(ImmutableSet.of(3), effectivelyFinishedStageNumbers); + Assert.assertEquals(ImmutableSet.of(), effectivelyFinishedStageNumbers); + - // Mark 2 as done and fetch new kernels. This should clear up 0 and 1 alongside 3 (which is not marked as FINISHED yet) + // Mark 2 as done and fetch the new kernels. 4 is still ready, 0 can now be cleaned, and 3 can be launched transitionNewToResultsComplete(controllerQueryKernelTester, 2); newStageNumbers = controllerQueryKernelTester.createAndGetNewStageNumbers(); - Assert.assertEquals(ImmutableSet.of(4), newStageNumbers); + Assert.assertEquals(ImmutableSet.of(3, 4), newStageNumbers); + effectivelyFinishedStageNumbers = controllerQueryKernelTester.getEffectivelyFinishedStageNumbers(); + Assert.assertEquals(ImmutableSet.of(0), effectivelyFinishedStageNumbers); + + // Mark 4 as done and fetch the new kernels. 3 is still ready, and 2 becomes cleanable + transitionNewToResultsComplete(controllerQueryKernelTester, 4); + newStageNumbers = controllerQueryKernelTester.createAndGetNewStageNumbers(); + Assert.assertEquals(ImmutableSet.of(3), newStageNumbers); effectivelyFinishedStageNumbers = controllerQueryKernelTester.getEffectivelyFinishedStageNumbers(); - Assert.assertEquals(ImmutableSet.of(0, 1, 3), effectivelyFinishedStageNumbers); + Assert.assertEquals(ImmutableSet.of(0, 2), effectivelyFinishedStageNumbers); - // Mark 0, 1, 3 finished together + // Mark 3 as post-reading and fetch new kernels. This makes 1 cleanable, and 5 ready to run + transitionNewToDoneReadingInput(controllerQueryKernelTester, 3); + newStageNumbers = controllerQueryKernelTester.createAndGetNewStageNumbers(); + Assert.assertEquals(ImmutableSet.of(5), newStageNumbers); + effectivelyFinishedStageNumbers = controllerQueryKernelTester.getEffectivelyFinishedStageNumbers(); + Assert.assertEquals(ImmutableSet.of(0, 1, 2), effectivelyFinishedStageNumbers); + + // Mark 0, 1, 2 finished together effectivelyFinishedStageNumbers.forEach(controllerQueryKernelTester::finishStage); - // Mark 4 as done and fetch new kernels. This should unblock 6 and clear up 2 - transitionNewToResultsComplete(controllerQueryKernelTester, 4); + // Mark 5 as post-reading and fetch new kernels. Nothing is ready, since 6 is waiting for 5 to finish + // However, this does clear up 3 to become cleanable + transitionNewToDoneReadingInput(controllerQueryKernelTester, 5); + newStageNumbers = controllerQueryKernelTester.createAndGetNewStageNumbers(); + Assert.assertEquals(ImmutableSet.of(), newStageNumbers); + effectivelyFinishedStageNumbers = controllerQueryKernelTester.getEffectivelyFinishedStageNumbers(); + Assert.assertEquals(ImmutableSet.of(3), effectivelyFinishedStageNumbers); + + // Mark 5 as done. This makes 6 ready to go + transitionDoneReadingInputToResultsComplete(controllerQueryKernelTester, 5); newStageNumbers = controllerQueryKernelTester.createAndGetNewStageNumbers(); Assert.assertEquals(ImmutableSet.of(6), newStageNumbers); effectivelyFinishedStageNumbers = controllerQueryKernelTester.getEffectivelyFinishedStageNumbers(); - Assert.assertEquals(ImmutableSet.of(2), effectivelyFinishedStageNumbers); + Assert.assertEquals(ImmutableSet.of(3), effectivelyFinishedStageNumbers); - // Mark 6 as done. No more kernels left, but we can clean up 4 and 5 alongwith 2 + // Mark 6 as done. No more kernels left, but we can clean up 4 and 5 along with 2 transitionNewToResultsComplete(controllerQueryKernelTester, 6); newStageNumbers = controllerQueryKernelTester.createAndGetNewStageNumbers(); Assert.assertEquals(ImmutableSet.of(), newStageNumbers); effectivelyFinishedStageNumbers = controllerQueryKernelTester.getEffectivelyFinishedStageNumbers(); - Assert.assertEquals(ImmutableSet.of(2, 4, 5), effectivelyFinishedStageNumbers); + Assert.assertEquals(ImmutableSet.of(3, 4, 5, 6), effectivelyFinishedStageNumbers); effectivelyFinishedStageNumbers.forEach(controllerQueryKernelTester::finishStage); } @Test public void testCompleteDAGExecutionForMultipleWorkers() { - ControllerQueryKernelTester controllerQueryKernelTester = testControllerQueryKernel(2); + ControllerQueryKernelTester controllerQueryKernelTester = testControllerQueryKernel(); // 0 -> 1 -> 2 -> 3 controllerQueryKernelTester.queryDefinition( new MockQueryDefinitionBuilder(4) - .addVertex(0, 1) - .addVertex(1, 2) - .addVertex(2, 3) - .defineStage(0, true, 1) // Ingestion only on one worker - .defineStage(1, true, 2) - .defineStage(3, true, 2) + .addEdge(0, 1) + .addEdge(1, 2) + .addEdge(2, 3) + .defineStage(0, ShuffleKind.GLOBAL_SORT, 1) // Ingestion only on one worker + .defineStage(1, ShuffleKind.GLOBAL_SORT, 2) + .defineStage(3, ShuffleKind.GLOBAL_SORT, 2) .getQueryDefinitionBuilder() .build() ); @@ -233,12 +354,12 @@ public void testCompleteDAGExecutionForMultipleWorkers() @Test public void testTransitionsInShufflingStagesAndMultipleWorkers() { - ControllerQueryKernelTester controllerQueryKernelTester = testControllerQueryKernel(2); + ControllerQueryKernelTester controllerQueryKernelTester = testControllerQueryKernel(); // Single stage query definition controllerQueryKernelTester.queryDefinition( new MockQueryDefinitionBuilder(1) - .defineStage(0, true, 2) + .defineStage(0, ShuffleKind.GLOBAL_SORT, 2) .getQueryDefinitionBuilder() .build() ); @@ -275,12 +396,12 @@ public void testTransitionsInShufflingStagesAndMultipleWorkers() @Test public void testPrematureResultsComplete() { - ControllerQueryKernelTester controllerQueryKernelTester = testControllerQueryKernel(2); + ControllerQueryKernelTester controllerQueryKernelTester = testControllerQueryKernel(); // Single stage query definition controllerQueryKernelTester.queryDefinition( new MockQueryDefinitionBuilder(1) - .defineStage(0, true, 2) + .defineStage(0, ShuffleKind.GLOBAL_SORT, 2) .getQueryDefinitionBuilder() .build() ); @@ -311,15 +432,18 @@ public void testPrematureResultsComplete() @Test public void testKernelFailed() { - ControllerQueryKernelTester controllerQueryKernelTester = testControllerQueryKernel(1); + ControllerQueryKernelTester controllerQueryKernelTester = testControllerQueryKernel( + configBuilder -> + configBuilder.maxConcurrentStages(2).build() + ); // 0 1 // \ / // 2 controllerQueryKernelTester.queryDefinition( new MockQueryDefinitionBuilder(3) - .addVertex(0, 2) - .addVertex(1, 2) + .addEdge(0, 2) + .addEdge(1, 2) .getQueryDefinitionBuilder() .build() ); @@ -340,16 +464,16 @@ public void testKernelFailed() @Test(expected = IllegalStateException.class) public void testCycleInvalidQueryThrowsException() { - ControllerQueryKernelTester controllerQueryKernelTester = testControllerQueryKernel(1); + ControllerQueryKernelTester controllerQueryKernelTester = testControllerQueryKernel(); // 0 - 1 // \ / // 2 controllerQueryKernelTester.queryDefinition( new MockQueryDefinitionBuilder(3) - .addVertex(0, 1) - .addVertex(1, 2) - .addVertex(2, 0) + .addEdge(0, 1) + .addEdge(1, 2) + .addEdge(2, 0) .getQueryDefinitionBuilder() .build() ); @@ -358,13 +482,13 @@ public void testCycleInvalidQueryThrowsException() @Test(expected = IllegalStateException.class) public void testSelfLoopInvalidQueryThrowsException() { - ControllerQueryKernelTester controllerQueryKernelTester = testControllerQueryKernel(1); + ControllerQueryKernelTester controllerQueryKernelTester = testControllerQueryKernel(); // 0 _ // |__| controllerQueryKernelTester.queryDefinition( new MockQueryDefinitionBuilder(1) - .addVertex(0, 0) + .addEdge(0, 0) .getQueryDefinitionBuilder() .build() ); @@ -373,15 +497,15 @@ public void testSelfLoopInvalidQueryThrowsException() @Test(expected = IllegalStateException.class) public void testLoopInvalidQueryThrowsException() { - ControllerQueryKernelTester controllerQueryKernelTester = testControllerQueryKernel(1); + ControllerQueryKernelTester controllerQueryKernelTester = testControllerQueryKernel(); // 0 - 1 // | | // --- controllerQueryKernelTester.queryDefinition( new MockQueryDefinitionBuilder(2) - .addVertex(0, 1) - .addVertex(1, 0) + .addEdge(0, 1) + .addEdge(1, 0) .getQueryDefinitionBuilder() .build() ); @@ -390,15 +514,15 @@ public void testLoopInvalidQueryThrowsException() @Test public void testMarkSuccessfulTerminalStagesAsFinished() { - ControllerQueryKernelTester controllerQueryKernelTester = testControllerQueryKernel(1); + ControllerQueryKernelTester controllerQueryKernelTester = testControllerQueryKernel(); // 0 1 // \ / // 2 controllerQueryKernelTester.queryDefinition( new MockQueryDefinitionBuilder(3) - .addVertex(0, 2) - .addVertex(1, 2) + .addEdge(0, 2) + .addEdge(1, 2) .getQueryDefinitionBuilder() .build() ); @@ -409,8 +533,8 @@ public void testMarkSuccessfulTerminalStagesAsFinished() controllerQueryKernelTester.init(); - Assert.assertTrue(controllerQueryKernelTester.isDone()); - Assert.assertTrue(controllerQueryKernelTester.isSuccess()); + Assert.assertFalse(controllerQueryKernelTester.isDone()); + Assert.assertFalse(controllerQueryKernelTester.isSuccess()); controllerQueryKernelTester.assertStagePhase(0, ControllerStagePhase.FINISHED); controllerQueryKernelTester.assertStagePhase(1, ControllerStagePhase.RESULTS_READY); @@ -430,4 +554,18 @@ private static void transitionNewToResultsComplete(ControllerQueryKernelTester q queryKernelTester.setResultsCompleteForStageAndWorkers(stageNumber, 0); } + private static void transitionNewToDoneReadingInput(ControllerQueryKernelTester queryKernelTester, int stageNumber) + { + queryKernelTester.startStage(stageNumber); + queryKernelTester.startWorkOrder(stageNumber); + queryKernelTester.doneReadingInput(stageNumber); + } + + private static void transitionDoneReadingInputToResultsComplete( + ControllerQueryKernelTester queryKernelTester, + int stageNumber + ) + { + queryKernelTester.setResultsCompleteForStageAndWorkers(stageNumber, 0); + } } diff --git a/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/kernel/controller/ControllerQueryKernelUtilsTest.java b/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/kernel/controller/ControllerQueryKernelUtilsTest.java new file mode 100644 index 000000000000..b6bb5bb3d4e0 --- /dev/null +++ b/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/kernel/controller/ControllerQueryKernelUtilsTest.java @@ -0,0 +1,551 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.druid.msq.kernel.controller; + +import com.google.common.collect.ImmutableList; +import org.apache.druid.msq.exec.OutputChannelMode; +import org.apache.druid.msq.indexing.destination.DurableStorageMSQDestination; +import org.apache.druid.msq.indexing.destination.TaskReportMSQDestination; +import org.apache.druid.msq.kernel.QueryDefinition; +import org.apache.druid.msq.kernel.ShuffleKind; +import org.apache.druid.msq.kernel.StageId; +import org.junit.Assert; +import org.junit.Test; + +import java.util.Arrays; +import java.util.stream.Collectors; + +public class ControllerQueryKernelUtilsTest +{ + @Test + public void test_computeStageGroups_multiPronged() + { + final QueryDefinition queryDef = makeMultiProngedQueryDefinition(); + + Assert.assertEquals( + ImmutableList.of( + makeStageGroup(queryDef.getQueryId(), OutputChannelMode.LOCAL_STORAGE, 0), + makeStageGroup(queryDef.getQueryId(), OutputChannelMode.LOCAL_STORAGE, 1), + makeStageGroup(queryDef.getQueryId(), OutputChannelMode.LOCAL_STORAGE, 2), + makeStageGroup(queryDef.getQueryId(), OutputChannelMode.LOCAL_STORAGE, 3), + makeStageGroup(queryDef.getQueryId(), OutputChannelMode.LOCAL_STORAGE, 4), + makeStageGroup(queryDef.getQueryId(), OutputChannelMode.LOCAL_STORAGE, 5), + makeStageGroup(queryDef.getQueryId(), OutputChannelMode.LOCAL_STORAGE, 6) + ), + ControllerQueryKernelUtils.computeStageGroups( + queryDef, + ControllerQueryKernelConfig + .builder() + .maxRetainedPartitionSketchBytes(1) + .maxConcurrentStages(1) + .pipeline(false) + .faultTolerance(false) + .destination(TaskReportMSQDestination.instance()) + .build() + ) + ); + } + + @Test + public void test_computeStageGroups_multiPronged_pipeline_twoAtOnce() + { + final QueryDefinition queryDef = makeMultiProngedQueryDefinition(); + + Assert.assertEquals( + ImmutableList.of( + makeStageGroup(queryDef.getQueryId(), OutputChannelMode.LOCAL_STORAGE, 0), + makeStageGroup(queryDef.getQueryId(), OutputChannelMode.LOCAL_STORAGE, 1), + makeStageGroup(queryDef.getQueryId(), OutputChannelMode.LOCAL_STORAGE, 2, 4), + makeStageGroup(queryDef.getQueryId(), OutputChannelMode.LOCAL_STORAGE, 3, 5), + makeStageGroup(queryDef.getQueryId(), OutputChannelMode.MEMORY, 6) + ), + ControllerQueryKernelUtils.computeStageGroups( + queryDef, + ControllerQueryKernelConfig + .builder() + .maxRetainedPartitionSketchBytes(1) + .maxConcurrentStages(2) + .pipeline(true) + .faultTolerance(false) + .destination(TaskReportMSQDestination.instance()) + .build() + ) + ); + } + + @Test + public void test_computeStageGroups_linearWithoutShuffle() + { + final QueryDefinition queryDef = makeLinearQueryDefinitionWithoutShuffle(); + + Assert.assertEquals( + ImmutableList.of( + makeStageGroup(queryDef.getQueryId(), OutputChannelMode.LOCAL_STORAGE, 0), + makeStageGroup(queryDef.getQueryId(), OutputChannelMode.LOCAL_STORAGE, 1), + makeStageGroup(queryDef.getQueryId(), OutputChannelMode.LOCAL_STORAGE, 2), + makeStageGroup(queryDef.getQueryId(), OutputChannelMode.LOCAL_STORAGE, 3) + ), + ControllerQueryKernelUtils.computeStageGroups( + queryDef, + ControllerQueryKernelConfig + .builder() + .maxRetainedPartitionSketchBytes(1) + .maxConcurrentStages(1) + .pipeline(false) + .faultTolerance(false) + .destination(TaskReportMSQDestination.instance()) + .build() + ) + ); + } + + @Test + public void test_computeStageGroups_linearWithShuffle() + { + final QueryDefinition queryDef = makeLinearQueryDefinitionWithShuffle(); + + Assert.assertEquals( + ImmutableList.of( + makeStageGroup(queryDef.getQueryId(), OutputChannelMode.LOCAL_STORAGE, 0), + makeStageGroup(queryDef.getQueryId(), OutputChannelMode.LOCAL_STORAGE, 1), + makeStageGroup(queryDef.getQueryId(), OutputChannelMode.LOCAL_STORAGE, 2), + makeStageGroup(queryDef.getQueryId(), OutputChannelMode.LOCAL_STORAGE, 3) + ), + ControllerQueryKernelUtils.computeStageGroups( + queryDef, + ControllerQueryKernelConfig + .builder() + .maxRetainedPartitionSketchBytes(1) + .maxConcurrentStages(1) + .pipeline(false) + .faultTolerance(false) + .destination(TaskReportMSQDestination.instance()) + .build() + ) + ); + } + + @Test + public void test_computeStageGroups_linearWithoutShuffle_faultTolerant() + { + final QueryDefinition queryDef = makeLinearQueryDefinitionWithoutShuffle(); + + Assert.assertEquals( + ImmutableList.of( + makeStageGroup(queryDef.getQueryId(), OutputChannelMode.DURABLE_STORAGE_INTERMEDIATE, 0), + makeStageGroup(queryDef.getQueryId(), OutputChannelMode.DURABLE_STORAGE_INTERMEDIATE, 1), + makeStageGroup(queryDef.getQueryId(), OutputChannelMode.DURABLE_STORAGE_INTERMEDIATE, 2), + makeStageGroup(queryDef.getQueryId(), OutputChannelMode.DURABLE_STORAGE_INTERMEDIATE, 3) + ), + ControllerQueryKernelUtils.computeStageGroups( + queryDef, + ControllerQueryKernelConfig + .builder() + .maxRetainedPartitionSketchBytes(1) + .maxConcurrentStages(2) + .pipeline(false) + .faultTolerance(true) + .durableStorage(true) + .destination(TaskReportMSQDestination.instance()) + .build() + ) + ); + } + + @Test + public void test_computeStageGroups_linearWithShuffle_faultTolerant() + { + final QueryDefinition queryDef = makeLinearQueryDefinitionWithShuffle(); + + Assert.assertEquals( + ImmutableList.of( + makeStageGroup(queryDef.getQueryId(), OutputChannelMode.DURABLE_STORAGE_INTERMEDIATE, 0), + makeStageGroup(queryDef.getQueryId(), OutputChannelMode.DURABLE_STORAGE_INTERMEDIATE, 1), + makeStageGroup(queryDef.getQueryId(), OutputChannelMode.DURABLE_STORAGE_INTERMEDIATE, 2), + makeStageGroup(queryDef.getQueryId(), OutputChannelMode.DURABLE_STORAGE_INTERMEDIATE, 3) + ), + ControllerQueryKernelUtils.computeStageGroups( + queryDef, + ControllerQueryKernelConfig + .builder() + .maxRetainedPartitionSketchBytes(1) + .maxConcurrentStages(2) + .pipeline(false) + .faultTolerance(true) + .durableStorage(true) + .destination(TaskReportMSQDestination.instance()) + .build() + ) + ); + } + + @Test + public void test_computeStageGroups_linearWithoutShuffle_faultTolerant_durableResults() + { + final QueryDefinition queryDef = makeLinearQueryDefinitionWithoutShuffle(); + + Assert.assertEquals( + ImmutableList.of( + makeStageGroup(queryDef.getQueryId(), OutputChannelMode.DURABLE_STORAGE_INTERMEDIATE, 0), + makeStageGroup(queryDef.getQueryId(), OutputChannelMode.DURABLE_STORAGE_INTERMEDIATE, 1), + makeStageGroup(queryDef.getQueryId(), OutputChannelMode.DURABLE_STORAGE_INTERMEDIATE, 2), + makeStageGroup(queryDef.getQueryId(), OutputChannelMode.DURABLE_STORAGE_QUERY_RESULTS, 3) + ), + ControllerQueryKernelUtils.computeStageGroups( + queryDef, + ControllerQueryKernelConfig + .builder() + .maxRetainedPartitionSketchBytes(1) + .maxConcurrentStages(2) + .pipeline(false) + .faultTolerance(true) + .durableStorage(true) + .destination(DurableStorageMSQDestination.instance()) + .build() + ) + ); + } + + @Test + public void test_computeStageGroups_linearWithShuffle_faultTolerant_durableResults() + { + final QueryDefinition queryDef = makeLinearQueryDefinitionWithShuffle(); + + Assert.assertEquals( + ImmutableList.of( + makeStageGroup(queryDef.getQueryId(), OutputChannelMode.DURABLE_STORAGE_INTERMEDIATE, 0), + makeStageGroup(queryDef.getQueryId(), OutputChannelMode.DURABLE_STORAGE_INTERMEDIATE, 1), + makeStageGroup(queryDef.getQueryId(), OutputChannelMode.DURABLE_STORAGE_INTERMEDIATE, 2), + makeStageGroup(queryDef.getQueryId(), OutputChannelMode.DURABLE_STORAGE_QUERY_RESULTS, 3) + ), + ControllerQueryKernelUtils.computeStageGroups( + queryDef, + ControllerQueryKernelConfig + .builder() + .maxRetainedPartitionSketchBytes(1) + .maxConcurrentStages(2) + .pipeline(false) + .faultTolerance(true) + .durableStorage(true) + .destination(DurableStorageMSQDestination.instance()) + .build() + ) + ); + } + + @Test + public void test_computeStageGroups_linearWithoutShuffle_pipeline_twoAtOnce() + { + final QueryDefinition queryDef = makeLinearQueryDefinitionWithoutShuffle(); + + Assert.assertEquals( + // Without a sort-based shuffle, we can't leapfrog, so we launch two groups broken up by LOCAL_STORAGE + ImmutableList.of( + makeStageGroup(queryDef.getQueryId(), OutputChannelMode.LOCAL_STORAGE, 0, 1), + makeStageGroup(queryDef.getQueryId(), OutputChannelMode.MEMORY, 2, 3) + ), + ControllerQueryKernelUtils.computeStageGroups( + queryDef, + ControllerQueryKernelConfig + .builder() + .maxRetainedPartitionSketchBytes(1) + .maxConcurrentStages(2) + .pipeline(true) + .faultTolerance(false) + .destination(TaskReportMSQDestination.instance()) + .build() + ) + ); + } + + @Test + public void test_computeStageGroups_linearWithShuffle_pipeline_twoAtOnce() + { + final QueryDefinition queryDef = makeLinearQueryDefinitionWithShuffle(); + + Assert.assertEquals( + // With sort-based shuffle, we can leapfrog 4 stages, all of them being in-memory + ImmutableList.of( + makeStageGroup(queryDef.getQueryId(), OutputChannelMode.MEMORY, 0), + makeStageGroup(queryDef.getQueryId(), OutputChannelMode.MEMORY, 1), + makeStageGroup(queryDef.getQueryId(), OutputChannelMode.MEMORY, 2), + makeStageGroup(queryDef.getQueryId(), OutputChannelMode.MEMORY, 3) + ), + ControllerQueryKernelUtils.computeStageGroups( + queryDef, + ControllerQueryKernelConfig + .builder() + .maxRetainedPartitionSketchBytes(1) + .maxConcurrentStages(2) + .pipeline(true) + .faultTolerance(false) + .destination(TaskReportMSQDestination.instance()) + .build() + ) + ); + } + + @Test + public void test_computeStageGroups_fanIn() + { + final QueryDefinition queryDef = makeFanInQueryDefinition(); + + Assert.assertEquals( + ImmutableList.of( + makeStageGroup(queryDef.getQueryId(), OutputChannelMode.LOCAL_STORAGE, 0), + makeStageGroup(queryDef.getQueryId(), OutputChannelMode.LOCAL_STORAGE, 1), + makeStageGroup(queryDef.getQueryId(), OutputChannelMode.LOCAL_STORAGE, 2), + makeStageGroup(queryDef.getQueryId(), OutputChannelMode.LOCAL_STORAGE, 3) + ), + ControllerQueryKernelUtils.computeStageGroups( + queryDef, + ControllerQueryKernelConfig + .builder() + .maxRetainedPartitionSketchBytes(1) + .maxConcurrentStages(1) + .pipeline(false) + .faultTolerance(false) + .destination(TaskReportMSQDestination.instance()) + .build() + ) + ); + } + + @Test + public void test_computeStageGroups_fanInWithBroadcast() + { + final QueryDefinition queryDef = makeFanInQueryDefinitionWithBroadcast(); + + Assert.assertEquals( + ImmutableList.of( + makeStageGroup(queryDef.getQueryId(), OutputChannelMode.LOCAL_STORAGE, 0), + makeStageGroup(queryDef.getQueryId(), OutputChannelMode.LOCAL_STORAGE, 1), + makeStageGroup(queryDef.getQueryId(), OutputChannelMode.LOCAL_STORAGE, 2), + makeStageGroup(queryDef.getQueryId(), OutputChannelMode.LOCAL_STORAGE, 3) + ), + ControllerQueryKernelUtils.computeStageGroups( + queryDef, + ControllerQueryKernelConfig + .builder() + .maxRetainedPartitionSketchBytes(1) + .maxConcurrentStages(1) + .pipeline(false) + .faultTolerance(false) + .destination(TaskReportMSQDestination.instance()) + .build() + ) + ); + } + + @Test + public void test_computeStageGroups_fanIn_faultTolerant() + { + final QueryDefinition queryDef = makeFanInQueryDefinition(); + + Assert.assertEquals( + ImmutableList.of( + makeStageGroup(queryDef.getQueryId(), OutputChannelMode.DURABLE_STORAGE_INTERMEDIATE, 0), + makeStageGroup(queryDef.getQueryId(), OutputChannelMode.DURABLE_STORAGE_INTERMEDIATE, 1), + makeStageGroup(queryDef.getQueryId(), OutputChannelMode.DURABLE_STORAGE_INTERMEDIATE, 2), + makeStageGroup(queryDef.getQueryId(), OutputChannelMode.DURABLE_STORAGE_INTERMEDIATE, 3) + ), + ControllerQueryKernelUtils.computeStageGroups( + queryDef, + ControllerQueryKernelConfig + .builder() + .maxRetainedPartitionSketchBytes(1) + .maxConcurrentStages(2) + .pipeline(false) + .faultTolerance(true) + .durableStorage(true) + .destination(TaskReportMSQDestination.instance()) + .build() + ) + ); + } + + @Test + public void test_computeStageGroups_fanIn_faultTolerant_durableResults() + { + final QueryDefinition queryDef = makeFanInQueryDefinition(); + + Assert.assertEquals( + ImmutableList.of( + makeStageGroup(queryDef.getQueryId(), OutputChannelMode.DURABLE_STORAGE_INTERMEDIATE, 0), + makeStageGroup(queryDef.getQueryId(), OutputChannelMode.DURABLE_STORAGE_INTERMEDIATE, 1), + makeStageGroup(queryDef.getQueryId(), OutputChannelMode.DURABLE_STORAGE_INTERMEDIATE, 2), + makeStageGroup(queryDef.getQueryId(), OutputChannelMode.DURABLE_STORAGE_QUERY_RESULTS, 3) + ), + ControllerQueryKernelUtils.computeStageGroups( + queryDef, + ControllerQueryKernelConfig + .builder() + .maxRetainedPartitionSketchBytes(1) + .maxConcurrentStages(2) + .pipeline(false) + .faultTolerance(true) + .durableStorage(true) + .destination(DurableStorageMSQDestination.instance()) + .build() + ) + ); + } + + @Test + public void test_computeStageGroups_fanIn_pipeline_twoAtOnce() + { + final QueryDefinition queryDef = makeFanInQueryDefinition(); + + Assert.assertEquals( + ImmutableList.of( + makeStageGroup(queryDef.getQueryId(), OutputChannelMode.LOCAL_STORAGE, 0), + makeStageGroup(queryDef.getQueryId(), OutputChannelMode.LOCAL_STORAGE, 1), + makeStageGroup(queryDef.getQueryId(), OutputChannelMode.MEMORY, 2, 3) + ), + ControllerQueryKernelUtils.computeStageGroups( + queryDef, + ControllerQueryKernelConfig + .builder() + .maxRetainedPartitionSketchBytes(1) + .maxConcurrentStages(2) + .pipeline(true) + .faultTolerance(false) + .destination(TaskReportMSQDestination.instance()) + .build() + ) + ); + } + + @Test + public void test_computeStageGroups_fanInWithBroadcast_pipeline_twoAtOnce() + { + final QueryDefinition queryDef = makeFanInQueryDefinitionWithBroadcast(); + + Assert.assertEquals( + // Output of stage 1 is broadcast, so it must run first; then stages 0 and 2 may be launched together + ImmutableList.of( + makeStageGroup(queryDef.getQueryId(), OutputChannelMode.LOCAL_STORAGE, 1), + makeStageGroup(queryDef.getQueryId(), OutputChannelMode.LOCAL_STORAGE, 0, 2), + makeStageGroup(queryDef.getQueryId(), OutputChannelMode.MEMORY, 3) + ), + ControllerQueryKernelUtils.computeStageGroups( + queryDef, + ControllerQueryKernelConfig + .builder() + .maxRetainedPartitionSketchBytes(1) + .maxConcurrentStages(2) + .pipeline(true) + .faultTolerance(false) + .destination(TaskReportMSQDestination.instance()) + .build() + ) + ); + } + + private static QueryDefinition makeLinearQueryDefinitionWithShuffle() + { + // 0 -> 1 -> 2 -> 3 + + return new MockQueryDefinitionBuilder(4) + .addEdge(0, 1) + .addEdge(1, 2) + .addEdge(2, 3) + .defineStage(0, ShuffleKind.GLOBAL_SORT) + .defineStage(1, ShuffleKind.GLOBAL_SORT) + .defineStage(2, ShuffleKind.GLOBAL_SORT) + .defineStage(3, ShuffleKind.GLOBAL_SORT) + .getQueryDefinitionBuilder() + .build(); + } + + private static QueryDefinition makeLinearQueryDefinitionWithoutShuffle() + { + // 0 -> 1 -> 2 -> 3 + + return new MockQueryDefinitionBuilder(4) + .addEdge(0, 1) + .addEdge(1, 2) + .addEdge(2, 3) + .getQueryDefinitionBuilder() + .build(); + } + + private static QueryDefinition makeFanInQueryDefinition() + { + // 0 -> 2 -> 3 + // / + // 1 + + return new MockQueryDefinitionBuilder(4) + .addEdge(0, 2) + .addEdge(1, 2) + .addEdge(2, 3) + .getQueryDefinitionBuilder() + .build(); + } + + private static QueryDefinition makeFanInQueryDefinitionWithBroadcast() + { + // 0 -> 2 -> 3 + // / < broadcast + // 1 + + return new MockQueryDefinitionBuilder(4) + .addEdge(0, 2) + .addEdge(1, 2, true) + .addEdge(2, 3) + .getQueryDefinitionBuilder() + .build(); + } + + private static QueryDefinition makeMultiProngedQueryDefinition() + { + // 0 1 + // | / | + // 2 / 3 + // | | + // 4 5 + // \ / + // 6 + + return new MockQueryDefinitionBuilder(7) + .addEdge(0, 2) + .addEdge(1, 2) + .addEdge(1, 3) + .addEdge(2, 4) + .addEdge(3, 5) + .addEdge(4, 6) + .addEdge(5, 6) + .getQueryDefinitionBuilder() + .build(); + } + + public static StageGroup makeStageGroup( + final String queryId, + final OutputChannelMode outputChannelMode, + final int... stageNumbers + ) + { + return new StageGroup( + Arrays.stream(stageNumbers).mapToObj(n -> new StageId(queryId, n)).collect(Collectors.toList()), + outputChannelMode + ); + } +} diff --git a/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/kernel/controller/MockQueryDefinitionBuilder.java b/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/kernel/controller/MockQueryDefinitionBuilder.java index f16e35e6e283..6ac399f70e4a 100644 --- a/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/kernel/controller/MockQueryDefinitionBuilder.java +++ b/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/kernel/controller/MockQueryDefinitionBuilder.java @@ -21,6 +21,9 @@ import com.google.common.base.Preconditions; import com.google.common.collect.ImmutableList; +import it.unimi.dsi.fastutil.ints.IntBooleanPair; +import it.unimi.dsi.fastutil.ints.IntOpenHashSet; +import it.unimi.dsi.fastutil.ints.IntSet; import org.apache.druid.frame.key.ClusterBy; import org.apache.druid.frame.key.KeyColumn; import org.apache.druid.frame.key.KeyOrder; @@ -30,21 +33,26 @@ import org.apache.druid.msq.input.stage.StageInputSpec; import org.apache.druid.msq.kernel.FrameProcessorFactory; import org.apache.druid.msq.kernel.GlobalSortMaxCountShuffleSpec; +import org.apache.druid.msq.kernel.HashShuffleSpec; +import org.apache.druid.msq.kernel.MixShuffleSpec; import org.apache.druid.msq.kernel.QueryDefinition; import org.apache.druid.msq.kernel.QueryDefinitionBuilder; +import org.apache.druid.msq.kernel.ShuffleKind; import org.apache.druid.msq.kernel.ShuffleSpec; import org.apache.druid.msq.kernel.StageDefinition; import org.apache.druid.segment.column.ColumnType; import org.apache.druid.segment.column.RowSignature; import org.mockito.Mockito; +import javax.annotation.Nullable; +import java.util.ArrayList; import java.util.Collections; import java.util.HashMap; import java.util.HashSet; import java.util.List; import java.util.Map; import java.util.Set; -import java.util.stream.Collectors; +import java.util.UUID; public class MockQueryDefinitionBuilder { @@ -56,14 +64,14 @@ public class MockQueryDefinitionBuilder private final int numStages; // Maps a stage to all the other stages on which it has dependency, i.e. for an edge like A -> B, the adjacency list - // would have an entry like B : [ A, ... ] - private final Map> adjacencyList = new HashMap<>(); + // would have an entry like B : [ , ... ] + private final Map> adjacencyList = new HashMap<>(); // Keeps a collection of those stages that have been already defined private final Set definedStages = new HashSet<>(); // Query definition builder corresponding to this mock builder - private final QueryDefinitionBuilder queryDefinitionBuilder = QueryDefinition.builder(); + private final QueryDefinitionBuilder queryDefinitionBuilder = QueryDefinition.builder(UUID.randomUUID().toString()); public MockQueryDefinitionBuilder(final int numStages) @@ -71,35 +79,40 @@ public MockQueryDefinitionBuilder(final int numStages) this.numStages = numStages; } - public MockQueryDefinitionBuilder addVertex(final int outEdge, final int inEdge) + public MockQueryDefinitionBuilder addEdge(final int outVertex, final int inVertex) + { + return addEdge(outVertex, inVertex, false); + } + + public MockQueryDefinitionBuilder addEdge(final int outVertex, final int inVertex, final boolean broadcast) { Preconditions.checkArgument( - outEdge < numStages, + outVertex < numStages, "vertex number can only be from 0 to one less than the total number of stages" ); Preconditions.checkArgument( - inEdge < numStages, + inVertex < numStages, "vertex number can only be from 0 to one less than the total number of stages" ); Preconditions.checkArgument( - !definedStages.contains(inEdge), - StringUtils.format("%s is already defined, cannot create more connections from it", inEdge) + !definedStages.contains(inVertex), + StringUtils.format("%s is already defined, cannot create more connections from it", inVertex) ); Preconditions.checkArgument( - !definedStages.contains(outEdge), - StringUtils.format("%s is already defined, cannot create more connections to it", outEdge) + !definedStages.contains(outVertex), + StringUtils.format("%s is already defined, cannot create more connections to it", outVertex) ); - adjacencyList.computeIfAbsent(inEdge, k -> new HashSet<>()).add(outEdge); + adjacencyList.computeIfAbsent(inVertex, k -> new HashSet<>()).add(IntBooleanPair.of(outVertex, broadcast)); return this; } public MockQueryDefinitionBuilder defineStage( int stageNumber, - boolean shuffling, + @Nullable ShuffleKind shuffleKind, int maxWorkers ) { @@ -113,27 +126,60 @@ public MockQueryDefinitionBuilder defineStage( ); definedStages.add(stageNumber); - ShuffleSpec shuffleSpec; + ShuffleSpec shuffleSpec = null; - if (shuffling) { - shuffleSpec = new GlobalSortMaxCountShuffleSpec( - new ClusterBy( - ImmutableList.of( - new KeyColumn(SHUFFLE_KEY_COLUMN, KeyOrder.ASCENDING) + if (shuffleKind != null) { + switch (shuffleKind) { + case GLOBAL_SORT: + shuffleSpec = new GlobalSortMaxCountShuffleSpec( + new ClusterBy( + ImmutableList.of( + new KeyColumn(SHUFFLE_KEY_COLUMN, KeyOrder.ASCENDING) + ), + 0 ), - 0 - ), - MAX_NUM_PARTITIONS, - false - ); - } else { - shuffleSpec = null; + MAX_NUM_PARTITIONS, + false + ); + break; + + case HASH_LOCAL_SORT: + case HASH: + shuffleSpec = new HashShuffleSpec( + new ClusterBy( + ImmutableList.of( + new KeyColumn( + SHUFFLE_KEY_COLUMN, + shuffleKind == ShuffleKind.HASH ? KeyOrder.NONE : KeyOrder.ASCENDING + ) + ), + 0 + ), + MAX_NUM_PARTITIONS + ); + break; + + case MIX: + shuffleSpec = MixShuffleSpec.instance(); + break; + } + + if (shuffleSpec == null || shuffleKind != shuffleSpec.kind()) { + throw new ISE("Oops, created an incorrect shuffleSpec[%s] for kind[%s]", shuffleSpec, shuffleKind); + } } - final List inputSpecs = - adjacencyList.getOrDefault(stageNumber, new HashSet<>()) - .stream() - .map(StageInputSpec::new).collect(Collectors.toList()); + final List inputSpecs = new ArrayList<>(); + final IntSet broadcastInputNumbers = new IntOpenHashSet(); + + int inputNumber = 0; + for (final IntBooleanPair pair : adjacencyList.getOrDefault(stageNumber, Collections.emptySet())) { + inputSpecs.add(new StageInputSpec(pair.leftInt())); + if (pair.rightBoolean()) { + broadcastInputNumbers.add(inputNumber); + } + inputNumber++; + } if (inputSpecs.isEmpty()) { for (int i = 0; i < maxWorkers; i++) { @@ -144,6 +190,7 @@ public MockQueryDefinitionBuilder defineStage( queryDefinitionBuilder.add( StageDefinition.builder(stageNumber) .inputs(inputSpecs) + .broadcastInputs(broadcastInputNumbers) .processorFactory(Mockito.mock(FrameProcessorFactory.class)) .shuffleSpec(shuffleSpec) .signature(RowSignature.builder().add(SHUFFLE_KEY_COLUMN, ColumnType.STRING).build()) @@ -153,14 +200,14 @@ public MockQueryDefinitionBuilder defineStage( return this; } - public MockQueryDefinitionBuilder defineStage(int stageNumber, boolean shuffling) + public MockQueryDefinitionBuilder defineStage(int stageNumber, @Nullable ShuffleKind shuffleKind) { - return defineStage(stageNumber, shuffling, 1); + return defineStage(stageNumber, shuffleKind, 1); } public MockQueryDefinitionBuilder defineStage(int stageNumber) { - return defineStage(stageNumber, false); + return defineStage(stageNumber, null); } public QueryDefinitionBuilder getQueryDefinitionBuilder() @@ -205,8 +252,8 @@ private boolean checkAcyclic(int node, Map visited) return false; } else { visited.put(node, StageState.VISITING); - for (int neighbour : adjacencyList.getOrDefault(node, Collections.emptySet())) { - if (!checkAcyclic(neighbour, visited)) { + for (IntBooleanPair neighbour : adjacencyList.getOrDefault(node, Collections.emptySet())) { + if (!checkAcyclic(neighbour.leftInt(), visited)) { return false; } } diff --git a/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/kernel/controller/NonShufflingWorkersWithRetryKernelTest.java b/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/kernel/controller/NonShufflingWorkersWithRetryKernelTest.java index d408b47da8e1..fb5af9e7c4f6 100644 --- a/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/kernel/controller/NonShufflingWorkersWithRetryKernelTest.java +++ b/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/kernel/controller/NonShufflingWorkersWithRetryKernelTest.java @@ -20,6 +20,7 @@ package org.apache.druid.msq.kernel.controller; +import org.apache.druid.msq.indexing.destination.DurableStorageMSQDestination; import org.junit.Assert; import org.junit.Test; @@ -318,13 +319,20 @@ public void testMultipleWorkersFailedBeforeAllResultsRecieved() @Nonnull private ControllerQueryKernelTester getSimpleQueryDefinition(int numWorkers) { - ControllerQueryKernelTester controllerQueryKernelTester = testControllerQueryKernel(numWorkers); + ControllerQueryKernelTester controllerQueryKernelTester = testControllerQueryKernel( + configBuilder -> + configBuilder + .destination(DurableStorageMSQDestination.instance()) + .durableStorage(true) + .faultTolerance(true) + .build() + ); // 0 -> 1 controllerQueryKernelTester.queryDefinition( new MockQueryDefinitionBuilder(2) - .addVertex(0, 1) - .defineStage(0, false, numWorkers) - .defineStage(1, false, numWorkers) + .addEdge(0, 1) + .defineStage(0, null, numWorkers) + .defineStage(1, null, numWorkers) .getQueryDefinitionBuilder() .build() ); diff --git a/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/kernel/controller/ShufflingWorkersWithRetryKernelTest.java b/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/kernel/controller/ShufflingWorkersWithRetryKernelTest.java index 824c23b4fb11..81addb183ba3 100644 --- a/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/kernel/controller/ShufflingWorkersWithRetryKernelTest.java +++ b/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/kernel/controller/ShufflingWorkersWithRetryKernelTest.java @@ -19,6 +19,8 @@ package org.apache.druid.msq.kernel.controller; +import org.apache.druid.msq.indexing.destination.DurableStorageMSQDestination; +import org.apache.druid.msq.kernel.ShuffleKind; import org.junit.Assert; import org.junit.Test; @@ -1071,13 +1073,20 @@ public void testMultipleWorkersFailedBeforeAllResultsReceived() @Nonnull private ControllerQueryKernelTester getSimpleQueryDefinition(int numWorkers) { - ControllerQueryKernelTester controllerQueryKernelTester = testControllerQueryKernel(numWorkers); + ControllerQueryKernelTester controllerQueryKernelTester = testControllerQueryKernel( + configBuilder -> + configBuilder + .destination(DurableStorageMSQDestination.instance()) + .durableStorage(true) + .faultTolerance(true) + .build() + ); // 0 -> 1 controllerQueryKernelTester.queryDefinition( new MockQueryDefinitionBuilder(2) - .addVertex(0, 1) - .defineStage(0, true, numWorkers) - .defineStage(1, true, numWorkers) + .addEdge(0, 1) + .defineStage(0, ShuffleKind.GLOBAL_SORT, numWorkers) + .defineStage(1, ShuffleKind.GLOBAL_SORT, numWorkers) .getQueryDefinitionBuilder() .build() ); diff --git a/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/kernel/controller/WorkerInputsTest.java b/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/kernel/controller/WorkerInputsTest.java index 00ccfdee6c19..605e0bf2de74 100644 --- a/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/kernel/controller/WorkerInputsTest.java +++ b/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/kernel/controller/WorkerInputsTest.java @@ -29,6 +29,7 @@ import it.unimi.dsi.fastutil.longs.LongList; import nl.jqno.equalsverifier.EqualsVerifier; import org.apache.druid.msq.exec.Limits; +import org.apache.druid.msq.exec.OutputChannelMode; import org.apache.druid.msq.input.InputSlice; import org.apache.druid.msq.input.InputSpec; import org.apache.druid.msq.input.InputSpecSlicer; @@ -238,7 +239,8 @@ public void test_auto_oneInputStageWithThreePartitionsAndTwoWorkers_fourWorkerMa stageDef, new Int2IntAVLTreeMap(ImmutableMap.of(0, 2)), new StageInputSpecSlicer( - new Int2ObjectAVLTreeMap<>(ImmutableMap.of(0, ReadablePartitions.striped(0, 2, 3))) + new Int2ObjectAVLTreeMap<>(ImmutableMap.of(0, ReadablePartitions.striped(0, 2, 3))), + new Int2ObjectAVLTreeMap<>(ImmutableMap.of(0, OutputChannelMode.LOCAL_STORAGE)) ), WorkerAssignmentStrategy.AUTO, Limits.DEFAULT_MAX_INPUT_BYTES_PER_WORKER @@ -251,7 +253,8 @@ public void test_auto_oneInputStageWithThreePartitionsAndTwoWorkers_fourWorkerMa Collections.singletonList( new StageInputSlice( 0, - new StripedReadablePartitions(0, 2, new IntAVLTreeSet(new int[]{0, 2})) + new StripedReadablePartitions(0, 2, new IntAVLTreeSet(new int[]{0, 2})), + OutputChannelMode.LOCAL_STORAGE ) ) ) @@ -260,7 +263,8 @@ public void test_auto_oneInputStageWithThreePartitionsAndTwoWorkers_fourWorkerMa Collections.singletonList( new StageInputSlice( 0, - new StripedReadablePartitions(0, 2, new IntAVLTreeSet(new int[]{1})) + new StripedReadablePartitions(0, 2, new IntAVLTreeSet(new int[]{1})), + OutputChannelMode.LOCAL_STORAGE ) ) ) @@ -283,7 +287,8 @@ public void test_auto_oneInputStageWithThreePartitionsAndTwoWorkers_oneWorkerMax stageDef, new Int2IntAVLTreeMap(ImmutableMap.of(0, 2)), new StageInputSpecSlicer( - new Int2ObjectAVLTreeMap<>(ImmutableMap.of(0, ReadablePartitions.striped(0, 2, 3))) + new Int2ObjectAVLTreeMap<>(ImmutableMap.of(0, ReadablePartitions.striped(0, 2, 3))), + new Int2ObjectAVLTreeMap<>(ImmutableMap.of(0, OutputChannelMode.LOCAL_STORAGE)) ), WorkerAssignmentStrategy.AUTO, Limits.DEFAULT_MAX_INPUT_BYTES_PER_WORKER @@ -296,7 +301,8 @@ public void test_auto_oneInputStageWithThreePartitionsAndTwoWorkers_oneWorkerMax Collections.singletonList( new StageInputSlice( 0, - new StripedReadablePartitions(0, 2, new IntAVLTreeSet(new int[]{0, 1, 2})) + new StripedReadablePartitions(0, 2, new IntAVLTreeSet(new int[]{0, 1, 2})), + OutputChannelMode.LOCAL_STORAGE ) ) ) diff --git a/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/sql/resources/SqlStatementResourceTest.java b/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/sql/resources/SqlStatementResourceTest.java index b80e59223f78..2da5fd42caf1 100644 --- a/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/sql/resources/SqlStatementResourceTest.java +++ b/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/sql/resources/SqlStatementResourceTest.java @@ -37,8 +37,6 @@ import org.apache.druid.java.util.common.Intervals; import org.apache.druid.java.util.common.StringUtils; import org.apache.druid.java.util.common.granularity.Granularities; -import org.apache.druid.java.util.common.guava.Sequences; -import org.apache.druid.java.util.common.guava.Yielders; import org.apache.druid.java.util.http.client.response.StringFullResponseHolder; import org.apache.druid.msq.counters.ChannelCounters; import org.apache.druid.msq.counters.CounterSnapshots; @@ -247,8 +245,8 @@ public class SqlStatementResourceTest extends MSQTestBase ImmutableMap.of(), ImmutableMap.of(), ImmutableMap.of(0, 1), - ImmutableMap.of(0, 1) - + ImmutableMap.of(0, 1), + ImmutableMap.of() ), CounterSnapshotsTree.fromMap(ImmutableMap.of( 0, @@ -287,9 +285,7 @@ public class SqlStatementResourceTest extends MSQTestBase SqlTypeName.VARCHAR, SqlTypeName.VARCHAR ), - Yielders.each( - Sequences.simple( - RESULT_ROWS)), + RESULT_ROWS, null ) ) @@ -315,6 +311,7 @@ public class SqlStatementResourceTest extends MSQTestBase ImmutableMap.of(), ImmutableMap.of(), ImmutableMap.of(), + ImmutableMap.of(), ImmutableMap.of() ), new CounterSnapshotsTree(), diff --git a/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/statistics/SendPartialKeyStatisticsInformationSerdeTest.java b/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/statistics/SendPartialKeyStatisticsInformationSerdeTest.java new file mode 100644 index 000000000000..254eb8a23210 --- /dev/null +++ b/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/statistics/SendPartialKeyStatisticsInformationSerdeTest.java @@ -0,0 +1,62 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.druid.msq.statistics; + +import com.fasterxml.jackson.core.JsonParser; +import com.fasterxml.jackson.core.JsonProcessingException; +import com.fasterxml.jackson.databind.ObjectMapper; +import com.google.common.collect.ImmutableSet; +import org.apache.druid.msq.guice.MSQIndexingModule; +import org.apache.druid.segment.TestHelper; +import org.junit.Assert; +import org.junit.Before; +import org.junit.Test; + +public class SendPartialKeyStatisticsInformationSerdeTest +{ + private ObjectMapper objectMapper; + + @Before + public void setUp() + { + objectMapper = TestHelper.makeJsonMapper(); + objectMapper.registerModules(new MSQIndexingModule().getJacksonModules()); + objectMapper.enable(JsonParser.Feature.STRICT_DUPLICATE_DETECTION); + } + + @Test + public void testSerde() throws JsonProcessingException + { + PartialKeyStatisticsInformation partialInformation = new PartialKeyStatisticsInformation( + ImmutableSet.of(2L, 3L), + false, + 0.0 + ); + + final String json = objectMapper.writeValueAsString(partialInformation); + final PartialKeyStatisticsInformation deserializedKeyStatistics = objectMapper.readValue( + json, + PartialKeyStatisticsInformation.class + ); + Assert.assertEquals(json, partialInformation.getTimeSegments(), deserializedKeyStatistics.getTimeSegments()); + Assert.assertEquals(json, partialInformation.hasMultipleValues(), deserializedKeyStatistics.hasMultipleValues()); + Assert.assertEquals(json, partialInformation.getBytesRetained(), deserializedKeyStatistics.getBytesRetained(), 0); + } +} diff --git a/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/test/CalciteMSQTestsHelper.java b/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/test/CalciteMSQTestsHelper.java index fc7cfe5d9bea..2ebe975c39d9 100644 --- a/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/test/CalciteMSQTestsHelper.java +++ b/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/test/CalciteMSQTestsHelper.java @@ -32,6 +32,7 @@ import org.apache.druid.data.input.impl.LongDimensionSchema; import org.apache.druid.data.input.impl.StringDimensionSchema; import org.apache.druid.discovery.NodeRole; +import org.apache.druid.frame.processor.Bouncer; import org.apache.druid.guice.GuiceInjectors; import org.apache.druid.guice.IndexingServiceTuningConfigModule; import org.apache.druid.guice.JoinableFactoryModule; @@ -175,6 +176,7 @@ public String getFormatString() groupByBuffers ).getGroupingEngine(); binder.bind(GroupingEngine.class).toInstance(groupingEngine); + binder.bind(Bouncer.class).toInstance(new Bouncer(1)); }; return ImmutableList.of( customBindings, diff --git a/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/test/CalciteSelectJoinQueryMSQTest.java b/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/test/CalciteSelectJoinQueryMSQTest.java index d59bf6f027be..c249df61ebab 100644 --- a/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/test/CalciteSelectJoinQueryMSQTest.java +++ b/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/test/CalciteSelectJoinQueryMSQTest.java @@ -137,7 +137,14 @@ public SqlEngine createEngine( ) { final WorkerMemoryParameters workerMemoryParameters = - WorkerMemoryParameters.createInstance(WorkerMemoryParameters.PROCESSING_MINIMUM_BYTES * 50, 2, 10, 2, 0, 0); + WorkerMemoryParameters.createInstance( + WorkerMemoryParameters.PROCESSING_MINIMUM_BYTES * 50, + 2, + 10, + 2, + 0, + 0 + ); final MSQTestOverlordServiceClient indexingServiceClient = new MSQTestOverlordServiceClient( queryJsonMapper, injector, diff --git a/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/test/MSQTestBase.java b/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/test/MSQTestBase.java index 3b5e14cb2f55..fe78b481bee4 100644 --- a/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/test/MSQTestBase.java +++ b/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/test/MSQTestBase.java @@ -46,6 +46,7 @@ import org.apache.druid.discovery.BrokerClient; import org.apache.druid.discovery.NodeRole; import org.apache.druid.frame.channel.FrameChannelSequence; +import org.apache.druid.frame.processor.Bouncer; import org.apache.druid.frame.testutil.FrameTestUtil; import org.apache.druid.guice.DruidInjectorBuilder; import org.apache.druid.guice.DruidSecondaryModule; @@ -58,7 +59,6 @@ import org.apache.druid.guice.SegmentWranglerModule; import org.apache.druid.guice.StartupInjectorBuilder; import org.apache.druid.guice.annotations.EscalatedGlobal; -import org.apache.druid.guice.annotations.MSQ; import org.apache.druid.guice.annotations.Self; import org.apache.druid.hll.HyperLogLogCollector; import org.apache.druid.indexing.common.SegmentCacheManagerFactory; @@ -73,7 +73,6 @@ import org.apache.druid.java.util.common.concurrent.Execs; import org.apache.druid.java.util.common.granularity.Granularities; import org.apache.druid.java.util.common.granularity.Granularity; -import org.apache.druid.java.util.common.guava.Yielder; import org.apache.druid.java.util.common.io.Closer; import org.apache.druid.java.util.common.logger.Logger; import org.apache.druid.java.util.http.client.Request; @@ -86,6 +85,7 @@ import org.apache.druid.msq.exec.Controller; import org.apache.druid.msq.exec.DataServerQueryHandler; import org.apache.druid.msq.exec.DataServerQueryHandlerFactory; +import org.apache.druid.msq.exec.ResultsContext; import org.apache.druid.msq.exec.WorkerMemoryParameters; import org.apache.druid.msq.guice.MSQDurableStorageModule; import org.apache.druid.msq.guice.MSQExternalDataSourceModule; @@ -504,7 +504,9 @@ public String getFormatString() // following bindings are overriding other bindings that end up needing a lot more dependencies. // We replace the bindings with something that returns null to make things more brittle in case they // actually are used somewhere in the test. - binder.bind(SqlStatementFactory.class).annotatedWith(MSQ.class).toProvider(Providers.of(null)); + binder.bind(SqlStatementFactory.class) + .annotatedWith(MultiStageQuery.class) + .toProvider(Providers.of(null)); binder.bind(SqlToolbox.class).toProvider(Providers.of(null)); binder.bind(MSQTaskSqlEngine.class).toProvider(Providers.of(null)); } @@ -514,7 +516,8 @@ public String getFormatString() new LookylooModule(), new SegmentWranglerModule(), new HllSketchModule(), - binder -> binder.bind(BrokerClient.class).toInstance(brokerClient) + binder -> binder.bind(BrokerClient.class).toInstance(brokerClient), + binder -> binder.bind(Bouncer.class).toInstance(new Bouncer(1)) ); // adding node role injection to the modules, since CliPeon would also do that through run method Injector injector = new CoreInjectorBuilder(new StartupInjectorBuilder().build(), ImmutableSet.of(NodeRole.PEON)) @@ -835,20 +838,7 @@ public static List getRows(@Nullable MSQResultsReport resultsReport) if (resultsReport == null) { return null; } else { - Yielder yielder = resultsReport.getResultYielder(); - List rows = new ArrayList<>(); - while (!yielder.isDone()) { - rows.add(yielder.get()); - yielder = yielder.next(null); - } - try { - yielder.close(); - } - catch (IOException e) { - throw new ISE("Unable to get results from the report"); - } - - return rows; + return resultsReport.getResults(); } } @@ -1436,9 +1426,10 @@ public Pair, List>> pageInformation.getWorker() == null ? 0 : pageInformation.getWorker(), pageInformation.getPartition() == null ? 0 : pageInformation.getPartition() )).flatMap(frame -> SqlStatementResourceHelper.getResultSequence( - msqControllerTask, - finalStage, frame, + finalStage.getFrameReader(), + msqControllerTask.getQuerySpec().getColumnMappings(), + new ResultsContext(msqControllerTask.getSqlTypeNames(), msqControllerTask.getSqlResultsContext()), objectMapper )).withBaggage(closer).toList()); } diff --git a/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/test/MSQTestControllerClient.java b/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/test/MSQTestControllerClient.java index 3e78e477bda9..96e26cba77e1 100644 --- a/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/test/MSQTestControllerClient.java +++ b/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/test/MSQTestControllerClient.java @@ -54,6 +54,12 @@ public void postPartialKeyStatistics( } } + @Override + public void postDoneReadingInput(StageId stageId, int workerNumber) + { + controller.doneReadingInput(stageId.getStageNumber(), workerNumber); + } + @Override public void postCounters(String workerId, CounterSnapshotsTree snapshotsTree) { diff --git a/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/test/MSQTestControllerContext.java b/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/test/MSQTestControllerContext.java index 45de3d7c4f50..20d31fbd4cfe 100644 --- a/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/test/MSQTestControllerContext.java +++ b/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/test/MSQTestControllerContext.java @@ -20,6 +20,7 @@ package org.apache.druid.msq.test; import com.fasterxml.jackson.databind.ObjectMapper; +import com.google.common.collect.ImmutableMap; import com.google.common.util.concurrent.FutureCallback; import com.google.common.util.concurrent.Futures; import com.google.common.util.concurrent.ListenableFuture; @@ -28,33 +29,47 @@ import com.google.inject.Injector; import org.apache.druid.client.ImmutableSegmentLoadInfo; import org.apache.druid.client.coordinator.CoordinatorClient; +import org.apache.druid.client.indexing.NoopOverlordClient; +import org.apache.druid.client.indexing.TaskStatusResponse; +import org.apache.druid.common.guava.FutureUtils; +import org.apache.druid.indexer.RunnerTaskState; import org.apache.druid.indexer.TaskLocation; import org.apache.druid.indexer.TaskState; import org.apache.druid.indexer.TaskStatus; -import org.apache.druid.indexer.report.TaskReport; +import org.apache.druid.indexer.TaskStatusPlus; import org.apache.druid.indexing.common.actions.TaskActionClient; +import org.apache.druid.java.util.common.DateTimes; import org.apache.druid.java.util.common.ISE; import org.apache.druid.java.util.common.concurrent.Execs; import org.apache.druid.java.util.common.io.Closer; import org.apache.druid.java.util.common.logger.Logger; -import org.apache.druid.java.util.emitter.service.ServiceEmitter; import org.apache.druid.msq.exec.Controller; import org.apache.druid.msq.exec.ControllerContext; +import org.apache.druid.msq.exec.ControllerMemoryParameters; import org.apache.druid.msq.exec.Worker; import org.apache.druid.msq.exec.WorkerClient; +import org.apache.druid.msq.exec.WorkerFailureListener; import org.apache.druid.msq.exec.WorkerImpl; -import org.apache.druid.msq.exec.WorkerManagerClient; +import org.apache.druid.msq.exec.WorkerManager; import org.apache.druid.msq.exec.WorkerMemoryParameters; import org.apache.druid.msq.exec.WorkerStorageParameters; +import org.apache.druid.msq.indexing.IndexerControllerContext; +import org.apache.druid.msq.indexing.MSQSpec; import org.apache.druid.msq.indexing.MSQWorkerTask; +import org.apache.druid.msq.indexing.MSQWorkerTaskLauncher; +import org.apache.druid.msq.input.InputSpecSlicer; +import org.apache.druid.msq.input.table.TableInputSpecSlicer; +import org.apache.druid.msq.kernel.QueryDefinition; +import org.apache.druid.msq.kernel.controller.ControllerQueryKernelConfig; import org.apache.druid.msq.util.MultiStageQueryContext; import org.apache.druid.query.QueryContext; +import org.apache.druid.rpc.indexing.OverlordClient; import org.apache.druid.server.DruidNode; -import org.apache.druid.server.metrics.NoopServiceEmitter; import org.mockito.ArgumentMatchers; import org.mockito.Mockito; import javax.annotation.Nullable; +import java.util.Collections; import java.util.HashMap; import java.util.List; import java.util.Map; @@ -72,7 +87,8 @@ public class MSQTestControllerContext implements ControllerContext private final ConcurrentMap statusMap = new ConcurrentHashMap<>(); private final ListeningExecutorService executor = MoreExecutors.listeningDecorator(Execs.multiThreaded( NUM_WORKERS, - "MultiStageQuery-test-controller-client")); + "MultiStageQuery-test-controller-client" + )); private final CoordinatorClient coordinatorClient; private final DruidNode node = new DruidNode( "controller", @@ -85,18 +101,18 @@ public class MSQTestControllerContext implements ControllerContext ); private final Injector injector; private final ObjectMapper mapper; - private final ServiceEmitter emitter = new NoopServiceEmitter(); private Controller controller; - private TaskReport.ReportMap report = null; private final WorkerMemoryParameters workerMemoryParameters; + private final QueryContext queryContext; public MSQTestControllerContext( ObjectMapper mapper, Injector injector, TaskActionClient taskActionClient, WorkerMemoryParameters workerMemoryParameters, - List loadedSegments + List loadedSegments, + QueryContext queryContext ) { this.mapper = mapper; @@ -105,8 +121,8 @@ public MSQTestControllerContext( coordinatorClient = Mockito.mock(CoordinatorClient.class); Mockito.when(coordinatorClient.fetchServerViewSegments( - ArgumentMatchers.anyString(), - ArgumentMatchers.any() + ArgumentMatchers.anyString(), + ArgumentMatchers.any() ) ).thenAnswer(invocation -> loadedSegments.stream() .filter(immutableSegmentLoadInfo -> @@ -116,13 +132,15 @@ public MSQTestControllerContext( .collect(Collectors.toList()) ); this.workerMemoryParameters = workerMemoryParameters; + this.queryContext = queryContext; } - WorkerManagerClient workerManagerClient = new WorkerManagerClient() + OverlordClient overlordClient = new NoopOverlordClient() { @Override - public String run(String taskId, MSQWorkerTask task) + public ListenableFuture runTask(String taskId, Object taskObject) { + final MSQWorkerTask task = (MSQWorkerTask) taskObject; if (controller == null) { throw new ISE("Controller needs to be set using the register method"); } @@ -137,13 +155,26 @@ public String run(String taskId, MSQWorkerTask task) Worker worker = new WorkerImpl( task, - new MSQTestWorkerContext(inMemoryWorkers, controller, mapper, injector, workerMemoryParameters), + new MSQTestWorkerContext( + inMemoryWorkers, + controller, + mapper, + injector, + workerMemoryParameters + ), workerStorageParameters ); inMemoryWorkers.put(task.getId(), worker); statusMap.put(task.getId(), TaskStatus.running(task.getId())); - ListenableFuture future = executor.submit(worker::run); + ListenableFuture future = executor.submit(() -> { + try { + return worker.run(); + } + catch (Exception e) { + throw new RuntimeException(e); + } + }); Futures.addCallback(future, new FutureCallback() { @@ -161,11 +192,11 @@ public void onFailure(Throwable t) } }, MoreExecutors.directExecutor()); - return task.getId(); + return Futures.immediateFuture(null); } @Override - public Map statuses(Set taskIds) + public ListenableFuture> taskStatuses(Set taskIds) { Map result = new HashMap<>(); for (String taskId : taskIds) { @@ -188,40 +219,63 @@ public Map statuses(Set taskIds) } } } - return result; + return Futures.immediateFuture(result); } @Override - public TaskLocation location(String workerId) + public ListenableFuture taskStatus(String taskId) { - final TaskStatus status = statusMap.get(workerId); - if (status != null && status.getStatusCode().equals(TaskState.RUNNING) && inMemoryWorkers.containsKey(workerId)) { - return TaskLocation.create("host-" + workerId, 1, -1); + final Map taskStatusMap = + FutureUtils.getUnchecked(taskStatuses(Collections.singleton(taskId)), true); + + final TaskStatus taskStatus = taskStatusMap.get(taskId); + if (taskStatus == null) { + return Futures.immediateFuture(new TaskStatusResponse(taskId, null)); } else { - return TaskLocation.unknown(); + return Futures.immediateFuture( + new TaskStatusResponse( + taskId, + new TaskStatusPlus( + taskStatus.getId(), + null, + null, + DateTimes.utc(0), + DateTimes.utc(0), + taskStatus.getStatusCode(), + taskStatus.getStatusCode(), + taskStatus.getStatusCode().isRunnable() ? RunnerTaskState.RUNNING : RunnerTaskState.NONE, + null, + taskStatus.getStatusCode().isRunnable() + ? TaskLocation.create("host-" + taskId, 1, -1) + : TaskLocation.unknown(), + null, + taskStatus.getErrorMsg() + ) + ) + ); } } @Override - public void cancel(String workerId) + public ListenableFuture cancelTask(String workerId) { final Worker worker = inMemoryWorkers.remove(workerId); if (worker != null) { worker.stopGracefully(); } - } - - @Override - public void close() - { - //do nothing + return Futures.immediateFuture(null); } }; @Override - public ServiceEmitter emitter() + public ControllerQueryKernelConfig queryKernelConfig(MSQSpec querySpec, QueryDefinition queryDef) + { + return IndexerControllerContext.makeQueryKernelConfig(querySpec, new ControllerMemoryParameters(100_000_000)); + } + + @Override + public void emitMetric(String metric, Number value) { - return emitter; } @Override @@ -243,21 +297,37 @@ public DruidNode selfNode() } @Override - public CoordinatorClient coordinatorClient() + public TaskActionClient taskActionClient() { - return coordinatorClient; + return taskActionClient; } @Override - public TaskActionClient taskActionClient() + public InputSpecSlicer newTableInputSpecSlicer() { - return taskActionClient; + return new TableInputSpecSlicer( + coordinatorClient, + taskActionClient, + MultiStageQueryContext.getSegmentSources(queryContext) + ); } @Override - public WorkerManagerClient workerManager() + public WorkerManager newWorkerManager( + String queryId, + MSQSpec querySpec, + ControllerQueryKernelConfig queryKernelConfig, + WorkerFailureListener workerFailureListener + ) { - return workerManagerClient; + return new MSQWorkerTaskLauncher( + controller.queryId(), + "test-datasource", + overlordClient, + workerFailureListener, + IndexerControllerContext.makeTaskContext(querySpec, queryKernelConfig, ImmutableMap.of()), + 0 + ); } @Override @@ -267,21 +337,8 @@ public void registerController(Controller controller, Closer closer) } @Override - public WorkerClient taskClientFor(Controller controller) + public WorkerClient newWorkerClient() { return new MSQTestWorkerClient(inMemoryWorkers); } - - @Override - public void writeReports(String controllerTaskId, TaskReport.ReportMap taskReport) - { - if (controller != null && controller.id().equals(controllerTaskId)) { - report = taskReport; - } - } - - public TaskReport.ReportMap getAllReports() - { - return report; - } } diff --git a/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/test/MSQTestOverlordServiceClient.java b/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/test/MSQTestOverlordServiceClient.java index 4a5ac7e84e64..a565283154fd 100644 --- a/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/test/MSQTestOverlordServiceClient.java +++ b/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/test/MSQTestOverlordServiceClient.java @@ -20,10 +20,13 @@ package org.apache.druid.msq.test; import com.fasterxml.jackson.databind.ObjectMapper; +import com.google.common.base.Preconditions; +import com.google.common.collect.Iterables; import com.google.common.util.concurrent.Futures; import com.google.common.util.concurrent.ListenableFuture; import com.google.common.util.concurrent.SettableFuture; import com.google.inject.Injector; +import org.apache.calcite.sql.type.SqlTypeName; import org.apache.druid.client.ImmutableSegmentLoadInfo; import org.apache.druid.client.indexing.NoopOverlordClient; import org.apache.druid.client.indexing.TaskPayloadResponse; @@ -36,11 +39,19 @@ import org.apache.druid.java.util.common.ISE; import org.apache.druid.msq.exec.Controller; import org.apache.druid.msq.exec.ControllerImpl; +import org.apache.druid.msq.exec.QueryListener; +import org.apache.druid.msq.exec.ResultsContext; import org.apache.druid.msq.exec.WorkerMemoryParameters; import org.apache.druid.msq.indexing.MSQControllerTask; +import org.apache.druid.msq.indexing.destination.MSQDestination; +import org.apache.druid.msq.indexing.report.MSQResultsReport; +import org.apache.druid.msq.indexing.report.MSQStatusReport; +import org.apache.druid.msq.indexing.report.MSQTaskReport; +import org.apache.druid.msq.indexing.report.MSQTaskReportPayload; import org.joda.time.DateTime; import javax.annotation.Nullable; +import java.util.ArrayList; import java.util.HashMap; import java.util.List; import java.util.Map; @@ -80,33 +91,53 @@ public MSQTestOverlordServiceClient( @Override public ListenableFuture runTask(String taskId, Object taskObject) { + TestQueryListener queryListener = null; ControllerImpl controller = null; - MSQTestControllerContext msqTestControllerContext = null; + MSQTestControllerContext msqTestControllerContext; try { + MSQControllerTask cTask = objectMapper.convertValue(taskObject, MSQControllerTask.class); + msqTestControllerContext = new MSQTestControllerContext( objectMapper, injector, taskActionClient, workerMemoryParameters, - loadedSegmentMetadata + loadedSegmentMetadata, + cTask.getQuerySpec().getQuery().context() ); - MSQControllerTask cTask = objectMapper.convertValue(taskObject, MSQControllerTask.class); inMemoryControllerTask.put(cTask.getId(), cTask); - controller = new ControllerImpl(cTask, msqTestControllerContext); + controller = new ControllerImpl( + cTask.getId(), + cTask.getQuerySpec(), + new ResultsContext(cTask.getSqlTypeNames(), cTask.getSqlResultsContext()), + msqTestControllerContext + ); + + inMemoryControllers.put(controller.queryId(), controller); - inMemoryControllers.put(controller.id(), controller); + queryListener = + new TestQueryListener( + cTask.getId(), + cTask.getQuerySpec().getDestination() + ); - inMemoryTaskStatus.put(taskId, controller.run()); + try { + controller.run(queryListener); + inMemoryTaskStatus.put(taskId, queryListener.getStatusReport().toTaskStatus(cTask.getId())); + } + catch (Exception e) { + inMemoryTaskStatus.put(taskId, TaskStatus.failure(cTask.getId(), e.toString())); + } return Futures.immediateFuture(null); } catch (Exception e) { throw new ISE(e, "Unable to run"); } finally { - if (controller != null && msqTestControllerContext != null) { - reports.put(controller.id(), msqTestControllerContext.getAllReports()); + if (controller != null && queryListener != null) { + reports.put(controller.queryId(), queryListener.getReportMap()); } } } @@ -114,7 +145,7 @@ public ListenableFuture runTask(String taskId, Object taskObject) @Override public ListenableFuture cancelTask(String taskId) { - inMemoryControllers.get(taskId).stopGracefully(); + inMemoryControllers.get(taskId).stop(); return Futures.immediateFuture(null); } @@ -166,4 +197,96 @@ MSQControllerTask getMSQControllerTask(String id) { return inMemoryControllerTask.get(id); } + + /** + * Listener that captures a report and makes it available through {@link #getReportMap()}. + */ + static class TestQueryListener implements QueryListener + { + private final String taskId; + private final MSQDestination destination; + private final List results = new ArrayList<>(); + + private List signature; + private List sqlTypeNames; + private boolean resultsTruncated = true; + private TaskReport.ReportMap reportMap; + + public TestQueryListener(final String taskId, final MSQDestination destination) + { + this.taskId = taskId; + this.destination = destination; + } + + @Override + public boolean readResults() + { + return destination.getRowsInTaskReport() == MSQDestination.UNLIMITED || destination.getRowsInTaskReport() > 0; + } + + @Override + public void onResultsStart(List signature, @Nullable List sqlTypeNames) + { + this.signature = signature; + this.sqlTypeNames = sqlTypeNames; + } + + @Override + public boolean onResultRow(Object[] row) + { + if (destination.getRowsInTaskReport() == MSQDestination.UNLIMITED + || results.size() < destination.getRowsInTaskReport()) { + results.add(row); + return true; + } else { + return false; + } + } + + @Override + public void onResultsComplete() + { + resultsTruncated = false; + } + + @Override + public void onQueryComplete(MSQTaskReportPayload report) + { + final MSQResultsReport resultsReport; + + if (signature != null) { + resultsReport = new MSQResultsReport( + signature, + sqlTypeNames, + results, + resultsTruncated + ); + } else { + resultsReport = null; + } + + final MSQTaskReport taskReport = new MSQTaskReport( + taskId, + new MSQTaskReportPayload( + report.getStatus(), + report.getStages(), + report.getCounters(), + resultsReport + ) + ); + + reportMap = TaskReport.buildTaskReports(taskReport); + } + + public TaskReport.ReportMap getReportMap() + { + return Preconditions.checkNotNull(reportMap, "reportMap"); + } + + public MSQStatusReport getStatusReport() + { + final MSQTaskReport taskReport = (MSQTaskReport) Iterables.getOnlyElement(getReportMap().values()); + return taskReport.getPayload().getStatus(); + } + } } diff --git a/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/test/MSQTestWorkerClient.java b/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/test/MSQTestWorkerClient.java index ae892c34500a..72cb246a43e1 100644 --- a/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/test/MSQTestWorkerClient.java +++ b/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/test/MSQTestWorkerClient.java @@ -54,24 +54,22 @@ public ListenableFuture postWorkOrder(String workerTaskId, WorkOrder workO @Override public ListenableFuture fetchClusterByStatisticsSnapshot( String workerTaskId, - String queryId, - int stageNumber + StageId stageId ) { - StageId stageId = new StageId(queryId, stageNumber); return Futures.immediateFuture(inMemoryWorkers.get(workerTaskId).fetchStatisticsSnapshot(stageId)); } @Override public ListenableFuture fetchClusterByStatisticsSnapshotForTimeChunk( String workerTaskId, - String queryId, - int stageNumber, + StageId stageId, long timeChunk ) { - StageId stageId = new StageId(queryId, stageNumber); - return Futures.immediateFuture(inMemoryWorkers.get(workerTaskId).fetchStatisticsSnapshotForTimeChunk(stageId, timeChunk)); + return Futures.immediateFuture( + inMemoryWorkers.get(workerTaskId).fetchStatisticsSnapshotForTimeChunk(stageId, timeChunk) + ); } @Override @@ -123,20 +121,19 @@ public ListenableFuture fetchChannelData( final ReadableByteChunksFrameChannel channel ) { - try (InputStream inputStream = inMemoryWorkers.get(workerTaskId).readChannel( - stageId.getQueryId(), - stageId.getStageNumber(), - partitionNumber, - offset - )) { + try (InputStream inputStream = + inMemoryWorkers.get(workerTaskId) + .readChannel(stageId.getQueryId(), stageId.getStageNumber(), partitionNumber, offset)) { byte[] buffer = new byte[8 * 1024]; + boolean didRead = false; int bytesRead; while ((bytesRead = inputStream.read(buffer)) != -1) { channel.addChunk(Arrays.copyOf(buffer, bytesRead)); + didRead = true; } inputStream.close(); - return Futures.immediateFuture(true); + return Futures.immediateFuture(!didRead); } catch (Exception e) { throw new ISE(e, "Error reading frame file channel"); diff --git a/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/test/MSQTestWorkerContext.java b/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/test/MSQTestWorkerContext.java index d2283a94be04..ad05c20b5829 100644 --- a/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/test/MSQTestWorkerContext.java +++ b/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/test/MSQTestWorkerContext.java @@ -22,9 +22,9 @@ import com.fasterxml.jackson.databind.ObjectMapper; import com.google.inject.Injector; import org.apache.druid.frame.processor.Bouncer; -import org.apache.druid.indexer.report.TaskReport; import org.apache.druid.indexer.report.TaskReportFileWriter; import org.apache.druid.indexing.common.TaskToolbox; +import org.apache.druid.indexing.common.task.NoopTestTaskReportFileWriter; import org.apache.druid.java.util.common.FileUtils; import org.apache.druid.java.util.common.io.Closer; import org.apache.druid.msq.exec.Controller; @@ -123,20 +123,7 @@ public FrameContext frameContext(QueryDefinition queryDef, int stageNumber) OffHeapMemorySegmentWriteOutMediumFactory.instance(), true ); - final TaskReportFileWriter reportFileWriter = new TaskReportFileWriter() - { - @Override - public void write(String taskId, TaskReport.ReportMap reports) - { - - } - - @Override - public void setObjectMapper(ObjectMapper objectMapper) - { - - } - }; + final TaskReportFileWriter reportFileWriter = new NoopTestTaskReportFileWriter(); return new IndexerFrameContext( new IndexerWorkerContext( diff --git a/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/test/NoopQueryListener.java b/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/test/NoopQueryListener.java new file mode 100644 index 000000000000..fe38819a4519 --- /dev/null +++ b/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/test/NoopQueryListener.java @@ -0,0 +1,61 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.druid.msq.test; + +import org.apache.calcite.sql.type.SqlTypeName; +import org.apache.druid.msq.exec.QueryListener; +import org.apache.druid.msq.indexing.report.MSQResultsReport; +import org.apache.druid.msq.indexing.report.MSQTaskReportPayload; + +import javax.annotation.Nullable; +import java.util.List; + +public class NoopQueryListener implements QueryListener +{ + @Override + public boolean readResults() + { + return false; + } + + @Override + public void onResultsStart(List signature, @Nullable List sqlTypeNames) + { + // Do nothing. + } + + @Override + public boolean onResultRow(Object[] row) + { + return true; + } + + @Override + public void onResultsComplete() + { + // Do nothing. + } + + @Override + public void onQueryComplete(MSQTaskReportPayload report) + { + // Do nothing. + } +} diff --git a/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/util/SqlStatementResourceHelperTest.java b/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/util/SqlStatementResourceHelperTest.java index 3c14f4f1cd9f..1966d1e5b10a 100644 --- a/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/util/SqlStatementResourceHelperTest.java +++ b/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/util/SqlStatementResourceHelperTest.java @@ -27,6 +27,7 @@ import org.apache.druid.msq.counters.ChannelCounters; import org.apache.druid.msq.counters.CounterSnapshots; import org.apache.druid.msq.counters.CounterSnapshotsTree; +import org.apache.druid.msq.exec.OutputChannelMode; import org.apache.druid.msq.indexing.destination.DataSourceMSQDestination; import org.apache.druid.msq.indexing.destination.DurableStorageMSQDestination; import org.apache.druid.msq.indexing.destination.TaskReportMSQDestination; @@ -77,7 +78,8 @@ public void testDistinctPartitionsOnEachWorker() ImmutableMap.of(), ImmutableMap.of(), ImmutableMap.of(0, 3), - ImmutableMap.of(0, 15) + ImmutableMap.of(0, 15), + ImmutableMap.of(0, OutputChannelMode.LOCAL_STORAGE) ), counterSnapshots, null); Optional> pages = SqlStatementResourceHelper.populatePageList( @@ -117,7 +119,8 @@ public void testOnePartitionOnEachWorker() ImmutableMap.of(), ImmutableMap.of(), ImmutableMap.of(0, 4), - ImmutableMap.of(0, 4) + ImmutableMap.of(0, 4), + ImmutableMap.of(0, OutputChannelMode.LOCAL_STORAGE) ), counterSnapshots, null); Optional> pages = SqlStatementResourceHelper.populatePageList( @@ -158,7 +161,8 @@ public void testCommonPartitionsOnEachWorker() ImmutableMap.of(), ImmutableMap.of(), ImmutableMap.of(0, 4), - ImmutableMap.of(0, 21) + ImmutableMap.of(0, 21), + ImmutableMap.of(0, OutputChannelMode.LOCAL_STORAGE) ), counterSnapshots, null); Optional> pages = @@ -197,7 +201,8 @@ public void testNullChannelCounters() ImmutableMap.of(), ImmutableMap.of(), ImmutableMap.of(0, 4), - ImmutableMap.of(0, 21) + ImmutableMap.of(0, 21), + ImmutableMap.of(0, OutputChannelMode.LOCAL_STORAGE) ), counterSnapshots, null); Optional> pages = SqlStatementResourceHelper.populatePageList( @@ -237,7 +242,8 @@ public void testConsecutivePartitionsOnEachWorker() ImmutableMap.of(), ImmutableMap.of(), ImmutableMap.of(0, 4), - ImmutableMap.of(0, 13) + ImmutableMap.of(0, 13), + ImmutableMap.of(0, OutputChannelMode.LOCAL_STORAGE) ), counterSnapshots, null); Optional> pages = SqlStatementResourceHelper.populatePageList( @@ -278,7 +284,8 @@ public void testEmptyCountersForDurableStorageDestination() ImmutableMap.of(), ImmutableMap.of(), ImmutableMap.of(0, 1), - ImmutableMap.of(0, 1) + ImmutableMap.of(0, 1), + ImmutableMap.of(0, OutputChannelMode.LOCAL_STORAGE) ), counterSnapshots, null @@ -315,7 +322,8 @@ public void testEmptyCountersForTaskReportDestination() ImmutableMap.of(), ImmutableMap.of(), ImmutableMap.of(0, 1), - ImmutableMap.of(0, 1) + ImmutableMap.of(0, 1), + ImmutableMap.of(0, OutputChannelMode.LOCAL_STORAGE) ), counterSnapshots, null @@ -354,7 +362,8 @@ public void testEmptyCountersForDataSourceDestination() ImmutableMap.of(), ImmutableMap.of(), ImmutableMap.of(0, 1), - ImmutableMap.of(0, 1) + ImmutableMap.of(0, 1), + ImmutableMap.of(0, OutputChannelMode.LOCAL_STORAGE) ), counterSnapshots, null diff --git a/indexing-service/src/main/java/org/apache/druid/indexing/common/MultipleFileTaskReportFileWriter.java b/indexing-service/src/main/java/org/apache/druid/indexing/common/MultipleFileTaskReportFileWriter.java index 28cc1ae2af5e..865a8593a7d3 100644 --- a/indexing-service/src/main/java/org/apache/druid/indexing/common/MultipleFileTaskReportFileWriter.java +++ b/indexing-service/src/main/java/org/apache/druid/indexing/common/MultipleFileTaskReportFileWriter.java @@ -24,10 +24,13 @@ import org.apache.druid.indexer.report.TaskReport; import org.apache.druid.indexer.report.TaskReportFileWriter; import org.apache.druid.java.util.common.FileUtils; +import org.apache.druid.java.util.common.ISE; import org.apache.druid.java.util.common.logger.Logger; import java.io.File; -import java.io.FileOutputStream; +import java.io.IOException; +import java.io.OutputStream; +import java.nio.file.Files; import java.util.HashMap; import java.util.Map; @@ -41,26 +44,29 @@ public class MultipleFileTaskReportFileWriter implements TaskReportFileWriter @Override public void write(String taskId, TaskReport.ReportMap reports) + { + try (final OutputStream outputStream = openReportOutputStream(taskId)) { + SingleFileTaskReportFileWriter.writeReportToStream(objectMapper, outputStream, reports); + } + catch (Exception e) { + log.error(e, "Encountered exception in write()."); + } + } + + @Override + public OutputStream openReportOutputStream(String taskId) throws IOException { final File reportsFile = taskReportFiles.get(taskId); if (reportsFile == null) { - log.error("Could not find report file for task[%s]", taskId); - return; + throw new ISE("Could not find report file for task[%s]", taskId); } - try { - final File reportsFileParent = reportsFile.getParentFile(); - if (reportsFileParent != null) { - FileUtils.mkdirp(reportsFileParent); - } - - try (final FileOutputStream outputStream = new FileOutputStream(reportsFile)) { - SingleFileTaskReportFileWriter.writeReportToStream(objectMapper, outputStream, reports); - } - } - catch (Exception e) { - log.error(e, "Encountered exception in write()."); + final File reportsFileParent = reportsFile.getParentFile(); + if (reportsFileParent != null) { + FileUtils.mkdirp(reportsFileParent); } + + return Files.newOutputStream(reportsFile.toPath()); } @Override diff --git a/indexing-service/src/test/java/org/apache/druid/indexing/common/MultipleFileTaskReportFileWriterTest.java b/indexing-service/src/test/java/org/apache/druid/indexing/common/MultipleFileTaskReportFileWriterTest.java new file mode 100644 index 000000000000..2e51973ec241 --- /dev/null +++ b/indexing-service/src/test/java/org/apache/druid/indexing/common/MultipleFileTaskReportFileWriterTest.java @@ -0,0 +1,64 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.druid.indexing.common; + +import com.fasterxml.jackson.core.type.TypeReference; +import com.fasterxml.jackson.databind.ObjectMapper; +import org.apache.druid.indexer.report.IngestionStatsAndErrorsTaskReport; +import org.apache.druid.indexer.report.TaskReport; +import org.apache.druid.segment.TestHelper; +import org.junit.Assert; +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.TemporaryFolder; + +import java.io.File; +import java.io.IOException; +import java.nio.file.Files; +import java.util.Map; + +public class MultipleFileTaskReportFileWriterTest +{ + private static final String TASK_ID = "mytask"; + + @Rule + public final TemporaryFolder tempFolder = new TemporaryFolder(); + + @Test + public void testReport() throws IOException + { + final ObjectMapper mapper = TestHelper.makeJsonMapper(); + final File file = tempFolder.newFile(); + final MultipleFileTaskReportFileWriter writer = new MultipleFileTaskReportFileWriter(); + writer.setObjectMapper(mapper); + writer.add(TASK_ID, file); + + final TaskReport.ReportMap reportsMap = TaskReport.buildTaskReports( + new IngestionStatsAndErrorsTaskReport(TASK_ID, null) + ); + + writer.write(TASK_ID, reportsMap); + + Assert.assertEquals( + reportsMap, + mapper.readValue(Files.readAllBytes(file.toPath()), new TypeReference>() {}) + ); + } +} diff --git a/indexing-service/src/test/java/org/apache/druid/indexing/common/SingleFileTaskReportFileWriterTest.java b/indexing-service/src/test/java/org/apache/druid/indexing/common/SingleFileTaskReportFileWriterTest.java new file mode 100644 index 000000000000..1381a7483cb2 --- /dev/null +++ b/indexing-service/src/test/java/org/apache/druid/indexing/common/SingleFileTaskReportFileWriterTest.java @@ -0,0 +1,61 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.druid.indexing.common; + +import com.fasterxml.jackson.core.type.TypeReference; +import com.fasterxml.jackson.databind.ObjectMapper; +import org.apache.druid.indexer.report.IngestionStatsAndErrorsTaskReport; +import org.apache.druid.indexer.report.SingleFileTaskReportFileWriter; +import org.apache.druid.indexer.report.TaskReport; +import org.apache.druid.segment.TestHelper; +import org.junit.Assert; +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.TemporaryFolder; + +import java.io.File; +import java.io.IOException; +import java.nio.file.Files; +import java.util.Map; + +public class SingleFileTaskReportFileWriterTest +{ + private static final String TASK_ID = "mytask"; + + @Rule + public final TemporaryFolder tempFolder = new TemporaryFolder(); + + @Test + public void testReport() throws IOException + { + final ObjectMapper mapper = TestHelper.makeJsonMapper(); + final File file = tempFolder.newFile(); + final SingleFileTaskReportFileWriter writer = new SingleFileTaskReportFileWriter(file); + writer.setObjectMapper(mapper); + final TaskReport.ReportMap reportsMap = TaskReport.buildTaskReports( + new IngestionStatsAndErrorsTaskReport(TASK_ID, null) + ); + writer.write(TASK_ID, reportsMap); + Assert.assertEquals( + reportsMap, + mapper.readValue(Files.readAllBytes(file.toPath()), new TypeReference>() {}) + ); + } +} diff --git a/indexing-service/src/test/java/org/apache/druid/indexing/common/task/NoopTestTaskReportFileWriter.java b/indexing-service/src/test/java/org/apache/druid/indexing/common/task/NoopTestTaskReportFileWriter.java index 7e7860e9d8e2..ad175faeb49b 100644 --- a/indexing-service/src/test/java/org/apache/druid/indexing/common/task/NoopTestTaskReportFileWriter.java +++ b/indexing-service/src/test/java/org/apache/druid/indexing/common/task/NoopTestTaskReportFileWriter.java @@ -23,6 +23,9 @@ import org.apache.druid.indexer.report.TaskReport; import org.apache.druid.indexer.report.TaskReportFileWriter; +import java.io.ByteArrayOutputStream; +import java.io.OutputStream; + public class NoopTestTaskReportFileWriter implements TaskReportFileWriter { @Override @@ -30,6 +33,13 @@ public void write(String id, TaskReport.ReportMap reports) { } + @Override + public OutputStream openReportOutputStream(String taskId) + { + // Stream to nowhere. + return new ByteArrayOutputStream(); + } + @Override public void setObjectMapper(ObjectMapper objectMapper) { diff --git a/integration-tests-ex/cases/src/test/java/org/apache/druid/testsEx/msq/ITMultiStageQuery.java b/integration-tests-ex/cases/src/test/java/org/apache/druid/testsEx/msq/ITMultiStageQuery.java index 0fe486407db8..7ede24cd8f9a 100644 --- a/integration-tests-ex/cases/src/test/java/org/apache/druid/testsEx/msq/ITMultiStageQuery.java +++ b/integration-tests-ex/cases/src/test/java/org/apache/druid/testsEx/msq/ITMultiStageQuery.java @@ -26,7 +26,6 @@ import org.apache.druid.indexer.report.TaskReport; import org.apache.druid.java.util.common.ISE; import org.apache.druid.java.util.common.StringUtils; -import org.apache.druid.java.util.common.guava.Yielder; import org.apache.druid.msq.indexing.report.MSQResultsReport; import org.apache.druid.msq.indexing.report.MSQTaskReport; import org.apache.druid.msq.indexing.report.MSQTaskReportPayload; @@ -253,13 +252,10 @@ public void testExport() throws Exception "Results report for the task id is empty" ); - Yielder yielder = resultsReport.getResultYielder(); List> actualResults = new ArrayList<>(); - while (!yielder.isDone()) { - Object[] row = yielder.get(); + for (final Object[] row : resultsReport.getResults()) { actualResults.add(Arrays.asList(row)); - yielder = yielder.next(null); } ImmutableList> expectedResults = ImmutableList.of( diff --git a/integration-tests/src/main/java/org/apache/druid/testing/utils/MsqTestQueryHelper.java b/integration-tests/src/main/java/org/apache/druid/testing/utils/MsqTestQueryHelper.java index c5fc437fc9c7..2a2386869e4a 100644 --- a/integration-tests/src/main/java/org/apache/druid/testing/utils/MsqTestQueryHelper.java +++ b/integration-tests/src/main/java/org/apache/druid/testing/utils/MsqTestQueryHelper.java @@ -32,7 +32,6 @@ import org.apache.druid.java.util.common.ISE; import org.apache.druid.java.util.common.RetryUtils; import org.apache.druid.java.util.common.StringUtils; -import org.apache.druid.java.util.common.guava.Yielder; import org.apache.druid.java.util.http.client.response.StatusResponseHolder; import org.apache.druid.msq.guice.MSQIndexingModule; import org.apache.druid.msq.indexing.report.MSQResultsReport; @@ -215,17 +214,14 @@ private void compareResults(String taskId, MsqQueryWithResults expectedQueryWith List> actualResults = new ArrayList<>(); - Yielder yielder = resultsReport.getResultYielder(); List rowSignature = resultsReport.getSignature(); - while (!yielder.isDone()) { - Object[] row = yielder.get(); + for (final Object[] row : resultsReport.getResults()) { Map rowWithFieldNames = new LinkedHashMap<>(); for (int i = 0; i < row.length; ++i) { rowWithFieldNames.put(rowSignature.get(i).getName(), row[i]); } actualResults.add(rowWithFieldNames); - yielder = yielder.next(null); } QueryResultVerifier.ResultVerificationObject resultsComparison = QueryResultVerifier.compareResults( diff --git a/processing/src/main/java/org/apache/druid/frame/util/DurableStorageUtils.java b/processing/src/main/java/org/apache/druid/frame/util/DurableStorageUtils.java index 3e5f2fe00a1f..168d96fc20ab 100644 --- a/processing/src/main/java/org/apache/druid/frame/util/DurableStorageUtils.java +++ b/processing/src/main/java/org/apache/druid/frame/util/DurableStorageUtils.java @@ -126,7 +126,11 @@ public static String getTaskIdOutputsFolderName( { return StringUtils.format( "%s/taskId_%s", - getWorkerOutputFolderName(controllerTaskId, stageNumber, workerNumber), + getWorkerOutputFolderName( + IdUtils.validateId("controller task ID", controllerTaskId), + stageNumber, + workerNumber + ), taskId ); } diff --git a/processing/src/main/java/org/apache/druid/indexer/report/SingleFileTaskReportFileWriter.java b/processing/src/main/java/org/apache/druid/indexer/report/SingleFileTaskReportFileWriter.java index d862b224d86e..9012f4e83a15 100644 --- a/processing/src/main/java/org/apache/druid/indexer/report/SingleFileTaskReportFileWriter.java +++ b/processing/src/main/java/org/apache/druid/indexer/report/SingleFileTaskReportFileWriter.java @@ -24,8 +24,9 @@ import org.apache.druid.java.util.common.logger.Logger; import java.io.File; -import java.io.FileOutputStream; +import java.io.IOException; import java.io.OutputStream; +import java.nio.file.Files; public class SingleFileTaskReportFileWriter implements TaskReportFileWriter { @@ -42,21 +43,25 @@ public SingleFileTaskReportFileWriter(File reportsFile) @Override public void write(String taskId, TaskReport.ReportMap reports) { - try { - final File reportsFileParent = reportsFile.getParentFile(); - if (reportsFileParent != null) { - FileUtils.mkdirp(reportsFileParent); - } - - try (final FileOutputStream outputStream = new FileOutputStream(reportsFile)) { - writeReportToStream(objectMapper, outputStream, reports); - } + try (final OutputStream outputStream = openReportOutputStream(taskId)) { + writeReportToStream(objectMapper, outputStream, reports); } catch (Exception e) { log.error(e, "Encountered exception in write()."); } } + @Override + public OutputStream openReportOutputStream(String taskId) throws IOException + { + final File reportsFileParent = reportsFile.getParentFile(); + if (reportsFileParent != null) { + FileUtils.mkdirp(reportsFileParent); + } + + return Files.newOutputStream(reportsFile.toPath()); + } + @Override public void setObjectMapper(ObjectMapper objectMapper) { diff --git a/processing/src/main/java/org/apache/druid/indexer/report/TaskReportFileWriter.java b/processing/src/main/java/org/apache/druid/indexer/report/TaskReportFileWriter.java index bb3ebcd0394a..0cdd02493662 100644 --- a/processing/src/main/java/org/apache/druid/indexer/report/TaskReportFileWriter.java +++ b/processing/src/main/java/org/apache/druid/indexer/report/TaskReportFileWriter.java @@ -21,9 +21,14 @@ import com.fasterxml.jackson.databind.ObjectMapper; +import java.io.IOException; +import java.io.OutputStream; + public interface TaskReportFileWriter { void write(String taskId, TaskReport.ReportMap reports); + OutputStream openReportOutputStream(String taskId) throws IOException; + void setObjectMapper(ObjectMapper objectMapper); } diff --git a/server/src/main/java/org/apache/druid/rpc/RequestBuilder.java b/server/src/main/java/org/apache/druid/rpc/RequestBuilder.java index 224cfc78ed11..d6dde12fd6db 100644 --- a/server/src/main/java/org/apache/druid/rpc/RequestBuilder.java +++ b/server/src/main/java/org/apache/druid/rpc/RequestBuilder.java @@ -32,8 +32,6 @@ import org.joda.time.Duration; import javax.ws.rs.core.MediaType; -import java.net.MalformedURLException; -import java.net.URL; import java.util.Arrays; import java.util.Map; import java.util.Objects; @@ -77,11 +75,11 @@ public RequestBuilder content(final String contentType, final byte[] content) return this; } - public RequestBuilder jsonContent(final ObjectMapper jsonMapper, final Object content) + public RequestBuilder objectContent(final ObjectMapper objectMapper, final String contentType, final Object content) { try { - this.contentType = MediaType.APPLICATION_JSON; - this.content = jsonMapper.writeValueAsBytes(Preconditions.checkNotNull(content, "content")); + this.contentType = contentType; + this.content = objectMapper.writeValueAsBytes(Preconditions.checkNotNull(content, "content")); return this; } catch (JsonProcessingException e) { @@ -89,16 +87,14 @@ public RequestBuilder jsonContent(final ObjectMapper jsonMapper, final Object co } } + public RequestBuilder jsonContent(final ObjectMapper jsonMapper, final Object content) + { + return objectContent(jsonMapper, MediaType.APPLICATION_JSON, content); + } + public RequestBuilder smileContent(final ObjectMapper smileMapper, final Object content) { - try { - this.contentType = SmileMediaTypes.APPLICATION_JACKSON_SMILE; - this.content = smileMapper.writeValueAsBytes(Preconditions.checkNotNull(content, "content")); - return this; - } - catch (JsonProcessingException e) { - throw new RuntimeException(e); - } + return objectContent(smileMapper, SmileMediaTypes.APPLICATION_JACKSON_SMILE, content); } public RequestBuilder timeout(final Duration timeout) @@ -121,8 +117,7 @@ public Duration getTimeout() public Request build(ServiceLocation serviceLocation) { // It's expected that our encodedPathAndQueryString starts with '/' and the service base path doesn't end with one. - final String path = serviceLocation.getBasePath() + encodedPathAndQueryString; - final Request request = new Request(method, makeURL(serviceLocation, path)); + final Request request = new Request(method, serviceLocation.toURL(encodedPathAndQueryString)); for (final Map.Entry entry : headers.entries()) { request.addHeader(entry.getKey(), entry.getValue()); @@ -135,29 +130,6 @@ public Request build(ServiceLocation serviceLocation) return request; } - private URL makeURL(final ServiceLocation serviceLocation, final String encodedPathAndQueryString) - { - final String scheme; - final int portToUse; - - if (serviceLocation.getTlsPort() > 0) { - // Prefer HTTPS if available. - scheme = "https"; - portToUse = serviceLocation.getTlsPort(); - } else { - scheme = "http"; - portToUse = serviceLocation.getPlaintextPort(); - } - - // Use URL constructor, not URI, since the path is already encoded. - try { - return new URL(scheme, serviceLocation.getHost(), portToUse, encodedPathAndQueryString); - } - catch (MalformedURLException e) { - throw new IllegalArgumentException(e); - } - } - @Override public boolean equals(Object o) { diff --git a/server/src/main/java/org/apache/druid/rpc/ServiceLocation.java b/server/src/main/java/org/apache/druid/rpc/ServiceLocation.java index 3a092d7cb8dd..aeaa24318e93 100644 --- a/server/src/main/java/org/apache/druid/rpc/ServiceLocation.java +++ b/server/src/main/java/org/apache/druid/rpc/ServiceLocation.java @@ -26,7 +26,10 @@ import org.apache.druid.server.DruidNode; import org.apache.druid.server.coordination.DruidServerMetadata; +import javax.annotation.Nullable; import javax.validation.constraints.NotNull; +import java.net.MalformedURLException; +import java.net.URL; import java.util.Iterator; import java.util.Objects; @@ -35,11 +38,24 @@ */ public class ServiceLocation { + private static final String HTTP_SCHEME = "http"; + private static final String HTTPS_SCHEME = "https"; + private static final Splitter HOST_SPLITTER = Splitter.on(":").limit(2); + private final String host; private final int plaintextPort; private final int tlsPort; private final String basePath; + /** + * Create a service location. + * + * @param host hostname or address + * @param plaintextPort plaintext port + * @param tlsPort TLS port + * @param basePath base path; must be encoded and must not include trailing "/". In particular, to use root as + * the base path, pass "" for this parameter. + */ public ServiceLocation(final String host, final int plaintextPort, final int tlsPort, final String basePath) { this.host = Preconditions.checkNotNull(host, "host"); @@ -48,13 +64,19 @@ public ServiceLocation(final String host, final int plaintextPort, final int tls this.basePath = Preconditions.checkNotNull(basePath, "basePath"); } + /** + * Create a service location based on a {@link DruidNode}, without a base path. + */ public static ServiceLocation fromDruidNode(final DruidNode druidNode) { return new ServiceLocation(druidNode.getHost(), druidNode.getPlaintextPort(), druidNode.getTlsPort(), ""); } - private static final Splitter SPLITTER = Splitter.on(":").limit(2); - + /** + * Create a service location based on a {@link DruidServerMetadata}. + * + * @throws IllegalArgumentException if the server metadata cannot be mapped to a service location. + */ public static ServiceLocation fromDruidServerMetadata(final DruidServerMetadata druidServerMetadata) { final String host = getHostFromString( @@ -71,7 +93,7 @@ public static ServiceLocation fromDruidServerMetadata(final DruidServerMetadata private static String getHostFromString(@NotNull String s) { - Iterator iterator = SPLITTER.split(s).iterator(); + Iterator iterator = HOST_SPLITTER.split(s).iterator(); ImmutableList strings = ImmutableList.copyOf(iterator); return strings.get(0); } @@ -81,7 +103,7 @@ private static int getPortFromString(String s) if (s == null) { return -1; } - Iterator iterator = SPLITTER.split(s).iterator(); + Iterator iterator = HOST_SPLITTER.split(s).iterator(); ImmutableList strings = ImmutableList.copyOf(iterator); try { return Integer.parseInt(strings.get(1)); @@ -111,6 +133,33 @@ public String getBasePath() return basePath; } + public URL toURL(@Nullable final String encodedPathAndQueryString) + { + final String scheme; + final int portToUse; + + if (tlsPort > 0) { + // Prefer HTTPS if available. + scheme = HTTPS_SCHEME; + portToUse = tlsPort; + } else { + scheme = HTTP_SCHEME; + portToUse = plaintextPort; + } + + try { + return new URL( + scheme, + host, + portToUse, + basePath + (encodedPathAndQueryString == null ? "" : encodedPathAndQueryString) + ); + } + catch (MalformedURLException e) { + throw new IllegalArgumentException(e); + } + } + @Override public boolean equals(Object o) { @@ -143,4 +192,5 @@ public String toString() ", basePath='" + basePath + '\'' + '}'; } + }