From 2a41b8349c4fa5792eed1a0f01ebacc38aaa3cfc Mon Sep 17 00:00:00 2001 From: jaystarshot Date: Mon, 2 Dec 2024 15:34:58 -0800 Subject: [PATCH] Add enhanced cte scheduling mode --- .../main/sphinx/admin/cte-materialization.rst | 12 +++- .../presto/hive/TestCteExecution.java | 9 +++ .../presto/SystemSessionProperties.java | 11 ++++ .../presto/execution/SqlStageExecution.java | 61 +++++++++++++++++- .../scheduler/CTEMaterializationTracker.java | 63 +++++++++++++++++++ .../FixedSourcePartitionedScheduler.java | 32 +++++++++- .../execution/scheduler/ScheduleResult.java | 6 ++ .../scheduler/SectionExecutionFactory.java | 24 ++++--- .../scheduler/SplitSchedulerStats.java | 9 +++ .../scheduler/SqlQueryScheduler.java | 38 ++++++++--- .../presto/sql/analyzer/FeaturesConfig.java | 13 ++++ .../sql/analyzer/TestFeaturesConfig.java | 7 ++- .../planner/TestCanonicalPlanGenerator.java | 2 +- 13 files changed, 263 insertions(+), 24 deletions(-) create mode 100644 presto-main/src/main/java/com/facebook/presto/execution/scheduler/CTEMaterializationTracker.java diff --git a/presto-docs/src/main/sphinx/admin/cte-materialization.rst b/presto-docs/src/main/sphinx/admin/cte-materialization.rst index 332bc02cb289..9b169f02cee2 100644 --- a/presto-docs/src/main/sphinx/admin/cte-materialization.rst +++ b/presto-docs/src/main/sphinx/admin/cte-materialization.rst @@ -118,7 +118,6 @@ This setting specifies the Hash function type for CTE materialization. Use the ``hive.bucket_function_type_for_cte_materialization`` session property to set on a per-query basis. - ``query.max-written-intermediate-bytes`` ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ @@ -129,6 +128,17 @@ This setting defines a cap on the amount of data that can be written during CTE Use the ``query_max_written_intermediate_bytes`` session property to set on a per-query basis. +``enhanced-cte-scheduling-enabled`` +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + + * **Type:** ``boolean`` + * **Default value:** ``true`` + +Flag to enable or disable the enhanced-cte-blocking during CTE Materialization. Enhanced CTE blocking restricts only the table scan stages of the CTE TableScan, rather than blocking entire plan sections, including the main query, until the query completes. +This approach can improve latency in scenarios where parts of the query can execute concurrently with CTE materialization writes. + +Use the ``enhanced_cte_scheduling_enabled`` session property to set on a per-query basis. + How to Participate in Development --------------------------------- diff --git a/presto-hive/src/test/java/com/facebook/presto/hive/TestCteExecution.java b/presto-hive/src/test/java/com/facebook/presto/hive/TestCteExecution.java index c86b91f206f1..1e4f13b96390 100644 --- a/presto-hive/src/test/java/com/facebook/presto/hive/TestCteExecution.java +++ b/presto-hive/src/test/java/com/facebook/presto/hive/TestCteExecution.java @@ -1171,6 +1171,15 @@ public void testChainedCteProjectionAndFilterPushDown() generateMaterializedCTEInformation("cte5", 1, false, true))); } + @Test + public void testCTEMaterializationWithEnhancedScheduling() + { + QueryRunner queryRunner = getQueryRunner(); + String sql = "WITH temp as (SELECT orderkey FROM ORDERS) " + + "SELECT * FROM temp t1 JOIN (SELECT custkey FROM customer) c ON t1.orderkey=c.custkey"; + verifyResults(queryRunner, sql, ImmutableList.of(generateMaterializedCTEInformation("temp", 1, false, true))); + } + @Test public void testWrittenIntemediateByteLimit() throws Exception diff --git a/presto-main/src/main/java/com/facebook/presto/SystemSessionProperties.java b/presto-main/src/main/java/com/facebook/presto/SystemSessionProperties.java index 7f23e8172e80..93a485574f1a 100644 --- a/presto-main/src/main/java/com/facebook/presto/SystemSessionProperties.java +++ b/presto-main/src/main/java/com/facebook/presto/SystemSessionProperties.java @@ -234,6 +234,7 @@ public final class SystemSessionProperties public static final String QUERY_RETRY_MAX_EXECUTION_TIME = "query_retry_max_execution_time"; public static final String PARTIAL_RESULTS_ENABLED = "partial_results_enabled"; public static final String PARTIAL_RESULTS_COMPLETION_RATIO_THRESHOLD = "partial_results_completion_ratio_threshold"; + public static final String ENHANCED_CTE_SCHEDULING_ENABLED = "enhanced-cte-scheduling-enabled"; public static final String PARTIAL_RESULTS_MAX_EXECUTION_TIME_MULTIPLIER = "partial_results_max_execution_time_multiplier"; public static final String OFFSET_CLAUSE_ENABLED = "offset_clause_enabled"; public static final String VERBOSE_EXCEEDED_MEMORY_LIMIT_ERRORS_ENABLED = "verbose_exceeded_memory_limit_errors_enabled"; @@ -1282,6 +1283,11 @@ public SystemSessionProperties( "Minimum query completion ratio threshold for partial results", featuresConfig.getPartialResultsCompletionRatioThreshold(), false), + booleanProperty( + ENHANCED_CTE_SCHEDULING_ENABLED, + "Applicable for CTE Materialization. If enabled, only tablescans of the pending tablewriters are blocked and other stages can continue.", + featuresConfig.getEnhancedCTESchedulingEnabled(), + true), booleanProperty( OFFSET_CLAUSE_ENABLED, "Enable support for OFFSET clause", @@ -2690,6 +2696,11 @@ public static double getPartialResultsCompletionRatioThreshold(Session session) return session.getSystemProperty(PARTIAL_RESULTS_COMPLETION_RATIO_THRESHOLD, Double.class); } + public static boolean isEnhancedCTESchedulingEnabled(Session session) + { + return isCteMaterializationApplicable(session) & session.getSystemProperty(ENHANCED_CTE_SCHEDULING_ENABLED, Boolean.class); + } + public static double getPartialResultsMaxExecutionTimeMultiplier(Session session) { return session.getSystemProperty(PARTIAL_RESULTS_MAX_EXECUTION_TIME_MULTIPLIER, Double.class); diff --git a/presto-main/src/main/java/com/facebook/presto/execution/SqlStageExecution.java b/presto-main/src/main/java/com/facebook/presto/execution/SqlStageExecution.java index ee2159eba56c..2c6b7e8b60de 100644 --- a/presto-main/src/main/java/com/facebook/presto/execution/SqlStageExecution.java +++ b/presto-main/src/main/java/com/facebook/presto/execution/SqlStageExecution.java @@ -26,10 +26,15 @@ import com.facebook.presto.metadata.Split; import com.facebook.presto.server.remotetask.HttpRemoteTask; import com.facebook.presto.spi.PrestoException; +import com.facebook.presto.spi.plan.CteMaterializationInfo; import com.facebook.presto.spi.plan.PlanFragmentId; +import com.facebook.presto.spi.plan.PlanNode; import com.facebook.presto.spi.plan.PlanNodeId; +import com.facebook.presto.spi.plan.TableFinishNode; +import com.facebook.presto.spi.plan.TableScanNode; import com.facebook.presto.split.RemoteSplit; import com.facebook.presto.sql.planner.PlanFragment; +import com.facebook.presto.sql.planner.optimizations.PlanNodeSearcher; import com.facebook.presto.sql.planner.plan.RemoteSourceNode; import com.google.common.collect.HashMultimap; import com.google.common.collect.ImmutableMap; @@ -60,8 +65,10 @@ import java.util.concurrent.atomic.AtomicInteger; import java.util.concurrent.atomic.AtomicReference; import java.util.function.Consumer; +import java.util.stream.Collectors; import static com.facebook.presto.SystemSessionProperties.getMaxFailedTaskPercentage; +import static com.facebook.presto.SystemSessionProperties.isEnhancedCTESchedulingEnabled; import static com.facebook.presto.failureDetector.FailureDetector.State.GONE; import static com.facebook.presto.operator.ExchangeOperator.REMOTE_CONNECTOR_ID; import static com.facebook.presto.spi.StandardErrorCode.GENERIC_INTERNAL_ERROR; @@ -557,7 +564,6 @@ private synchronized RemoteTask scheduleTask(InternalNode node, TaskId taskId, M // stage finished while we were scheduling this task task.abort(); } - return task; } @@ -594,6 +600,59 @@ private static Split createRemoteSplitFor(TaskId taskId, URI remoteSourceTaskLoc return new Split(REMOTE_CONNECTOR_ID, new RemoteTransactionHandle(), new RemoteSplit(new Location(splitLocation), remoteSourceTaskId)); } + private static String getCteIdFromSource(PlanNode source) + { + // Traverse the plan node tree to find a TableWriterNode with TemporaryTableInfo + return PlanNodeSearcher.searchFrom(source) + .where(planNode -> planNode instanceof TableFinishNode) + .findFirst() + .flatMap(planNode -> ((TableFinishNode) planNode).getCteMaterializationInfo()) + .map(CteMaterializationInfo::getCteId) + .orElseThrow(() -> new IllegalStateException("TemporaryTableInfo has no CTE ID")); + } + + public boolean isCTETableFinishStage() + { + return PlanNodeSearcher.searchFrom(planFragment.getRoot()) + .where(planNode -> planNode instanceof TableFinishNode && + ((TableFinishNode) planNode).getCteMaterializationInfo().isPresent()) + .findSingle() + .isPresent(); + } + + public String getCTEWriterId() + { + // Validate that this is a CTE TableFinish stage and return the associated CTE ID + if (!isCTETableFinishStage()) { + throw new IllegalStateException("This stage is not a CTE writer stage"); + } + return getCteIdFromSource(planFragment.getRoot()); + } + + public boolean requiresMaterializedCTE() + { + if (!isEnhancedCTESchedulingEnabled(session)) { + return false; + } + // Search for TableScanNodes and check if they reference TemporaryTableInfo + return PlanNodeSearcher.searchFrom(planFragment.getRoot()) + .where(planNode -> planNode instanceof TableScanNode) + .findAll().stream() + .anyMatch(planNode -> ((TableScanNode) planNode).getCteMaterializationInfo().isPresent()); + } + + public List getRequiredCTEList() + { + // Collect all CTE IDs referenced by TableScanNodes with TemporaryTableInfo + return PlanNodeSearcher.searchFrom(planFragment.getRoot()) + .where(planNode -> planNode instanceof TableScanNode) + .findAll().stream() + .map(planNode -> ((TableScanNode) planNode).getCteMaterializationInfo() + .orElseThrow(() -> new IllegalStateException("TableScanNode has no TemporaryTableInfo"))) + .map(CteMaterializationInfo::getCteId) + .collect(Collectors.toList()); + } + private void updateTaskStatus(TaskId taskId, TaskStatus taskStatus) { StageExecutionState stageExecutionState = getState(); diff --git a/presto-main/src/main/java/com/facebook/presto/execution/scheduler/CTEMaterializationTracker.java b/presto-main/src/main/java/com/facebook/presto/execution/scheduler/CTEMaterializationTracker.java new file mode 100644 index 000000000000..a82ef6508d95 --- /dev/null +++ b/presto-main/src/main/java/com/facebook/presto/execution/scheduler/CTEMaterializationTracker.java @@ -0,0 +1,63 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.execution.scheduler; + +import com.google.common.base.Preconditions; +import com.google.common.util.concurrent.Futures; +import com.google.common.util.concurrent.ListenableFuture; +import com.google.common.util.concurrent.SettableFuture; + +import java.util.Map; +import java.util.concurrent.ConcurrentHashMap; + +/* + * Tracks the completion status of table-finish nodes that write temporary tables for CTE materialization. + * CTEMaterializationTracker manages a map of materialized CTEs and their associated materialization futures. + * When a stage includes a CTE table finish, it marks the corresponding CTE as materialized and completes + * the associated future. + * This signals the scheduler that some dependency has been resolved, prompting it to resume/continue scheduling. + */ +public class CTEMaterializationTracker +{ + private final Map> materializationFutures = new ConcurrentHashMap<>(); + + public ListenableFuture getFutureForCTE(String cteName) + { + return Futures.nonCancellationPropagating( + materializationFutures.compute(cteName, (key, existingFuture) -> { + if (existingFuture == null) { + // Create a new SettableFuture and store it internally + return SettableFuture.create(); + } + Preconditions.checkArgument(!existingFuture.isCancelled(), + String.format("Error: Existing future was found cancelled in CTEMaterializationTracker for cte", cteName)); + return existingFuture; + })); + } + + public void markCTEAsMaterialized(String cteName) + { + materializationFutures.compute(cteName, (key, existingFuture) -> { + if (existingFuture == null) { + SettableFuture completedFuture = SettableFuture.create(); + completedFuture.set(null); + return completedFuture; + } + Preconditions.checkArgument(!existingFuture.isCancelled(), + String.format("Error: Existing future was found cancelled in CTEMaterializationTracker for cte", cteName)); + existingFuture.set(null); // Notify all listeners + return existingFuture; + }); + } +} diff --git a/presto-main/src/main/java/com/facebook/presto/execution/scheduler/FixedSourcePartitionedScheduler.java b/presto-main/src/main/java/com/facebook/presto/execution/scheduler/FixedSourcePartitionedScheduler.java index 8e965aab792c..9febc345114f 100644 --- a/presto-main/src/main/java/com/facebook/presto/execution/scheduler/FixedSourcePartitionedScheduler.java +++ b/presto-main/src/main/java/com/facebook/presto/execution/scheduler/FixedSourcePartitionedScheduler.java @@ -74,6 +74,8 @@ public class FixedSourcePartitionedScheduler private final Queue tasksToRecover = new ConcurrentLinkedQueue<>(); + private final CTEMaterializationTracker cteMaterializationTracker; + @GuardedBy("this") private boolean closed; @@ -87,13 +89,15 @@ public FixedSourcePartitionedScheduler( int splitBatchSize, OptionalInt concurrentLifespansPerTask, NodeSelector nodeSelector, - List partitionHandles) + List partitionHandles, + CTEMaterializationTracker cteMaterializationTracker) { requireNonNull(stage, "stage is null"); requireNonNull(splitSources, "splitSources is null"); requireNonNull(bucketNodeMap, "bucketNodeMap is null"); checkArgument(!requireNonNull(nodes, "nodes is null").isEmpty(), "nodes is empty"); requireNonNull(partitionHandles, "partitionHandles is null"); + this.cteMaterializationTracker = cteMaterializationTracker; this.stage = stage; this.nodes = ImmutableList.copyOf(nodes); @@ -179,6 +183,29 @@ public ScheduleResult schedule() { // schedule a task on every node in the distribution List newTasks = ImmutableList.of(); + + // CTE Materialization Check + if (stage.requiresMaterializedCTE()) { + List> blocked = new ArrayList<>(); + List requiredCTEIds = stage.getRequiredCTEList(); + for (String cteId : requiredCTEIds) { + ListenableFuture cteFuture = cteMaterializationTracker.getFutureForCTE(cteId); + if (!cteFuture.isDone()) { + // Add CTE materialization future to the blocked list + blocked.add(cteFuture); + } + } + // If any CTE is not materialized, return a blocked ScheduleResult + if (!blocked.isEmpty()) { + return ScheduleResult.blocked( + false, + newTasks, + whenAnyComplete(blocked), + BlockedReason.WAITING_FOR_CTE_MATERIALIZATION, + 0); + } + } + // schedule a task on every node in the distribution if (!scheduledTasks) { newTasks = Streams.mapWithIndex( nodes.stream(), @@ -191,9 +218,8 @@ public ScheduleResult schedule() // notify listeners that we have scheduled all tasks so they can set no more buffers or exchange splits stage.transitionToFinishedTaskScheduling(); } - - boolean allBlocked = true; List> blocked = new ArrayList<>(); + boolean allBlocked = true; BlockedReason blockedReason = BlockedReason.NO_ACTIVE_DRIVER_GROUP; if (groupedLifespanScheduler.isPresent()) { diff --git a/presto-main/src/main/java/com/facebook/presto/execution/scheduler/ScheduleResult.java b/presto-main/src/main/java/com/facebook/presto/execution/scheduler/ScheduleResult.java index ed85bfff8fd9..dfc6288f7861 100644 --- a/presto-main/src/main/java/com/facebook/presto/execution/scheduler/ScheduleResult.java +++ b/presto-main/src/main/java/com/facebook/presto/execution/scheduler/ScheduleResult.java @@ -57,6 +57,11 @@ public enum BlockedReason * grouped execution where there are multiple lifespans per task). */ MIXED_SPLIT_QUEUES_FULL_AND_WAITING_FOR_SOURCE, + + /** + * Waiting for the completion of CTE materialization by the table writer. + */ + WAITING_FOR_CTE_MATERIALIZATION, /**/; public BlockedReason combineWith(BlockedReason other) @@ -64,6 +69,7 @@ public BlockedReason combineWith(BlockedReason other) switch (this) { case WRITER_SCALING: throw new IllegalArgumentException("cannot be combined"); + case WAITING_FOR_CTE_MATERIALIZATION: case NO_ACTIVE_DRIVER_GROUP: return other; case SPLIT_QUEUES_FULL: diff --git a/presto-main/src/main/java/com/facebook/presto/execution/scheduler/SectionExecutionFactory.java b/presto-main/src/main/java/com/facebook/presto/execution/scheduler/SectionExecutionFactory.java index cba5736650fd..26b77ef163ca 100644 --- a/presto-main/src/main/java/com/facebook/presto/execution/scheduler/SectionExecutionFactory.java +++ b/presto-main/src/main/java/com/facebook/presto/execution/scheduler/SectionExecutionFactory.java @@ -170,7 +170,8 @@ public SectionExecution createSectionExecutions( boolean summarizeTaskInfo, RemoteTaskFactory remoteTaskFactory, SplitSourceFactory splitSourceFactory, - int attemptId) + int attemptId, + CTEMaterializationTracker cteMaterializationTracker) { // Only fetch a distribution once per section to ensure all stages see the same machine assignments Map partitioningCache = new HashMap<>(); @@ -186,7 +187,8 @@ public SectionExecution createSectionExecutions( summarizeTaskInfo, remoteTaskFactory, splitSourceFactory, - attemptId); + attemptId, + cteMaterializationTracker); StageExecutionAndScheduler rootStage = getLast(sectionStages); rootStage.getStageExecution().setOutputBuffers(outputBuffers); return new SectionExecution(rootStage, sectionStages); @@ -205,7 +207,8 @@ private List createStreamingLinkedStageExecutions( boolean summarizeTaskInfo, RemoteTaskFactory remoteTaskFactory, SplitSourceFactory splitSourceFactory, - int attemptId) + int attemptId, + CTEMaterializationTracker cteMaterializationTracker) { ImmutableList.Builder stageExecutionAndSchedulers = ImmutableList.builder(); @@ -240,7 +243,8 @@ private List createStreamingLinkedStageExecutions( summarizeTaskInfo, remoteTaskFactory, splitSourceFactory, - attemptId); + attemptId, + cteMaterializationTracker); stageExecutionAndSchedulers.addAll(subTree); childStagesBuilder.add(getLast(subTree).getStageExecution()); } @@ -262,7 +266,8 @@ private List createStreamingLinkedStageExecutions( stageExecution, partitioningHandle, tableWriteInfo, - childStageExecutions); + childStageExecutions, + cteMaterializationTracker); stageExecutionAndSchedulers.add(new StageExecutionAndScheduler( stageExecution, stageLinkage, @@ -281,7 +286,8 @@ private StageScheduler createStageScheduler( SqlStageExecution stageExecution, PartitioningHandle partitioningHandle, TableWriteInfo tableWriteInfo, - Set childStageExecutions) + Set childStageExecutions, + CTEMaterializationTracker cteMaterializationTracker) { Map splitSources = splitSourceFactory.createSplitSources(plan.getFragment(), session, tableWriteInfo); int maxTasksPerStage = getMaxTasksPerStage(session); @@ -341,7 +347,8 @@ else if (partitioningHandle.equals(SCALED_WRITER_DISTRIBUTION)) { splitBatchSize, getConcurrentLifespansPerNode(session), nodeSelector, - ImmutableList.of(NOT_PARTITIONED)); + ImmutableList.of(NOT_PARTITIONED), + cteMaterializationTracker); } else if (!splitSources.isEmpty()) { // contains local source @@ -400,7 +407,8 @@ else if (!splitSources.isEmpty()) { splitBatchSize, getConcurrentLifespansPerNode(session), nodeScheduler.createNodeSelector(session, connectorId, nodePredicate), - connectorPartitionHandles); + connectorPartitionHandles, + cteMaterializationTracker); if (plan.getFragment().getStageExecutionDescriptor().isRecoverableGroupedExecution()) { stageExecution.registerStageTaskRecoveryCallback(taskId -> { checkArgument(taskId.getStageExecutionId().getStageId().equals(stageId), "The task did not execute this stage"); diff --git a/presto-main/src/main/java/com/facebook/presto/execution/scheduler/SplitSchedulerStats.java b/presto-main/src/main/java/com/facebook/presto/execution/scheduler/SplitSchedulerStats.java index 6d0507082e8b..83b4bbaa8c69 100644 --- a/presto-main/src/main/java/com/facebook/presto/execution/scheduler/SplitSchedulerStats.java +++ b/presto-main/src/main/java/com/facebook/presto/execution/scheduler/SplitSchedulerStats.java @@ -32,6 +32,8 @@ public class SplitSchedulerStats private final CounterStat splitQueuesFull = new CounterStat(); private final CounterStat mixedSplitQueuesFullAndWaitingForSource = new CounterStat(); private final CounterStat noActiveDriverGroup = new CounterStat(); + + private final CounterStat waitingForCTEMaterialization = new CounterStat(); private final DistributionStat splitsPerIteration = new DistributionStat(); @Managed @@ -62,6 +64,13 @@ public CounterStat getWaitingForSource() return waitingForSource; } + @Managed + @Nested + public CounterStat getWaitingForCTEMaterialization() + { + return waitingForCTEMaterialization; + } + @Managed @Nested public CounterStat getSplitQueuesFull() diff --git a/presto-main/src/main/java/com/facebook/presto/execution/scheduler/SqlQueryScheduler.java b/presto-main/src/main/java/com/facebook/presto/execution/scheduler/SqlQueryScheduler.java index 198f23f1ba22..a255a939f6ff 100644 --- a/presto-main/src/main/java/com/facebook/presto/execution/scheduler/SqlQueryScheduler.java +++ b/presto-main/src/main/java/com/facebook/presto/execution/scheduler/SqlQueryScheduler.java @@ -79,6 +79,7 @@ import static com.facebook.presto.SystemSessionProperties.getMaxConcurrentMaterializations; import static com.facebook.presto.SystemSessionProperties.getPartialResultsCompletionRatioThreshold; import static com.facebook.presto.SystemSessionProperties.getPartialResultsMaxExecutionTimeMultiplier; +import static com.facebook.presto.SystemSessionProperties.isEnhancedCTESchedulingEnabled; import static com.facebook.presto.SystemSessionProperties.isPartialResultsEnabled; import static com.facebook.presto.SystemSessionProperties.isRuntimeOptimizerEnabled; import static com.facebook.presto.execution.BasicStageExecutionStats.aggregateBasicStageStats; @@ -149,6 +150,7 @@ public class SqlQueryScheduler private final AtomicBoolean scheduling = new AtomicBoolean(); private final PartialResultQueryTaskTracker partialResultQueryTaskTracker; + private final CTEMaterializationTracker cteMaterializationTracker = new CTEMaterializationTracker(); public static SqlQueryScheduler createSqlQueryScheduler( LocationFactory locationFactory, @@ -278,6 +280,17 @@ else if (state == CANCELED) { for (StageExecutionAndScheduler stageExecutionInfo : stageExecutions.values()) { SqlStageExecution stageExecution = stageExecutionInfo.getStageExecution(); + // Add a listener for state changes + if (stageExecution.isCTETableFinishStage()) { + stageExecution.addStateChangeListener(state -> { + if (state == StageExecutionState.FINISHED) { + String cteName = stageExecution.getCTEWriterId(); + log.debug("CTE write completed for: " + cteName); + // Notify the materialization tracker + cteMaterializationTracker.markCTEAsMaterialized(cteName); + } + }); + } stageExecution.addStateChangeListener(state -> { if (queryStateMachine.isDone()) { return; @@ -363,7 +376,8 @@ private List createStageExecutions( summarizeTaskInfo, remoteTaskFactory, splitSourceFactory, - 0).getSectionStages(); + 0, + cteMaterializationTracker).getSectionStages(); stages.addAll(sectionStages); return stages.build(); @@ -460,7 +474,9 @@ else if (!result.getBlocked().isDone()) { ScheduleResult.BlockedReason blockedReason = result.getBlockedReason().get(); switch (blockedReason) { case WRITER_SCALING: - // no-op + break; + case WAITING_FOR_CTE_MATERIALIZATION: + schedulerStats.getWaitingForCTEMaterialization().update(1); break; case WAITING_FOR_SOURCE: schedulerStats.getWaitingForSource().update(1); @@ -568,10 +584,12 @@ private List getSectionsReadyForExecution() .map(section -> getStageExecution(section.getPlan().getFragment().getId()).getState()) .filter(state -> !state.isDone() && state != PLANNED) .count(); + return stream(forTree(StreamingPlanSection::getChildren).depthFirstPreOrder(sectionedPlan)) // get all sections ready for execution .filter(this::isReadyForExecution) - .limit(maxConcurrentMaterializations - runningPlanSections) + // for enhanced cte blocking we do not need a limit on the sections + .limit(isEnhancedCTESchedulingEnabled(session) ? Long.MAX_VALUE : maxConcurrentMaterializations - runningPlanSections) .map(this::tryCostBasedOptimize) .collect(toImmutableList()); } @@ -678,7 +696,8 @@ private void updateStageExecutions(StreamingPlanSection section, Map updatedStageExecutions = sectionExecution.getSectionStages().stream() .collect(toImmutableMap(execution -> execution.getStageExecution().getStageExecutionId().getStageId(), identity())); @@ -774,10 +793,13 @@ private boolean isReadyForExecution(StreamingPlanSection section) // already scheduled return false; } - for (StreamingPlanSection child : section.getChildren()) { - SqlStageExecution rootStageExecution = getStageExecution(child.getPlan().getFragment().getId()); - if (rootStageExecution.getState() != FINISHED) { - return false; + if (!isEnhancedCTESchedulingEnabled(session)) { + // Enhanced cte blocking is not enabled so block till child sections are complete + for (StreamingPlanSection child : section.getChildren()) { + SqlStageExecution rootStageExecution = getStageExecution(child.getPlan().getFragment().getId()); + if (rootStageExecution.getState() != FINISHED) { + return false; + } } } return true; diff --git a/presto-main/src/main/java/com/facebook/presto/sql/analyzer/FeaturesConfig.java b/presto-main/src/main/java/com/facebook/presto/sql/analyzer/FeaturesConfig.java index d6c329d195d9..4ee4907dc1cb 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/analyzer/FeaturesConfig.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/analyzer/FeaturesConfig.java @@ -138,6 +138,7 @@ public class FeaturesConfig private boolean ignoreStatsCalculatorFailures = true; private boolean printStatsForNonJoinQuery; private boolean defaultFilterFactorEnabled; + private boolean enhancedCteSchedulingEnabled = true; // Give a default 10% selectivity coefficient factor to avoid hitting unknown stats in join stats estimates // which could result in syntactic join order. Set it to 0 to disable this feature private double defaultJoinSelectivityCoefficient; @@ -1290,6 +1291,18 @@ public boolean isDefaultFilterFactorEnabled() return defaultFilterFactorEnabled; } + @Config("enhanced-cte-scheduling-enabled") + public FeaturesConfig setEnhancedCTESchedulingEnabled(boolean enhancedCTEBlockingEnabled) + { + this.enhancedCteSchedulingEnabled = enhancedCTEBlockingEnabled; + return this; + } + + public boolean getEnhancedCTESchedulingEnabled() + { + return enhancedCteSchedulingEnabled; + } + @Config("optimizer.default-join-selectivity-coefficient") @ConfigDescription("Used when join selectivity estimation is unknown. Default 0 to disable the use of join selectivity, this will allow planner to fall back to FROM-clause join order when the join cardinality is unknown") public FeaturesConfig setDefaultJoinSelectivityCoefficient(double defaultJoinSelectivityCoefficient) diff --git a/presto-main/src/test/java/com/facebook/presto/sql/analyzer/TestFeaturesConfig.java b/presto-main/src/test/java/com/facebook/presto/sql/analyzer/TestFeaturesConfig.java index 120a48ca62fd..55f0e2e5e254 100644 --- a/presto-main/src/test/java/com/facebook/presto/sql/analyzer/TestFeaturesConfig.java +++ b/presto-main/src/test/java/com/facebook/presto/sql/analyzer/TestFeaturesConfig.java @@ -250,7 +250,8 @@ public void testDefaults() .setEagerPlanValidationThreadPoolSize(20) .setPrestoSparkExecutionEnvironment(false) .setSingleNodeExecutionEnabled(false) - .setNativeExecutionScaleWritersThreadsEnabled(false)); + .setNativeExecutionScaleWritersThreadsEnabled(false) + .setEnhancedCTESchedulingEnabled(true)); } @Test @@ -450,6 +451,7 @@ public void testExplicitPropertyMappings() .put("presto-spark-execution-environment", "true") .put("single-node-execution-enabled", "true") .put("native-execution-scale-writer-threads-enabled", "true") + .put("enhanced-cte-scheduling-enabled", "false") .build(); FeaturesConfig expected = new FeaturesConfig() @@ -646,7 +648,8 @@ public void testExplicitPropertyMappings() .setEagerPlanValidationThreadPoolSize(2) .setPrestoSparkExecutionEnvironment(true) .setSingleNodeExecutionEnabled(true) - .setNativeExecutionScaleWritersThreadsEnabled(true); + .setNativeExecutionScaleWritersThreadsEnabled(true) + .setEnhancedCTESchedulingEnabled(false); assertFullMapping(properties, expected); } diff --git a/presto-main/src/test/java/com/facebook/presto/sql/planner/TestCanonicalPlanGenerator.java b/presto-main/src/test/java/com/facebook/presto/sql/planner/TestCanonicalPlanGenerator.java index e38e5c4eae0c..d64e46242904 100644 --- a/presto-main/src/test/java/com/facebook/presto/sql/planner/TestCanonicalPlanGenerator.java +++ b/presto-main/src/test/java/com/facebook/presto/sql/planner/TestCanonicalPlanGenerator.java @@ -375,7 +375,7 @@ public void testCanonicalTableScanNodeField() .filter(f -> !f.isSynthetic()) .map(Field::getName) .collect(toImmutableSet()), - ImmutableSet.of("table", "assignments", "outputVariables", "currentConstraint", "enforcedConstraint", "tableConstraints")); + ImmutableSet.of("table", "assignments", "outputVariables", "currentConstraint", "enforcedConstraint", "tableConstraints", "cteMaterializationInfo")); assertEquals( Arrays.stream(CanonicalTableScanNode.class.getDeclaredFields()) .filter(f -> !f.isSynthetic())