Fix for Spark-870.

This patch fixes a bug where the Spark UI didn't display the correct number of total
tasks if the number of tasks in a Stage doesn't equal the number of RDD partitions.

It also cleans up the listener API a bit by embedding this information in the
StageInfo class rather than passing it seperately.
This commit is contained in:
Patrick Wendell 2013-10-20 15:43:42 -07:00
parent a854f5bfcf
commit 2fa3c4c49c
7 changed files with 17 additions and 13 deletions

View file

@ -183,7 +183,7 @@ class DAGScheduler(
shuffleToMapStage.get(shuffleDep.shuffleId) match { shuffleToMapStage.get(shuffleDep.shuffleId) match {
case Some(stage) => stage case Some(stage) => stage
case None => case None =>
val stage = newStage(shuffleDep.rdd, Some(shuffleDep), jobId) val stage = newStage(shuffleDep.rdd, shuffleDep.rdd.partitions.size, Some(shuffleDep), jobId)
shuffleToMapStage(shuffleDep.shuffleId) = stage shuffleToMapStage(shuffleDep.shuffleId) = stage
stage stage
} }
@ -196,6 +196,7 @@ class DAGScheduler(
*/ */
private def newStage( private def newStage(
rdd: RDD[_], rdd: RDD[_],
numTasks: Int,
shuffleDep: Option[ShuffleDependency[_,_]], shuffleDep: Option[ShuffleDependency[_,_]],
jobId: Int, jobId: Int,
callSite: Option[String] = None) callSite: Option[String] = None)
@ -208,7 +209,8 @@ class DAGScheduler(
mapOutputTracker.registerShuffle(shuffleDep.get.shuffleId, rdd.partitions.size) mapOutputTracker.registerShuffle(shuffleDep.get.shuffleId, rdd.partitions.size)
} }
val id = nextStageId.getAndIncrement() val id = nextStageId.getAndIncrement()
val stage = new Stage(id, rdd, shuffleDep, getParentStages(rdd, jobId), jobId, callSite) val stage =
new Stage(id, rdd, numTasks, shuffleDep, getParentStages(rdd, jobId), jobId, callSite)
stageIdToStage(id) = stage stageIdToStage(id) = stage
stageToInfos(stage) = StageInfo(stage) stageToInfos(stage) = StageInfo(stage)
stage stage
@ -362,7 +364,7 @@ class DAGScheduler(
private[scheduler] def processEvent(event: DAGSchedulerEvent): Boolean = { private[scheduler] def processEvent(event: DAGSchedulerEvent): Boolean = {
event match { event match {
case JobSubmitted(jobId, rdd, func, partitions, allowLocal, callSite, listener, properties) => case JobSubmitted(jobId, rdd, func, partitions, allowLocal, callSite, listener, properties) =>
val finalStage = newStage(rdd, None, jobId, Some(callSite)) val finalStage = newStage(rdd, partitions.size, None, jobId, Some(callSite))
val job = new ActiveJob(jobId, finalStage, func, partitions, callSite, listener, properties) val job = new ActiveJob(jobId, finalStage, func, partitions, callSite, listener, properties)
clearCacheLocs() clearCacheLocs()
logInfo("Got job " + job.jobId + " (" + callSite + ") with " + partitions.length + logInfo("Got job " + job.jobId + " (" + callSite + ") with " + partitions.length +
@ -585,7 +587,7 @@ class DAGScheduler(
// must be run listener before possible NotSerializableException // must be run listener before possible NotSerializableException
// should be "StageSubmitted" first and then "JobEnded" // should be "StageSubmitted" first and then "JobEnded"
listenerBus.post(SparkListenerStageSubmitted(stageToInfos(stage), tasks.size, properties)) listenerBus.post(SparkListenerStageSubmitted(stageToInfos(stage), properties))
if (tasks.size > 0) { if (tasks.size > 0) {
// Preemptively serialize a task to make sure it can be serialized. We are catching this // Preemptively serialize a task to make sure it can be serialized. We are catching this

View file

@ -207,8 +207,8 @@ class JobLogger(val logDirName: String) extends SparkListener with Logging {
} }
override def onStageSubmitted(stageSubmitted: SparkListenerStageSubmitted) { override def onStageSubmitted(stageSubmitted: SparkListenerStageSubmitted) {
stageLogInfo(stageSubmitted.stage.stageId, "STAGE_ID=%d STATUS=SUBMITTED TASK_SIZE=%d".format( stageLogInfo(stageSubmitted.stage.stageId,"STAGE_ID=%d STATUS=SUBMITTED TASK_SIZE=%d".format(
stageSubmitted.stage.stageId, stageSubmitted.taskSize)) stageSubmitted.stage.stageId, stageSubmitted.stage.numTasks))
} }
override def onStageCompleted(stageCompleted: StageCompleted) { override def onStageCompleted(stageCompleted: StageCompleted) {

View file

@ -24,7 +24,7 @@ import org.apache.spark.executor.TaskMetrics
sealed trait SparkListenerEvents sealed trait SparkListenerEvents
case class SparkListenerStageSubmitted(stage: StageInfo, taskSize: Int, properties: Properties) case class SparkListenerStageSubmitted(stage: StageInfo, properties: Properties)
extends SparkListenerEvents extends SparkListenerEvents
case class StageCompleted(val stage: StageInfo) extends SparkListenerEvents case class StageCompleted(val stage: StageInfo) extends SparkListenerEvents

View file

@ -39,6 +39,7 @@ import org.apache.spark.storage.BlockManagerId
private[spark] class Stage( private[spark] class Stage(
val id: Int, val id: Int,
val rdd: RDD[_], val rdd: RDD[_],
val numTasks: Int,
val shuffleDep: Option[ShuffleDependency[_,_]], // Output shuffle if stage is a map stage val shuffleDep: Option[ShuffleDependency[_,_]], // Output shuffle if stage is a map stage
val parents: List[Stage], val parents: List[Stage],
val jobId: Int, val jobId: Int,

View file

@ -30,9 +30,8 @@ case class StageInfo(
var completionTime: Option[Long] = None var completionTime: Option[Long] = None
val rddName = stage.rdd.toString val rddName = stage.rdd.toString
val name = stage.name val name = stage.name
// TODO: We should also track the number of tasks associated with this stage, which may not
// be equal to numPartitions.
val numPartitions = stage.numPartitions val numPartitions = stage.numPartitions
val numTasks = stage.numTasks
override def toString = rddName override def toString = rddName
} }

View file

@ -94,7 +94,7 @@ private[spark] class StageTable(val stages: Seq[StageInfo], val parent: JobProgr
case f if f > 0 => "(%s failed)".format(f) case f if f > 0 => "(%s failed)".format(f)
case _ => "" case _ => ""
} }
val totalTasks = s.numPartitions val totalTasks = s.numTasks
val poolName = listener.stageToPool.get(s.stageId) val poolName = listener.stageToPool.get(s.stageId)

View file

@ -58,11 +58,13 @@ class JobLoggerSuite extends FunSuite with LocalSparkContext with ShouldMatchers
val parentRdd = makeRdd(4, Nil) val parentRdd = makeRdd(4, Nil)
val shuffleDep = new ShuffleDependency(parentRdd, null) val shuffleDep = new ShuffleDependency(parentRdd, null)
val rootRdd = makeRdd(4, List(shuffleDep)) val rootRdd = makeRdd(4, List(shuffleDep))
val shuffleMapStage = new Stage(1, parentRdd, Some(shuffleDep), Nil, jobID, None) val shuffleMapStage =
val rootStage = new Stage(0, rootRdd, None, List(shuffleMapStage), jobID, None) new Stage(1, parentRdd, parentRdd.partitions.size, Some(shuffleDep), Nil, jobID, None)
val rootStage =
new Stage(0, rootRdd, rootRdd.partitions.size, None, List(shuffleMapStage), jobID, None)
val rootStageInfo = new StageInfo(rootStage) val rootStageInfo = new StageInfo(rootStage)
joblogger.onStageSubmitted(SparkListenerStageSubmitted(rootStageInfo, 4, null)) joblogger.onStageSubmitted(SparkListenerStageSubmitted(rootStageInfo, null))
joblogger.getRddNameTest(parentRdd) should be (parentRdd.getClass.getName) joblogger.getRddNameTest(parentRdd) should be (parentRdd.getClass.getName)
parentRdd.setName("MyRDD") parentRdd.setName("MyRDD")
joblogger.getRddNameTest(parentRdd) should be ("MyRDD") joblogger.getRddNameTest(parentRdd) should be ("MyRDD")