track total bytes written by ShuffleMapTasks

This commit is contained in:
Imran Rashid 2013-02-05 10:15:28 -08:00
parent b430d2359d
commit 843084d69d
3 changed files with 17 additions and 2 deletions

View file

@ -511,8 +511,11 @@ class DAGScheduler(
}
case smt: ShuffleMapTask =>
val stage = idToStage(smt.stageId)
val status = event.result.asInstanceOf[MapStatus]
smt.totalBytesWritten match {
case Some(b) => stageToInfos(stage).shuffleBytesWritten += b
case None => throw new RuntimeException("shuffle stask completed without tracking bytes written")
}
val execId = status.location.executorId
logDebug("ShuffleMapTask finished on " + execId)
if (failedGeneration.contains(execId) && smt.generation <= failedGeneration(execId)) {

View file

@ -81,6 +81,9 @@ private[spark] class ShuffleMapTask(
with Externalizable
with Logging {
var totalBytesWritten : Option[Long] = None
protected def this() = this(0, null, null, 0, null)
var split = if (rdd == null) {
@ -130,14 +133,18 @@ private[spark] class ShuffleMapTask(
val compressedSizes = new Array[Byte](numOutputSplits)
var totalBytes = 0l
val blockManager = SparkEnv.get.blockManager
for (i <- 0 until numOutputSplits) {
val blockId = "shuffle_" + dep.shuffleId + "_" + partition + "_" + i
// Get a Scala iterator from Java map
val iter: Iterator[(Any, Any)] = buckets(i).iterator
val size = blockManager.put(blockId, iter, StorageLevel.DISK_ONLY, false)
totalBytes += size
compressedSizes(i) = MapOutputTracker.compressSize(size)
}
totalBytesWritten = Some(totalBytes)
return new MapStatus(blockManager.blockManagerId, compressedSizes)
} finally {

View file

@ -4,7 +4,12 @@ import cluster.TaskInfo
import collection._
import spark.util.Distribution
case class StageInfo(val stage: Stage, val taskInfos: mutable.Buffer[TaskInfo] = mutable.Buffer[TaskInfo]()) {
case class StageInfo(
val stage: Stage,
val taskInfos: mutable.Buffer[TaskInfo] = mutable.Buffer[TaskInfo](),
val shuffleBytesWritten : mutable.Buffer[Long] = mutable.Buffer[Long](),
val shuffleBytesRead : mutable.Buffer[Long] = mutable.Buffer[Long]()
) {
def name = stage.rdd.name + "(" + stage.origin + ")"