track total bytes written by ShuffleMapTasks
This commit is contained in:
parent
b430d2359d
commit
843084d69d
|
@ -511,8 +511,11 @@ class DAGScheduler(
|
|||
}
|
||||
|
||||
case smt: ShuffleMapTask =>
|
||||
val stage = idToStage(smt.stageId)
|
||||
val status = event.result.asInstanceOf[MapStatus]
|
||||
smt.totalBytesWritten match {
|
||||
case Some(b) => stageToInfos(stage).shuffleBytesWritten += b
|
||||
case None => throw new RuntimeException("shuffle stask completed without tracking bytes written")
|
||||
}
|
||||
val execId = status.location.executorId
|
||||
logDebug("ShuffleMapTask finished on " + execId)
|
||||
if (failedGeneration.contains(execId) && smt.generation <= failedGeneration(execId)) {
|
||||
|
|
|
@ -81,6 +81,9 @@ private[spark] class ShuffleMapTask(
|
|||
with Externalizable
|
||||
with Logging {
|
||||
|
||||
|
||||
var totalBytesWritten : Option[Long] = None
|
||||
|
||||
protected def this() = this(0, null, null, 0, null)
|
||||
|
||||
var split = if (rdd == null) {
|
||||
|
@ -130,14 +133,18 @@ private[spark] class ShuffleMapTask(
|
|||
|
||||
val compressedSizes = new Array[Byte](numOutputSplits)
|
||||
|
||||
var totalBytes = 0l
|
||||
|
||||
val blockManager = SparkEnv.get.blockManager
|
||||
for (i <- 0 until numOutputSplits) {
|
||||
val blockId = "shuffle_" + dep.shuffleId + "_" + partition + "_" + i
|
||||
// Get a Scala iterator from Java map
|
||||
val iter: Iterator[(Any, Any)] = buckets(i).iterator
|
||||
val size = blockManager.put(blockId, iter, StorageLevel.DISK_ONLY, false)
|
||||
totalBytes += size
|
||||
compressedSizes(i) = MapOutputTracker.compressSize(size)
|
||||
}
|
||||
totalBytesWritten = Some(totalBytes)
|
||||
|
||||
return new MapStatus(blockManager.blockManagerId, compressedSizes)
|
||||
} finally {
|
||||
|
|
|
@ -4,7 +4,12 @@ import cluster.TaskInfo
|
|||
import collection._
|
||||
import spark.util.Distribution
|
||||
|
||||
case class StageInfo(val stage: Stage, val taskInfos: mutable.Buffer[TaskInfo] = mutable.Buffer[TaskInfo]()) {
|
||||
case class StageInfo(
|
||||
val stage: Stage,
|
||||
val taskInfos: mutable.Buffer[TaskInfo] = mutable.Buffer[TaskInfo](),
|
||||
val shuffleBytesWritten : mutable.Buffer[Long] = mutable.Buffer[Long](),
|
||||
val shuffleBytesRead : mutable.Buffer[Long] = mutable.Buffer[Long]()
|
||||
) {
|
||||
|
||||
def name = stage.rdd.name + "(" + stage.origin + ")"
|
||||
|
||||
|
|
Loading…
Reference in a new issue