Fix for Spark-870.
This patch fixes a bug where the Spark UI didn't display the correct number of total tasks if the number of tasks in a Stage doesn't equal the number of RDD partitions. It also cleans up the listener API a bit by embedding this information in the StageInfo class rather than passing it seperately.
This commit is contained in:
parent
a854f5bfcf
commit
2fa3c4c49c
|
@ -183,7 +183,7 @@ class DAGScheduler(
|
|||
shuffleToMapStage.get(shuffleDep.shuffleId) match {
|
||||
case Some(stage) => stage
|
||||
case None =>
|
||||
val stage = newStage(shuffleDep.rdd, Some(shuffleDep), jobId)
|
||||
val stage = newStage(shuffleDep.rdd, shuffleDep.rdd.partitions.size, Some(shuffleDep), jobId)
|
||||
shuffleToMapStage(shuffleDep.shuffleId) = stage
|
||||
stage
|
||||
}
|
||||
|
@ -196,6 +196,7 @@ class DAGScheduler(
|
|||
*/
|
||||
private def newStage(
|
||||
rdd: RDD[_],
|
||||
numTasks: Int,
|
||||
shuffleDep: Option[ShuffleDependency[_,_]],
|
||||
jobId: Int,
|
||||
callSite: Option[String] = None)
|
||||
|
@ -208,7 +209,8 @@ class DAGScheduler(
|
|||
mapOutputTracker.registerShuffle(shuffleDep.get.shuffleId, rdd.partitions.size)
|
||||
}
|
||||
val id = nextStageId.getAndIncrement()
|
||||
val stage = new Stage(id, rdd, shuffleDep, getParentStages(rdd, jobId), jobId, callSite)
|
||||
val stage =
|
||||
new Stage(id, rdd, numTasks, shuffleDep, getParentStages(rdd, jobId), jobId, callSite)
|
||||
stageIdToStage(id) = stage
|
||||
stageToInfos(stage) = StageInfo(stage)
|
||||
stage
|
||||
|
@ -362,7 +364,7 @@ class DAGScheduler(
|
|||
private[scheduler] def processEvent(event: DAGSchedulerEvent): Boolean = {
|
||||
event match {
|
||||
case JobSubmitted(jobId, rdd, func, partitions, allowLocal, callSite, listener, properties) =>
|
||||
val finalStage = newStage(rdd, None, jobId, Some(callSite))
|
||||
val finalStage = newStage(rdd, partitions.size, None, jobId, Some(callSite))
|
||||
val job = new ActiveJob(jobId, finalStage, func, partitions, callSite, listener, properties)
|
||||
clearCacheLocs()
|
||||
logInfo("Got job " + job.jobId + " (" + callSite + ") with " + partitions.length +
|
||||
|
@ -585,7 +587,7 @@ class DAGScheduler(
|
|||
|
||||
// must be run listener before possible NotSerializableException
|
||||
// should be "StageSubmitted" first and then "JobEnded"
|
||||
listenerBus.post(SparkListenerStageSubmitted(stageToInfos(stage), tasks.size, properties))
|
||||
listenerBus.post(SparkListenerStageSubmitted(stageToInfos(stage), properties))
|
||||
|
||||
if (tasks.size > 0) {
|
||||
// Preemptively serialize a task to make sure it can be serialized. We are catching this
|
||||
|
|
|
@ -207,8 +207,8 @@ class JobLogger(val logDirName: String) extends SparkListener with Logging {
|
|||
}
|
||||
|
||||
override def onStageSubmitted(stageSubmitted: SparkListenerStageSubmitted) {
|
||||
stageLogInfo(stageSubmitted.stage.stageId, "STAGE_ID=%d STATUS=SUBMITTED TASK_SIZE=%d".format(
|
||||
stageSubmitted.stage.stageId, stageSubmitted.taskSize))
|
||||
stageLogInfo(stageSubmitted.stage.stageId,"STAGE_ID=%d STATUS=SUBMITTED TASK_SIZE=%d".format(
|
||||
stageSubmitted.stage.stageId, stageSubmitted.stage.numTasks))
|
||||
}
|
||||
|
||||
override def onStageCompleted(stageCompleted: StageCompleted) {
|
||||
|
|
|
@ -24,7 +24,7 @@ import org.apache.spark.executor.TaskMetrics
|
|||
|
||||
sealed trait SparkListenerEvents
|
||||
|
||||
case class SparkListenerStageSubmitted(stage: StageInfo, taskSize: Int, properties: Properties)
|
||||
case class SparkListenerStageSubmitted(stage: StageInfo, properties: Properties)
|
||||
extends SparkListenerEvents
|
||||
|
||||
case class StageCompleted(val stage: StageInfo) extends SparkListenerEvents
|
||||
|
|
|
@ -39,6 +39,7 @@ import org.apache.spark.storage.BlockManagerId
|
|||
private[spark] class Stage(
|
||||
val id: Int,
|
||||
val rdd: RDD[_],
|
||||
val numTasks: Int,
|
||||
val shuffleDep: Option[ShuffleDependency[_,_]], // Output shuffle if stage is a map stage
|
||||
val parents: List[Stage],
|
||||
val jobId: Int,
|
||||
|
|
|
@ -30,9 +30,8 @@ case class StageInfo(
|
|||
var completionTime: Option[Long] = None
|
||||
val rddName = stage.rdd.toString
|
||||
val name = stage.name
|
||||
// TODO: We should also track the number of tasks associated with this stage, which may not
|
||||
// be equal to numPartitions.
|
||||
val numPartitions = stage.numPartitions
|
||||
val numTasks = stage.numTasks
|
||||
|
||||
override def toString = rddName
|
||||
}
|
||||
|
|
|
@ -94,7 +94,7 @@ private[spark] class StageTable(val stages: Seq[StageInfo], val parent: JobProgr
|
|||
case f if f > 0 => "(%s failed)".format(f)
|
||||
case _ => ""
|
||||
}
|
||||
val totalTasks = s.numPartitions
|
||||
val totalTasks = s.numTasks
|
||||
|
||||
val poolName = listener.stageToPool.get(s.stageId)
|
||||
|
||||
|
|
|
@ -58,11 +58,13 @@ class JobLoggerSuite extends FunSuite with LocalSparkContext with ShouldMatchers
|
|||
val parentRdd = makeRdd(4, Nil)
|
||||
val shuffleDep = new ShuffleDependency(parentRdd, null)
|
||||
val rootRdd = makeRdd(4, List(shuffleDep))
|
||||
val shuffleMapStage = new Stage(1, parentRdd, Some(shuffleDep), Nil, jobID, None)
|
||||
val rootStage = new Stage(0, rootRdd, None, List(shuffleMapStage), jobID, None)
|
||||
val shuffleMapStage =
|
||||
new Stage(1, parentRdd, parentRdd.partitions.size, Some(shuffleDep), Nil, jobID, None)
|
||||
val rootStage =
|
||||
new Stage(0, rootRdd, rootRdd.partitions.size, None, List(shuffleMapStage), jobID, None)
|
||||
val rootStageInfo = new StageInfo(rootStage)
|
||||
|
||||
joblogger.onStageSubmitted(SparkListenerStageSubmitted(rootStageInfo, 4, null))
|
||||
joblogger.onStageSubmitted(SparkListenerStageSubmitted(rootStageInfo, null))
|
||||
joblogger.getRddNameTest(parentRdd) should be (parentRdd.getClass.getName)
|
||||
parentRdd.setName("MyRDD")
|
||||
joblogger.getRddNameTest(parentRdd) should be ("MyRDD")
|
||||
|
|
Loading…
Reference in a new issue