enable task metrics in local mode, add tests
This commit is contained in:
parent
ec30188a2a
commit
20f01a0a1b
|
@ -67,8 +67,10 @@ private[spark] class LocalScheduler(threads: Int, maxFailures: Int, sc: SparkCon
|
||||||
logInfo("Size of task " + idInJob + " is " + bytes.limit + " bytes")
|
logInfo("Size of task " + idInJob + " is " + bytes.limit + " bytes")
|
||||||
val (taskFiles, taskJars, taskBytes) = Task.deserializeWithDependencies(bytes)
|
val (taskFiles, taskJars, taskBytes) = Task.deserializeWithDependencies(bytes)
|
||||||
updateDependencies(taskFiles, taskJars) // Download any files added with addFile
|
updateDependencies(taskFiles, taskJars) // Download any files added with addFile
|
||||||
|
val deserStart = System.currentTimeMillis()
|
||||||
val deserializedTask = ser.deserialize[Task[_]](
|
val deserializedTask = ser.deserialize[Task[_]](
|
||||||
taskBytes, Thread.currentThread.getContextClassLoader)
|
taskBytes, Thread.currentThread.getContextClassLoader)
|
||||||
|
val deserTime = System.currentTimeMillis() - deserStart
|
||||||
|
|
||||||
// Run it
|
// Run it
|
||||||
val result: Any = deserializedTask.run(attemptId)
|
val result: Any = deserializedTask.run(attemptId)
|
||||||
|
@ -77,15 +79,19 @@ private[spark] class LocalScheduler(threads: Int, maxFailures: Int, sc: SparkCon
|
||||||
// executor does. This is useful to catch serialization errors early
|
// executor does. This is useful to catch serialization errors early
|
||||||
// on in development (so when users move their local Spark programs
|
// on in development (so when users move their local Spark programs
|
||||||
// to the cluster, they don't get surprised by serialization errors).
|
// to the cluster, they don't get surprised by serialization errors).
|
||||||
val resultToReturn = ser.deserialize[Any](ser.serialize(result))
|
val serResult = ser.serialize(result)
|
||||||
|
deserializedTask.metrics.get.resultSize = serResult.limit()
|
||||||
|
val resultToReturn = ser.deserialize[Any](serResult)
|
||||||
val accumUpdates = ser.deserialize[collection.mutable.Map[Long, Any]](
|
val accumUpdates = ser.deserialize[collection.mutable.Map[Long, Any]](
|
||||||
ser.serialize(Accumulators.values))
|
ser.serialize(Accumulators.values))
|
||||||
logInfo("Finished " + task)
|
logInfo("Finished " + task)
|
||||||
info.markSuccessful()
|
info.markSuccessful()
|
||||||
|
deserializedTask.metrics.get.executorRunTime = info.duration.toInt //close enough
|
||||||
|
deserializedTask.metrics.get.executorDeserializeTime = deserTime.toInt
|
||||||
|
|
||||||
// If the threadpool has not already been shutdown, notify DAGScheduler
|
// If the threadpool has not already been shutdown, notify DAGScheduler
|
||||||
if (!Thread.currentThread().isInterrupted)
|
if (!Thread.currentThread().isInterrupted)
|
||||||
listener.taskEnded(task, Success, resultToReturn, accumUpdates, info, null)
|
listener.taskEnded(task, Success, resultToReturn, accumUpdates, info, deserializedTask.metrics.getOrElse(null))
|
||||||
} catch {
|
} catch {
|
||||||
case t: Throwable => {
|
case t: Throwable => {
|
||||||
logError("Exception in task " + idInJob, t)
|
logError("Exception in task " + idInJob, t)
|
||||||
|
|
80
core/src/test/scala/spark/scheduler/SparkListenerSuite.scala
Normal file
80
core/src/test/scala/spark/scheduler/SparkListenerSuite.scala
Normal file
|
@ -0,0 +1,80 @@
|
||||||
|
package spark.scheduler
|
||||||
|
|
||||||
|
import org.scalatest.FunSuite
|
||||||
|
import spark.{SparkContext, LocalSparkContext}
|
||||||
|
import scala.collection.mutable
|
||||||
|
import org.scalatest.matchers.ShouldMatchers
|
||||||
|
import spark.SparkContext._
|
||||||
|
|
||||||
|
/**
|
||||||
|
*
|
||||||
|
*/
|
||||||
|
|
||||||
|
class SparkListenerSuite extends FunSuite with LocalSparkContext with ShouldMatchers {
|
||||||
|
|
||||||
|
test("local metrics") {
|
||||||
|
sc = new SparkContext("local[4]", "test")
|
||||||
|
val listener = new SaveStageInfo
|
||||||
|
sc.addSparkListener(listener)
|
||||||
|
sc.addSparkListener(new StatsReportListener)
|
||||||
|
|
||||||
|
val d = sc.parallelize(1 to 1e4.toInt, 64)
|
||||||
|
d.count
|
||||||
|
listener.stageInfos.size should be (1)
|
||||||
|
|
||||||
|
val d2 = d.map{i => i -> i * 2}.setName("shuffle input 1")
|
||||||
|
|
||||||
|
val d3 = d.map{i => i -> (0 to (i % 5))}.setName("shuffle input 2")
|
||||||
|
|
||||||
|
val d4 = d2.cogroup(d3, 64).map{case(k,(v1,v2)) => k -> (v1.size, v2.size)}
|
||||||
|
d4.setName("A Cogroup")
|
||||||
|
|
||||||
|
d4.collectAsMap
|
||||||
|
|
||||||
|
listener.stageInfos.size should be (4)
|
||||||
|
listener.stageInfos.foreach {stageInfo =>
|
||||||
|
//small test, so some tasks might take less than 1 millisecond, but average should be greater than 1 ms
|
||||||
|
checkNonZeroAvg(stageInfo.taskInfos.map{_._1.duration}, stageInfo + " duration")
|
||||||
|
checkNonZeroAvg(stageInfo.taskInfos.map{_._2.executorRunTime.toLong}, stageInfo + " executorRunTime")
|
||||||
|
checkNonZeroAvg(stageInfo.taskInfos.map{_._2.executorDeserializeTime.toLong}, stageInfo + " executorDeserializeTime")
|
||||||
|
if (stageInfo.stage.rdd.name == d4.name) {
|
||||||
|
checkNonZeroAvg(stageInfo.taskInfos.map{_._2.shuffleReadMetrics.get.fetchWaitTime}, stageInfo + " fetchWaitTime")
|
||||||
|
}
|
||||||
|
|
||||||
|
stageInfo.taskInfos.foreach{case (taskInfo, taskMetrics) =>
|
||||||
|
taskMetrics.resultSize should be > (0l)
|
||||||
|
if (isStage(stageInfo, Set(d2.name, d3.name), Set(d4.name))) {
|
||||||
|
taskMetrics.shuffleWriteMetrics should be ('defined)
|
||||||
|
taskMetrics.shuffleWriteMetrics.get.shuffleBytesWritten should be > (0l)
|
||||||
|
}
|
||||||
|
if (stageInfo.stage.rdd.name == d4.name) {
|
||||||
|
taskMetrics.shuffleReadMetrics should be ('defined)
|
||||||
|
val sm = taskMetrics.shuffleReadMetrics.get
|
||||||
|
sm.totalBlocksFetched should be > (0)
|
||||||
|
sm.shuffleReadMillis should be > (0l)
|
||||||
|
sm.localBlocksFetched should be > (0)
|
||||||
|
sm.remoteBlocksFetched should be (0)
|
||||||
|
sm.remoteBytesRead should be (0l)
|
||||||
|
sm.remoteFetchTime should be (0l)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
def checkNonZeroAvg(m: Traversable[Long], msg: String) {
|
||||||
|
assert(m.sum / m.size.toDouble > 0.0, msg)
|
||||||
|
}
|
||||||
|
|
||||||
|
def isStage(stageInfo: StageInfo, rddNames: Set[String], excludedNames: Set[String]) = {
|
||||||
|
val names = Set(stageInfo.stage.rdd.name) ++ stageInfo.stage.rdd.dependencies.map{_.rdd.name}
|
||||||
|
!names.intersect(rddNames).isEmpty && names.intersect(excludedNames).isEmpty
|
||||||
|
}
|
||||||
|
|
||||||
|
class SaveStageInfo extends SparkListener {
|
||||||
|
val stageInfos = mutable.Buffer[StageInfo]()
|
||||||
|
def onStageCompleted(stage: StageCompleted) {
|
||||||
|
stageInfos += stage.stageInfo
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
Loading…
Reference in a new issue