Fix for Spark-870.

This patch fixes a bug where the Spark UI didn't display the correct number of total
tasks if the number of tasks in a Stage doesn't equal the number of RDD partitions.

It also cleans up the listener API a bit by embedding this information in the
StageInfo class rather than passing it seperately.
This commit is contained in:
Patrick Wendell 2013-10-20 15:43:42 -07:00
parent a854f5bfcf
commit 2fa3c4c49c
7 changed files with 17 additions and 13 deletions

View file

@ -183,7 +183,7 @@ class DAGScheduler(
shuffleToMapStage.get(shuffleDep.shuffleId) match {
case Some(stage) => stage
case None =>
val stage = newStage(shuffleDep.rdd, Some(shuffleDep), jobId)
val stage = newStage(shuffleDep.rdd, shuffleDep.rdd.partitions.size, Some(shuffleDep), jobId)
shuffleToMapStage(shuffleDep.shuffleId) = stage
stage
}
@ -196,6 +196,7 @@ class DAGScheduler(
*/
private def newStage(
rdd: RDD[_],
numTasks: Int,
shuffleDep: Option[ShuffleDependency[_,_]],
jobId: Int,
callSite: Option[String] = None)
@ -208,7 +209,8 @@ class DAGScheduler(
mapOutputTracker.registerShuffle(shuffleDep.get.shuffleId, rdd.partitions.size)
}
val id = nextStageId.getAndIncrement()
val stage = new Stage(id, rdd, shuffleDep, getParentStages(rdd, jobId), jobId, callSite)
val stage =
new Stage(id, rdd, numTasks, shuffleDep, getParentStages(rdd, jobId), jobId, callSite)
stageIdToStage(id) = stage
stageToInfos(stage) = StageInfo(stage)
stage
@ -362,7 +364,7 @@ class DAGScheduler(
private[scheduler] def processEvent(event: DAGSchedulerEvent): Boolean = {
event match {
case JobSubmitted(jobId, rdd, func, partitions, allowLocal, callSite, listener, properties) =>
val finalStage = newStage(rdd, None, jobId, Some(callSite))
val finalStage = newStage(rdd, partitions.size, None, jobId, Some(callSite))
val job = new ActiveJob(jobId, finalStage, func, partitions, callSite, listener, properties)
clearCacheLocs()
logInfo("Got job " + job.jobId + " (" + callSite + ") with " + partitions.length +
@ -585,7 +587,7 @@ class DAGScheduler(
// must be run listener before possible NotSerializableException
// should be "StageSubmitted" first and then "JobEnded"
listenerBus.post(SparkListenerStageSubmitted(stageToInfos(stage), tasks.size, properties))
listenerBus.post(SparkListenerStageSubmitted(stageToInfos(stage), properties))
if (tasks.size > 0) {
// Preemptively serialize a task to make sure it can be serialized. We are catching this

View file

@ -207,8 +207,8 @@ class JobLogger(val logDirName: String) extends SparkListener with Logging {
}
override def onStageSubmitted(stageSubmitted: SparkListenerStageSubmitted) {
stageLogInfo(stageSubmitted.stage.stageId, "STAGE_ID=%d STATUS=SUBMITTED TASK_SIZE=%d".format(
stageSubmitted.stage.stageId, stageSubmitted.taskSize))
stageLogInfo(stageSubmitted.stage.stageId,"STAGE_ID=%d STATUS=SUBMITTED TASK_SIZE=%d".format(
stageSubmitted.stage.stageId, stageSubmitted.stage.numTasks))
}
override def onStageCompleted(stageCompleted: StageCompleted) {

View file

@ -24,7 +24,7 @@ import org.apache.spark.executor.TaskMetrics
sealed trait SparkListenerEvents
case class SparkListenerStageSubmitted(stage: StageInfo, taskSize: Int, properties: Properties)
case class SparkListenerStageSubmitted(stage: StageInfo, properties: Properties)
extends SparkListenerEvents
case class StageCompleted(val stage: StageInfo) extends SparkListenerEvents

View file

@ -39,6 +39,7 @@ import org.apache.spark.storage.BlockManagerId
private[spark] class Stage(
val id: Int,
val rdd: RDD[_],
val numTasks: Int,
val shuffleDep: Option[ShuffleDependency[_,_]], // Output shuffle if stage is a map stage
val parents: List[Stage],
val jobId: Int,

View file

@ -30,9 +30,8 @@ case class StageInfo(
var completionTime: Option[Long] = None
val rddName = stage.rdd.toString
val name = stage.name
// TODO: We should also track the number of tasks associated with this stage, which may not
// be equal to numPartitions.
val numPartitions = stage.numPartitions
val numTasks = stage.numTasks
override def toString = rddName
}

View file

@ -94,7 +94,7 @@ private[spark] class StageTable(val stages: Seq[StageInfo], val parent: JobProgr
case f if f > 0 => "(%s failed)".format(f)
case _ => ""
}
val totalTasks = s.numPartitions
val totalTasks = s.numTasks
val poolName = listener.stageToPool.get(s.stageId)

View file

@ -58,11 +58,13 @@ class JobLoggerSuite extends FunSuite with LocalSparkContext with ShouldMatchers
val parentRdd = makeRdd(4, Nil)
val shuffleDep = new ShuffleDependency(parentRdd, null)
val rootRdd = makeRdd(4, List(shuffleDep))
val shuffleMapStage = new Stage(1, parentRdd, Some(shuffleDep), Nil, jobID, None)
val rootStage = new Stage(0, rootRdd, None, List(shuffleMapStage), jobID, None)
val shuffleMapStage =
new Stage(1, parentRdd, parentRdd.partitions.size, Some(shuffleDep), Nil, jobID, None)
val rootStage =
new Stage(0, rootRdd, rootRdd.partitions.size, None, List(shuffleMapStage), jobID, None)
val rootStageInfo = new StageInfo(rootStage)
joblogger.onStageSubmitted(SparkListenerStageSubmitted(rootStageInfo, 4, null))
joblogger.onStageSubmitted(SparkListenerStageSubmitted(rootStageInfo, null))
joblogger.getRddNameTest(parentRdd) should be (parentRdd.getClass.getName)
parentRdd.setName("MyRDD")
joblogger.getRddNameTest(parentRdd) should be ("MyRDD")