diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala index 0dec93ff83..b412ee2e2b 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala @@ -183,7 +183,7 @@ class DAGScheduler( shuffleToMapStage.get(shuffleDep.shuffleId) match { case Some(stage) => stage case None => - val stage = newStage(shuffleDep.rdd, Some(shuffleDep), jobId) + val stage = newStage(shuffleDep.rdd, shuffleDep.rdd.partitions.size, Some(shuffleDep), jobId) shuffleToMapStage(shuffleDep.shuffleId) = stage stage } @@ -196,6 +196,7 @@ class DAGScheduler( */ private def newStage( rdd: RDD[_], + numTasks: Int, shuffleDep: Option[ShuffleDependency[_,_]], jobId: Int, callSite: Option[String] = None) @@ -208,7 +209,8 @@ class DAGScheduler( mapOutputTracker.registerShuffle(shuffleDep.get.shuffleId, rdd.partitions.size) } val id = nextStageId.getAndIncrement() - val stage = new Stage(id, rdd, shuffleDep, getParentStages(rdd, jobId), jobId, callSite) + val stage = + new Stage(id, rdd, numTasks, shuffleDep, getParentStages(rdd, jobId), jobId, callSite) stageIdToStage(id) = stage stageToInfos(stage) = StageInfo(stage) stage @@ -362,7 +364,7 @@ class DAGScheduler( private[scheduler] def processEvent(event: DAGSchedulerEvent): Boolean = { event match { case JobSubmitted(jobId, rdd, func, partitions, allowLocal, callSite, listener, properties) => - val finalStage = newStage(rdd, None, jobId, Some(callSite)) + val finalStage = newStage(rdd, partitions.size, None, jobId, Some(callSite)) val job = new ActiveJob(jobId, finalStage, func, partitions, callSite, listener, properties) clearCacheLocs() logInfo("Got job " + job.jobId + " (" + callSite + ") with " + partitions.length + @@ -585,7 +587,7 @@ class DAGScheduler( // must be run listener before possible NotSerializableException // should be "StageSubmitted" first and then "JobEnded" - listenerBus.post(SparkListenerStageSubmitted(stageToInfos(stage), tasks.size, properties)) + listenerBus.post(SparkListenerStageSubmitted(stageToInfos(stage), properties)) if (tasks.size > 0) { // Preemptively serialize a task to make sure it can be serialized. We are catching this diff --git a/core/src/main/scala/org/apache/spark/scheduler/JobLogger.scala b/core/src/main/scala/org/apache/spark/scheduler/JobLogger.scala index fe1990bcc1..e16436a9c7 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/JobLogger.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/JobLogger.scala @@ -207,8 +207,8 @@ class JobLogger(val logDirName: String) extends SparkListener with Logging { } override def onStageSubmitted(stageSubmitted: SparkListenerStageSubmitted) { - stageLogInfo(stageSubmitted.stage.stageId, "STAGE_ID=%d STATUS=SUBMITTED TASK_SIZE=%d".format( - stageSubmitted.stage.stageId, stageSubmitted.taskSize)) + stageLogInfo(stageSubmitted.stage.stageId,"STAGE_ID=%d STATUS=SUBMITTED TASK_SIZE=%d".format( + stageSubmitted.stage.stageId, stageSubmitted.stage.numTasks)) } override def onStageCompleted(stageCompleted: StageCompleted) { diff --git a/core/src/main/scala/org/apache/spark/scheduler/SparkListener.scala b/core/src/main/scala/org/apache/spark/scheduler/SparkListener.scala index 49faa0ba7a..1391993147 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/SparkListener.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/SparkListener.scala @@ -24,7 +24,7 @@ import org.apache.spark.executor.TaskMetrics sealed trait SparkListenerEvents -case class SparkListenerStageSubmitted(stage: StageInfo, taskSize: Int, properties: Properties) +case class SparkListenerStageSubmitted(stage: StageInfo, properties: Properties) extends SparkListenerEvents case class StageCompleted(val stage: StageInfo) extends SparkListenerEvents diff --git a/core/src/main/scala/org/apache/spark/scheduler/Stage.scala b/core/src/main/scala/org/apache/spark/scheduler/Stage.scala index aa293dc6b3..b320d8aa65 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/Stage.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/Stage.scala @@ -39,6 +39,7 @@ import org.apache.spark.storage.BlockManagerId private[spark] class Stage( val id: Int, val rdd: RDD[_], + val numTasks: Int, val shuffleDep: Option[ShuffleDependency[_,_]], // Output shuffle if stage is a map stage val parents: List[Stage], val jobId: Int, diff --git a/core/src/main/scala/org/apache/spark/scheduler/StageInfo.scala b/core/src/main/scala/org/apache/spark/scheduler/StageInfo.scala index 3156071535..52002f48e3 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/StageInfo.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/StageInfo.scala @@ -30,9 +30,8 @@ case class StageInfo( var completionTime: Option[Long] = None val rddName = stage.rdd.toString val name = stage.name - // TODO: We should also track the number of tasks associated with this stage, which may not - // be equal to numPartitions. val numPartitions = stage.numPartitions + val numTasks = stage.numTasks override def toString = rddName } diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/StageTable.scala b/core/src/main/scala/org/apache/spark/ui/jobs/StageTable.scala index c47878dcd6..b08cbe8d22 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/StageTable.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/StageTable.scala @@ -94,7 +94,7 @@ private[spark] class StageTable(val stages: Seq[StageInfo], val parent: JobProgr case f if f > 0 => "(%s failed)".format(f) case _ => "" } - val totalTasks = s.numPartitions + val totalTasks = s.numTasks val poolName = listener.stageToPool.get(s.stageId) diff --git a/core/src/test/scala/org/apache/spark/scheduler/JobLoggerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/JobLoggerSuite.scala index 9c3ca0bb92..8406093246 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/JobLoggerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/JobLoggerSuite.scala @@ -58,11 +58,13 @@ class JobLoggerSuite extends FunSuite with LocalSparkContext with ShouldMatchers val parentRdd = makeRdd(4, Nil) val shuffleDep = new ShuffleDependency(parentRdd, null) val rootRdd = makeRdd(4, List(shuffleDep)) - val shuffleMapStage = new Stage(1, parentRdd, Some(shuffleDep), Nil, jobID, None) - val rootStage = new Stage(0, rootRdd, None, List(shuffleMapStage), jobID, None) + val shuffleMapStage = + new Stage(1, parentRdd, parentRdd.partitions.size, Some(shuffleDep), Nil, jobID, None) + val rootStage = + new Stage(0, rootRdd, rootRdd.partitions.size, None, List(shuffleMapStage), jobID, None) val rootStageInfo = new StageInfo(rootStage) - joblogger.onStageSubmitted(SparkListenerStageSubmitted(rootStageInfo, 4, null)) + joblogger.onStageSubmitted(SparkListenerStageSubmitted(rootStageInfo, null)) joblogger.getRddNameTest(parentRdd) should be (parentRdd.getClass.getName) parentRdd.setName("MyRDD") joblogger.getRddNameTest(parentRdd) should be ("MyRDD")