diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala index 4f7338f74e298..e06b7d86e1db0 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala @@ -2937,7 +2937,8 @@ private[spark] class DAGScheduler( } else { // This stage is only used by the job, so finish the stage if it is running. val stage = stageIdToStage(stageId) - if (runningStages.contains(stage)) { + // Stages with failedAttemptIds may have tasks that are running + if (runningStages.contains(stage) || stage.failedAttemptIds.nonEmpty) { try { // killAllTaskAttempts will fail if a SchedulerBackend does not implement killTask taskScheduler.killAllTaskAttempts(stageId, shouldInterruptTaskThread(job), reason) if (legacyAbortStageAfterKillTasks) { diff --git a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala index 243d33fe55a79..0260f05761e57 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala @@ -185,6 +185,8 @@ class DAGSchedulerSuite extends SparkFunSuite with TempLocalSparkContext with Ti private var firstInit: Boolean = _ /** Set of TaskSets the DAGScheduler has requested executed. */ val taskSets = scala.collection.mutable.Buffer[TaskSet]() + /** Track running tasks, the key is the task's stageId , the value is the task's partitionId */ + var runningTaskInfos = new HashMap[Int, HashSet[Int]]() /** Stages for which the DAGScheduler has called TaskScheduler.killAllTaskAttempts(). */ val cancelledStages = new HashSet[Int]() @@ -206,12 +208,16 @@ class DAGSchedulerSuite extends SparkFunSuite with TempLocalSparkContext with Ti // normally done by TaskSetManager taskSet.tasks.foreach(_.epoch = mapOutputTracker.getEpoch) taskSets += taskSet + val taskPartitionIds = new HashSet[Int]() + taskSet.tasks.foreach(task => taskPartitionIds += task.partitionId) + runningTaskInfos.put(taskSet.stageId, taskPartitionIds) } override def killTaskAttempt( taskId: Long, interruptThread: Boolean, reason: String): Boolean = false override def killAllTaskAttempts( stageId: Int, interruptThread: Boolean, reason: String): Unit = { cancelledStages += stageId + runningTaskInfos.remove(stageId) } override def notifyPartitionCompletion(stageId: Int, partitionId: Int): Unit = { taskSets.filter(_.stageId == stageId).lastOption.foreach { ts => @@ -393,6 +399,16 @@ class DAGSchedulerSuite extends SparkFunSuite with TempLocalSparkContext with Ti handleShuffleMergeFinalized(shuffleMapStage, shuffleMapStage.shuffleDep.shuffleMergeId) } } + + override private[scheduler] def handleTaskCompletion(event: CompletionEvent): Unit = { + super.handleTaskCompletion(event) + if (runningTaskInfos.contains(event.task.stageId)) { + runningTaskInfos(event.task.stageId) -= event.task.partitionId + if (runningTaskInfos(event.task.stageId).isEmpty) { + runningTaskInfos.remove(event.task.stageId) + } + } + } } override def beforeEach(): Unit = { @@ -2248,10 +2264,50 @@ class DAGSchedulerSuite extends SparkFunSuite with TempLocalSparkContext with Ti // original result task 1.0 succeed runEvent(makeCompletionEvent(taskSets(1).tasks(1), Success, 42)) sc.listenerBus.waitUntilEmpty() - assert(completedStage === List(0, 1, 1, 0)) + assert(completedStage === List(0, 1, 1, 0, 1)) assert(scheduler.activeJobs.isEmpty) } + test("SPARK-50648: when job is cancelled during shuffle retry in parent stage, " + + "should kill all running tasks") { + val shuffleMapRdd = new MyRDD(sc, 2, Nil) + val shuffleDep = new ShuffleDependency(shuffleMapRdd, new HashPartitioner(2)) + val reduceRdd = new MyRDD(sc, 2, List(shuffleDep)) + submit(reduceRdd, Array(0, 1)) + completeShuffleMapStageSuccessfully(0, 0, 2) + sc.listenerBus.waitUntilEmpty() + + val info = new TaskInfo( + 3, index = 1, attemptNumber = 1, + partitionId = taskSets(1).tasks(0).partitionId, 0L, "", "", TaskLocality.ANY, true) + // result task 0.0 fetch failed, but result task 1.0 is still running + runEvent(makeCompletionEvent(taskSets(1).tasks(0), + FetchFailed(makeBlockManagerId("hostA"), shuffleDep.shuffleId, 0L, 0, 1, "ignored"), + null, + Seq.empty, + Array.empty, + info)) + sc.listenerBus.waitUntilEmpty() + + Thread.sleep(DAGScheduler.RESUBMIT_TIMEOUT * 2) + // map stage is running by resubmitted, result stage is waiting + // map tasks and the origin result task 1.0 are running + assert(scheduler.runningStages.size == 1, "Map stage should be running") + val mapStage = scheduler.runningStages.head + assert(mapStage.id === 0) + assert(mapStage.latestInfo.failureReason.isEmpty) + assert(scheduler.waitingStages.size == 1, "Result stage should be waiting") + assert(runningTaskInfos.size == 2) + assert(runningTaskInfos(taskSets(1).stageId).size == 1, + "origin result task 1.0 should be running") + + scheduler.cancelAllJobs() + // all tasks should be killed + assert(runningTaskInfos.isEmpty) + assert(scheduler.runningStages.isEmpty) + assert(scheduler.waitingStages.isEmpty) + } + test("misbehaved accumulator should not crash DAGScheduler and SparkContext") { val acc = new LongAccumulator { override def add(v: java.lang.Long): Unit = throw new DAGSchedulerSuiteDummyException