Skip to content

Commit

Permalink
Add enhanced cte scheduling mode
Browse files Browse the repository at this point in the history
  • Loading branch information
jaystarshot committed Jan 10, 2025
1 parent 17f365d commit 4e1a390
Show file tree
Hide file tree
Showing 13 changed files with 263 additions and 24 deletions.
12 changes: 11 additions & 1 deletion presto-docs/src/main/sphinx/admin/cte-materialization.rst
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,6 @@ This setting specifies the Hash function type for CTE materialization.

Use the ``hive.bucket_function_type_for_cte_materialization`` session property to set on a per-query basis.


``query.max-written-intermediate-bytes``
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

Expand All @@ -129,6 +128,17 @@ This setting defines a cap on the amount of data that can be written during CTE

Use the ``query_max_written_intermediate_bytes`` session property to set on a per-query basis.

``enhanced-cte-scheduling-enabled``
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

* **Type:** ``boolean``
* **Default value:** ``true``

Flag to enable or disable the enhanced-cte-blocking during CTE Materialization. Enhanced CTE blocking restricts only the table scan stages of the CTE TableScan, rather than blocking entire plan sections, including the main query, until the query completes.
This approach can improve latency in scenarios where parts of the query can execute concurrently with CTE materialization writes.

Use the ``enhanced_cte_scheduling_enabled`` session property to set on a per-query basis.


How to Participate in Development
---------------------------------
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1171,6 +1171,15 @@ public void testChainedCteProjectionAndFilterPushDown()
generateMaterializedCTEInformation("cte5", 1, false, true)));
}

@Test
public void testCTEMaterializationWithEnhancedScheduling()
{
QueryRunner queryRunner = getQueryRunner();
String sql = "WITH temp as (SELECT orderkey FROM ORDERS) " +
"SELECT * FROM temp t1 JOIN (SELECT custkey FROM customer) c ON t1.orderkey=c.custkey";
verifyResults(queryRunner, sql, ImmutableList.of(generateMaterializedCTEInformation("temp", 1, false, true)));
}

@Test
public void testWrittenIntemediateByteLimit()
throws Exception
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -234,6 +234,7 @@ public final class SystemSessionProperties
public static final String QUERY_RETRY_MAX_EXECUTION_TIME = "query_retry_max_execution_time";
public static final String PARTIAL_RESULTS_ENABLED = "partial_results_enabled";
public static final String PARTIAL_RESULTS_COMPLETION_RATIO_THRESHOLD = "partial_results_completion_ratio_threshold";
public static final String ENHANCED_CTE_SCHEDULING_ENABLED = "enhanced-cte-scheduling-enabled";
public static final String PARTIAL_RESULTS_MAX_EXECUTION_TIME_MULTIPLIER = "partial_results_max_execution_time_multiplier";
public static final String OFFSET_CLAUSE_ENABLED = "offset_clause_enabled";
public static final String VERBOSE_EXCEEDED_MEMORY_LIMIT_ERRORS_ENABLED = "verbose_exceeded_memory_limit_errors_enabled";
Expand Down Expand Up @@ -1282,6 +1283,11 @@ public SystemSessionProperties(
"Minimum query completion ratio threshold for partial results",
featuresConfig.getPartialResultsCompletionRatioThreshold(),
false),
booleanProperty(
ENHANCED_CTE_SCHEDULING_ENABLED,
"Applicable for CTE Materialization. If enabled, only tablescans of the pending tablewriters are blocked and other stages can continue.",
featuresConfig.getEnhancedCTESchedulingEnabled(),
true),
booleanProperty(
OFFSET_CLAUSE_ENABLED,
"Enable support for OFFSET clause",
Expand Down Expand Up @@ -2690,6 +2696,11 @@ public static double getPartialResultsCompletionRatioThreshold(Session session)
return session.getSystemProperty(PARTIAL_RESULTS_COMPLETION_RATIO_THRESHOLD, Double.class);
}

public static boolean isEnhancedCTESchedulingEnabled(Session session)
{
return isCteMaterializationApplicable(session) & session.getSystemProperty(ENHANCED_CTE_SCHEDULING_ENABLED, Boolean.class);
}

public static double getPartialResultsMaxExecutionTimeMultiplier(Session session)
{
return session.getSystemProperty(PARTIAL_RESULTS_MAX_EXECUTION_TIME_MULTIPLIER, Double.class);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,10 +26,15 @@
import com.facebook.presto.metadata.Split;
import com.facebook.presto.server.remotetask.HttpRemoteTask;
import com.facebook.presto.spi.PrestoException;
import com.facebook.presto.spi.plan.CteMaterializationInfo;
import com.facebook.presto.spi.plan.PlanFragmentId;
import com.facebook.presto.spi.plan.PlanNode;
import com.facebook.presto.spi.plan.PlanNodeId;
import com.facebook.presto.spi.plan.TableFinishNode;
import com.facebook.presto.spi.plan.TableScanNode;
import com.facebook.presto.split.RemoteSplit;
import com.facebook.presto.sql.planner.PlanFragment;
import com.facebook.presto.sql.planner.optimizations.PlanNodeSearcher;
import com.facebook.presto.sql.planner.plan.RemoteSourceNode;
import com.google.common.collect.HashMultimap;
import com.google.common.collect.ImmutableMap;
Expand Down Expand Up @@ -60,8 +65,10 @@
import java.util.concurrent.atomic.AtomicInteger;
import java.util.concurrent.atomic.AtomicReference;
import java.util.function.Consumer;
import java.util.stream.Collectors;

import static com.facebook.presto.SystemSessionProperties.getMaxFailedTaskPercentage;
import static com.facebook.presto.SystemSessionProperties.isEnhancedCTESchedulingEnabled;
import static com.facebook.presto.failureDetector.FailureDetector.State.GONE;
import static com.facebook.presto.operator.ExchangeOperator.REMOTE_CONNECTOR_ID;
import static com.facebook.presto.spi.StandardErrorCode.GENERIC_INTERNAL_ERROR;
Expand Down Expand Up @@ -557,7 +564,6 @@ private synchronized RemoteTask scheduleTask(InternalNode node, TaskId taskId, M
// stage finished while we were scheduling this task
task.abort();
}

return task;
}

Expand Down Expand Up @@ -594,6 +600,59 @@ private static Split createRemoteSplitFor(TaskId taskId, URI remoteSourceTaskLoc
return new Split(REMOTE_CONNECTOR_ID, new RemoteTransactionHandle(), new RemoteSplit(new Location(splitLocation), remoteSourceTaskId));
}

private static String getCteIdFromSource(PlanNode source)
{
// Traverse the plan node tree to find a TableWriterNode with TemporaryTableInfo
return PlanNodeSearcher.searchFrom(source)
.where(planNode -> planNode instanceof TableFinishNode)
.findFirst()
.flatMap(planNode -> ((TableFinishNode) planNode).getCteMaterializationInfo())
.map(CteMaterializationInfo::getCteId)
.orElseThrow(() -> new IllegalStateException("TemporaryTableInfo has no CTE ID"));
}

public boolean isCTETableFinishStage()
{
return PlanNodeSearcher.searchFrom(planFragment.getRoot())
.where(planNode -> planNode instanceof TableFinishNode &&
((TableFinishNode) planNode).getCteMaterializationInfo().isPresent())
.findSingle()
.isPresent();
}

public String getCTEWriterId()
{
// Validate that this is a CTE TableFinish stage and return the associated CTE ID
if (!isCTETableFinishStage()) {
throw new IllegalStateException("This stage is not a CTE writer stage");
}
return getCteIdFromSource(planFragment.getRoot());
}

public boolean requiresMaterializedCTE()
{
if (!isEnhancedCTESchedulingEnabled(session)) {
return false;
}
// Search for TableScanNodes and check if they reference TemporaryTableInfo
return PlanNodeSearcher.searchFrom(planFragment.getRoot())
.where(planNode -> planNode instanceof TableScanNode)
.findAll().stream()
.anyMatch(planNode -> ((TableScanNode) planNode).getCteMaterializationInfo().isPresent());
}

public List<String> getRequiredCTEList()
{
// Collect all CTE IDs referenced by TableScanNodes with TemporaryTableInfo
return PlanNodeSearcher.searchFrom(planFragment.getRoot())
.where(planNode -> planNode instanceof TableScanNode)
.findAll().stream()
.map(planNode -> ((TableScanNode) planNode).getCteMaterializationInfo()
.orElseThrow(() -> new IllegalStateException("TableScanNode has no TemporaryTableInfo")))
.map(CteMaterializationInfo::getCteId)
.collect(Collectors.toList());
}

private void updateTaskStatus(TaskId taskId, TaskStatus taskStatus)
{
StageExecutionState stageExecutionState = getState();
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
/*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.facebook.presto.execution.scheduler;

import com.google.common.base.Preconditions;
import com.google.common.util.concurrent.Futures;
import com.google.common.util.concurrent.ListenableFuture;
import com.google.common.util.concurrent.SettableFuture;

import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;

/*
* Tracks the completion status of table-finish nodes that write temporary tables for CTE materialization.
* CTEMaterializationTracker manages a map of materialized CTEs and their associated materialization futures.
* When a stage includes a CTE table finish, it marks the corresponding CTE as materialized and completes
* the associated future.
* This signals the scheduler that some dependency has been resolved, prompting it to resume/continue scheduling.
*/
public class CTEMaterializationTracker
{
private final Map<String, SettableFuture<Void>> materializationFutures = new ConcurrentHashMap<>();

public ListenableFuture<Void> getFutureForCTE(String cteName)
{
return Futures.nonCancellationPropagating(
materializationFutures.compute(cteName, (key, existingFuture) -> {
if (existingFuture == null) {
// Create a new SettableFuture and store it internally
return SettableFuture.create();
}
Preconditions.checkArgument(!existingFuture.isCancelled(),
String.format("Error: Existing future was found cancelled in CTEMaterializationTracker for cte", cteName));
return existingFuture;
}));
}

public void markCTEAsMaterialized(String cteName)
{
materializationFutures.compute(cteName, (key, existingFuture) -> {
if (existingFuture == null) {
SettableFuture<Void> completedFuture = SettableFuture.create();
completedFuture.set(null);
return completedFuture;
}
Preconditions.checkArgument(!existingFuture.isCancelled(),
String.format("Error: Existing future was found cancelled in CTEMaterializationTracker for cte", cteName));
existingFuture.set(null); // Notify all listeners
return existingFuture;
});
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,8 @@ public class FixedSourcePartitionedScheduler

private final Queue<Integer> tasksToRecover = new ConcurrentLinkedQueue<>();

private final CTEMaterializationTracker cteMaterializationTracker;

@GuardedBy("this")
private boolean closed;

Expand All @@ -87,13 +89,15 @@ public FixedSourcePartitionedScheduler(
int splitBatchSize,
OptionalInt concurrentLifespansPerTask,
NodeSelector nodeSelector,
List<ConnectorPartitionHandle> partitionHandles)
List<ConnectorPartitionHandle> partitionHandles,
CTEMaterializationTracker cteMaterializationTracker)
{
requireNonNull(stage, "stage is null");
requireNonNull(splitSources, "splitSources is null");
requireNonNull(bucketNodeMap, "bucketNodeMap is null");
checkArgument(!requireNonNull(nodes, "nodes is null").isEmpty(), "nodes is empty");
requireNonNull(partitionHandles, "partitionHandles is null");
this.cteMaterializationTracker = cteMaterializationTracker;

this.stage = stage;
this.nodes = ImmutableList.copyOf(nodes);
Expand Down Expand Up @@ -179,6 +183,29 @@ public ScheduleResult schedule()
{
// schedule a task on every node in the distribution
List<RemoteTask> newTasks = ImmutableList.of();

// CTE Materialization Check
if (stage.requiresMaterializedCTE()) {
List<ListenableFuture<?>> blocked = new ArrayList<>();
List<String> requiredCTEIds = stage.getRequiredCTEList();
for (String cteId : requiredCTEIds) {
ListenableFuture<Void> cteFuture = cteMaterializationTracker.getFutureForCTE(cteId);
if (!cteFuture.isDone()) {
// Add CTE materialization future to the blocked list
blocked.add(cteFuture);
}
}
// If any CTE is not materialized, return a blocked ScheduleResult
if (!blocked.isEmpty()) {
return ScheduleResult.blocked(
false,
newTasks,
whenAnyComplete(blocked),
BlockedReason.WAITING_FOR_CTE_MATERIALIZATION,
0);
}
}
// schedule a task on every node in the distribution
if (!scheduledTasks) {
newTasks = Streams.mapWithIndex(
nodes.stream(),
Expand All @@ -191,9 +218,8 @@ public ScheduleResult schedule()
// notify listeners that we have scheduled all tasks so they can set no more buffers or exchange splits
stage.transitionToFinishedTaskScheduling();
}

boolean allBlocked = true;
List<ListenableFuture<?>> blocked = new ArrayList<>();
boolean allBlocked = true;
BlockedReason blockedReason = BlockedReason.NO_ACTIVE_DRIVER_GROUP;

if (groupedLifespanScheduler.isPresent()) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -57,13 +57,19 @@ public enum BlockedReason
* grouped execution where there are multiple lifespans per task).
*/
MIXED_SPLIT_QUEUES_FULL_AND_WAITING_FOR_SOURCE,

/**
* Waiting for the completion of CTE materialization by the table writer.
*/
WAITING_FOR_CTE_MATERIALIZATION,
/**/;

public BlockedReason combineWith(BlockedReason other)
{
switch (this) {
case WRITER_SCALING:
throw new IllegalArgumentException("cannot be combined");
case WAITING_FOR_CTE_MATERIALIZATION:
case NO_ACTIVE_DRIVER_GROUP:
return other;
case SPLIT_QUEUES_FULL:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -170,7 +170,8 @@ public SectionExecution createSectionExecutions(
boolean summarizeTaskInfo,
RemoteTaskFactory remoteTaskFactory,
SplitSourceFactory splitSourceFactory,
int attemptId)
int attemptId,
CTEMaterializationTracker cteMaterializationTracker)
{
// Only fetch a distribution once per section to ensure all stages see the same machine assignments
Map<PartitioningHandle, NodePartitionMap> partitioningCache = new HashMap<>();
Expand All @@ -186,7 +187,8 @@ public SectionExecution createSectionExecutions(
summarizeTaskInfo,
remoteTaskFactory,
splitSourceFactory,
attemptId);
attemptId,
cteMaterializationTracker);
StageExecutionAndScheduler rootStage = getLast(sectionStages);
rootStage.getStageExecution().setOutputBuffers(outputBuffers);
return new SectionExecution(rootStage, sectionStages);
Expand All @@ -205,7 +207,8 @@ private List<StageExecutionAndScheduler> createStreamingLinkedStageExecutions(
boolean summarizeTaskInfo,
RemoteTaskFactory remoteTaskFactory,
SplitSourceFactory splitSourceFactory,
int attemptId)
int attemptId,
CTEMaterializationTracker cteMaterializationTracker)
{
ImmutableList.Builder<StageExecutionAndScheduler> stageExecutionAndSchedulers = ImmutableList.builder();

Expand Down Expand Up @@ -240,7 +243,8 @@ private List<StageExecutionAndScheduler> createStreamingLinkedStageExecutions(
summarizeTaskInfo,
remoteTaskFactory,
splitSourceFactory,
attemptId);
attemptId,
cteMaterializationTracker);
stageExecutionAndSchedulers.addAll(subTree);
childStagesBuilder.add(getLast(subTree).getStageExecution());
}
Expand All @@ -262,7 +266,8 @@ private List<StageExecutionAndScheduler> createStreamingLinkedStageExecutions(
stageExecution,
partitioningHandle,
tableWriteInfo,
childStageExecutions);
childStageExecutions,
cteMaterializationTracker);
stageExecutionAndSchedulers.add(new StageExecutionAndScheduler(
stageExecution,
stageLinkage,
Expand All @@ -281,7 +286,8 @@ private StageScheduler createStageScheduler(
SqlStageExecution stageExecution,
PartitioningHandle partitioningHandle,
TableWriteInfo tableWriteInfo,
Set<SqlStageExecution> childStageExecutions)
Set<SqlStageExecution> childStageExecutions,
CTEMaterializationTracker cteMaterializationTracker)
{
Map<PlanNodeId, SplitSource> splitSources = splitSourceFactory.createSplitSources(plan.getFragment(), session, tableWriteInfo);
int maxTasksPerStage = getMaxTasksPerStage(session);
Expand Down Expand Up @@ -341,7 +347,8 @@ else if (partitioningHandle.equals(SCALED_WRITER_DISTRIBUTION)) {
splitBatchSize,
getConcurrentLifespansPerNode(session),
nodeSelector,
ImmutableList.of(NOT_PARTITIONED));
ImmutableList.of(NOT_PARTITIONED),
cteMaterializationTracker);
}
else if (!splitSources.isEmpty()) {
// contains local source
Expand Down Expand Up @@ -400,7 +407,8 @@ else if (!splitSources.isEmpty()) {
splitBatchSize,
getConcurrentLifespansPerNode(session),
nodeScheduler.createNodeSelector(session, connectorId, nodePredicate),
connectorPartitionHandles);
connectorPartitionHandles,
cteMaterializationTracker);
if (plan.getFragment().getStageExecutionDescriptor().isRecoverableGroupedExecution()) {
stageExecution.registerStageTaskRecoveryCallback(taskId -> {
checkArgument(taskId.getStageExecutionId().getStageId().equals(stageId), "The task did not execute this stage");
Expand Down
Loading

0 comments on commit 4e1a390

Please sign in to comment.