Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-50648][CORE] Cleanup zombie tasks in non-running stages when the job is cancelled #49270

Closed
wants to merge 15 commits into from
Original file line number Diff line number Diff line change
Expand Up @@ -2187,7 +2187,8 @@ private[spark] class DAGScheduler(
log"we will roll back and rerun below stages which include itself and all its " +
log"indeterminate child stages: ${MDC(STAGES, rollingBackStages)}")
}

failedStage.markResubmitInFetchFailed()
mapStage.markResubmitInFetchFailed()
yabola marked this conversation as resolved.
Show resolved Hide resolved
// We expect one executor failure to trigger many FetchFailures in rapid succession,
// but all of those task failures can typically be handled by a single resubmission of
// the failed stage. We avoid flooding the scheduler's event queue with resubmit
Expand Down Expand Up @@ -2937,7 +2938,9 @@ private[spark] class DAGScheduler(
} else {
// This stage is only used by the job, so finish the stage if it is running.
val stage = stageIdToStage(stageId)
if (runningStages.contains(stage)) {
val shouldKill = runningStages.contains(stage) ||
(waitingStages.contains(stage) && stage.resubmitInFetchFailed)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shall we check the failedAttemptIds instead?

Suggested change
(waitingStages.contains(stage) && stage.resubmitInFetchFailed)
stage.failedAttemptIds.nonEmpty

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes we can Please see this
I'm not sure which way is better.

Copy link
Contributor

@mridulm mridulm Dec 26, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I like @Ngone51's suggestion better - simply check for stage.failedAttemptIds.nonEmpty || runningStages.contains(stage).
I can see an argument being made for failed as well.
With this, the PR will boil down to this change and tests to stress this logic ofcourse.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@mridulm @Ngone51 do you think it is necessary (waitingStages.contains(stage) && stage.failedAttemptIds.nonEmpty) || runningStages.contains(stage). Only considering failedAttemptIds may result in repeated calls to the the stage already completed and failed.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Only considering failedAttemptIds may result in repeated calls to the the stage already completed and failed.

It looks like there could be a case where the stage exists in failedStages but not in waitingStages, e.g., in the case of fetch failures, map stage and reduce stage can be added into failedStages, but the related job could be canceled before they were resubmitted. So adding waitingStages.contains(stage) would miss the stages in failedStages. And I don't think we would have repeated calls as we don't kill tasks for those failed stages.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the confirmation, done

if (shouldKill) {
try { // killAllTaskAttempts will fail if a SchedulerBackend does not implement killTask
taskScheduler.killAllTaskAttempts(stageId, shouldInterruptTaskThread(job), reason)
if (legacyAbortStageAfterKillTasks) {
Expand Down
7 changes: 7 additions & 0 deletions core/src/main/scala/org/apache/spark/scheduler/Stage.scala
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@ private[scheduler] abstract class Stage(
/** The ID to use for the next new attempt for this stage. */
private var nextAttemptId: Int = 0
private[scheduler] def getNextAttemptId: Int = nextAttemptId
private[scheduler] var _resubmitInFetchFailed: Boolean = false

val name: String = callSite.shortForm
val details: String = callSite.longForm
Expand All @@ -96,6 +97,12 @@ private[scheduler] abstract class Stage(
failedAttemptIds.clear()
}

private[scheduler] def resubmitInFetchFailed: Boolean = _resubmitInFetchFailed

private[scheduler] def markResubmitInFetchFailed() : Unit = {
_resubmitInFetchFailed = true
}

/** Creates a new attempt for this stage by creating a new StageInfo with a new attempt ID. */
def makeNewStageAttempt(
numPartitionsToCompute: Int,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -121,4 +121,6 @@ private[spark] trait TaskScheduler {
*/
def applicationAttemptId(): Option[String]


def hasRunningTasks(stageId: Int): Boolean
yabola marked this conversation as resolved.
Show resolved Hide resolved
}
Original file line number Diff line number Diff line change
Expand Up @@ -1212,6 +1212,16 @@ private[spark] class TaskSchedulerImpl(

override def applicationAttemptId(): Option[String] = backend.applicationAttemptId()

override def hasRunningTasks(stageId: Int): Boolean = synchronized {
var hasRunningTasks = false
taskSetsByStageIdAndAttempt.get(stageId).foreach { attempts =>
attempts.foreach { case (_, tsm) =>
hasRunningTasks = hasRunningTasks || tsm.runningTasksSet.nonEmpty
}
}
hasRunningTasks
yabola marked this conversation as resolved.
Show resolved Hide resolved
}

// exposed for testing
private[scheduler] def taskSetManagerForAttempt(
stageId: Int,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -225,6 +225,7 @@ class DAGSchedulerSuite extends SparkFunSuite with TempLocalSparkContext with Ti
override def executorLost(executorId: String, reason: ExecutorLossReason): Unit = {}
override def workerRemoved(workerId: String, host: String, message: String): Unit = {}
override def applicationAttemptId(): Option[String] = None
override def hasRunningTasks(stageId: Int): Boolean = false
override def executorDecommission(
executorId: String,
decommissionInfo: ExecutorDecommissionInfo): Unit = {
Expand Down Expand Up @@ -941,6 +942,7 @@ class DAGSchedulerSuite extends SparkFunSuite with TempLocalSparkContext with Ti
override def executorLost(executorId: String, reason: ExecutorLossReason): Unit = {}
override def workerRemoved(workerId: String, host: String, message: String): Unit = {}
override def applicationAttemptId(): Option[String] = None
override def hasRunningTasks(stageId: Int): Boolean = false
override def executorDecommission(
executorId: String,
decommissionInfo: ExecutorDecommissionInfo): Unit = {}
Expand Down Expand Up @@ -2248,7 +2250,7 @@ class DAGSchedulerSuite extends SparkFunSuite with TempLocalSparkContext with Ti
// original result task 1.0 succeed
runEvent(makeCompletionEvent(taskSets(1).tasks(1), Success, 42))
sc.listenerBus.waitUntilEmpty()
assert(completedStage === List(0, 1, 1, 0))
assert(completedStage === List(0, 1, 1, 0, 1))
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The last one added is the resubmit stage (FetchFailed) and in waiting stages. We will kill it and one more SparkListenerStageCompleted event will be added ( see markStageAsFinished)

Copy link
Contributor Author

@yabola yabola Dec 27, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

After changing to judging by failedAttemptIds, it won't cancel. Because all tasks finished already , in markStageAsFinished will remove failedAttemptIds if no error message.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@yabola I'm curious about the difference here. With the current approach, doesn't the stage still has to be killed because of failedAttemptIds.nonEmpty?

Copy link
Contributor Author

@yabola yabola Dec 27, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let me explain the timeline of the last event in this UT:

  1. map stage is running, result stage is waiting (result task 1.0 is running)
  2. UT result task 1.0 success(no running tasks in result stage any more)
  3. in handleTaskCompletion , the result stage markStageAsFinished and clean result stage's failedAttemptIds
  4. cancelRunningIndependentStages cancel map stage (it is in running stage) . Result stage is waiting , but don't have failedAttemptIds , so it won't be killed in cancelRunningIndependentStages (and also no running tasks in result stage)

In this UT, it is really no need to kill the last result stage.

In addition, the result stage will always definitely kill all tasks when success, we don't have to worry about this. please see here

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK, sounds good to me.

assert(scheduler.activeJobs.isEmpty)
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,7 @@ private class DummyTaskScheduler extends TaskScheduler {
override def executorLost(executorId: String, reason: ExecutorLossReason): Unit = {}
override def workerRemoved(workerId: String, host: String, message: String): Unit = {}
override def applicationAttemptId(): Option[String] = None
override def hasRunningTasks(stageId: Int): Boolean = false
def executorHeartbeatReceived(
execId: String,
accumUpdates: Array[(Long, Seq[AccumulatorV2[_, _]])],
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -938,8 +938,12 @@ class AdaptiveQueryExecSuite
val error = intercept[SparkException] {
joined.collect()
}
assert((Seq(error) ++ Option(error.getCause) ++ error.getSuppressed()).exists(
e => e.getMessage() != null && e.getMessage().contains("coalesce test error")))
val errorMessages = (Seq(error) ++ Option(error.getCause) ++ error.getSuppressed())
.filter(e => e.getMessage != null).map(e => e.getMessage)
assert(errorMessages.exists(
e => e.contains("coalesce test error")),
s"Error messages should contain `coalesce test error`, " +
s"error messages: $errorMessages")
yabola marked this conversation as resolved.
Show resolved Hide resolved

val adaptivePlan = joined.queryExecution.executedPlan.asInstanceOf[AdaptiveSparkPlanExec]

Expand Down
Loading