[SPARK-29562][SQL] Speed up and slim down metric aggregation in SQL listener

First, a bit of background on the code being changed. The current code tracks
metric updates for each task, recording which metrics the task is monitoring
and the last update value.

Once a SQL execution finishes, then the metrics for all the stages are
aggregated, by building a list with all (metric ID, value) pairs collected
for all tasks in the stages related to the execution, then grouping by metric
ID, and then calculating the values shown in the UI.

That is full of inefficiencies:

- in normal operation, all tasks will be tracking and updating the same
  metrics. So recording the metric IDs per task is wasteful.
- tracking by task means we might be double-counting values if you have
  speculative tasks (as a comment in the code mentions).
- creating a list of (metric ID, value) is extremely inefficient, because now
  you have a huge map in memory storing boxed versions of the metric IDs and
  values.
- same thing for the aggregation part, where now a Seq is built with the values
  for each metric ID.

The end result is that for large queries, this code can become both really
slow, thus affecting the processing of events, and memory hungry.

The updated code changes the approach to the following:

- stages track metrics by their ID; this means the stage tracking code
  naturally groups values, making aggregation later simpler.
- each metric ID being tracked uses a long array matching the number of
  partitions of the stage; this means that it's cheap to update the value of
  the metric once a task ends.
- when aggregating, custom code just concatenates the arrays corresponding to
  the matching metric IDs; this is cheaper than the previous, boxing-heavy
  approach.

The end result is that the listener uses about half as much memory as before
for tracking metrics, since it doesn't need to track metric IDs per task.

I captured heap dumps with the old and the new code during metric aggregation
in the listener, for an execution with 3 stages, 100k tasks per stage, 50
metrics updated per task. The dumps contained just reachable memory - so data
kept by the listener plus the variables in the aggregateMetrics() method.

With the old code, the thread doing aggregation references >1G of memory - and
that does not include temporary data created by the "groupBy" transformation
(for which the intermediate state is not referenced in the aggregation method).
The same thread with the new code references ~250M of memory. The old code uses
about ~250M to track all the metric values for that execution, while the new
code uses about ~130M. (Note the per-thread numbers include the amount used to
track the metrics - so, e.g., in the old case, aggregation was referencing
about ~750M of temporary data.)

I'm also including a small benchmark (based on the Benchmark class) so that we
can measure how much changes to this code affect performance. The benchmark
contains some extra code to measure things the normal Benchmark class does not,
given that the code under test does not really map that well to the
expectations of that class.

Running with the old code (I removed results that don't make much
sense for this benchmark):

```
[info] Java HotSpot(TM) 64-Bit Server VM 1.8.0_181-b13 on Linux 4.15.0-66-generic
[info] Intel(R) Core(TM) i7-6820HQ CPU  2.70GHz
[info] metrics aggregation (50 metrics, 100k tasks per stage):  Best Time(ms)   Avg Time(ms)
[info] --------------------------------------------------------------------------------------
[info] 1 stage(s)                                                  2113           2118
[info] 2 stage(s)                                                  4172           4392
[info] 3 stage(s)                                                  7755           8460
[info]
[info] Stage Count    Stage Proc. Time    Aggreg. Time
[info]      1              614                1187
[info]      2              620                2480
[info]      3              718                5069
```

With the new code:

```
[info] Java HotSpot(TM) 64-Bit Server VM 1.8.0_181-b13 on Linux 4.15.0-66-generic
[info] Intel(R) Core(TM) i7-6820HQ CPU  2.70GHz
[info] metrics aggregation (50 metrics, 100k tasks per stage):  Best Time(ms)   Avg Time(ms)
[info] --------------------------------------------------------------------------------------
[info] 1 stage(s)                                                   727            886
[info] 2 stage(s)                                                  1722           1983
[info] 3 stage(s)                                                  2752           3013
[info]
[info] Stage Count    Stage Proc. Time    Aggreg. Time
[info]      1              408                177
[info]      2              389                423
[info]      3              372                660

```

So the new code is faster than the old when processing task events, and about
an order of maginute faster when aggregating metrics.

Note this still leaves room for improvement; for example, using the above
measurements, 600ms is still a huge amount of time to spend in an event
handler. But I'll leave further enhancements for a separate change.

Tested with benchmarking code + existing unit tests.

Closes #26218 from vanzin/SPARK-29562.

Authored-by: Marcelo Vanzin <vanzin@cloudera.com>
Signed-off-by: Dongjoon Hyun <dhyun@apple.com>
This commit is contained in:
Marcelo Vanzin 2019-10-24 22:18:10 -07:00 committed by Dongjoon Hyun
parent 7417c3e7d5
commit 1474ed05fb
7 changed files with 396 additions and 71 deletions

View file

@ -0,0 +1,12 @@
OpenJDK 64-Bit Server VM 11.0.4+11 on Linux 4.15.0-66-generic
Intel(R) Core(TM) i7-6820HQ CPU @ 2.70GHz
metrics aggregation (50 metrics, 100000 tasks per stage): Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative
------------------------------------------------------------------------------------------------------------------------
1 stage(s) 672 841 179 0.0 671888474.0 1.0X
2 stage(s) 1700 1842 201 0.0 1699591662.0 0.4X
3 stage(s) 2601 2776 247 0.0 2601465786.0 0.3X
Stage Count Stage Proc. Time Aggreg. Time
1 436 164
2 537 354
3 480 602

View file

@ -0,0 +1,12 @@
Java HotSpot(TM) 64-Bit Server VM 1.8.0_181-b13 on Linux 4.15.0-66-generic
Intel(R) Core(TM) i7-6820HQ CPU @ 2.70GHz
metrics aggregation (50 metrics, 100000 tasks per stage): Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative
------------------------------------------------------------------------------------------------------------------------
1 stage(s) 740 883 147 0.0 740089816.0 1.0X
2 stage(s) 1661 1943 399 0.0 1660649192.0 0.4X
3 stage(s) 2711 2967 362 0.0 2711110178.0 0.3X
Stage Count Stage Proc. Time Aggreg. Time
1 405 179
2 375 414
3 364 644

View file

@ -18,7 +18,7 @@
package org.apache.spark.sql.execution.metric
import java.text.NumberFormat
import java.util.Locale
import java.util.{Arrays, Locale}
import scala.concurrent.duration._
@ -150,7 +150,7 @@ object SQLMetrics {
* A function that defines how we aggregate the final accumulator results among all tasks,
* and represent it in string for a SQL physical operator.
*/
def stringValue(metricsType: String, values: Seq[Long]): String = {
def stringValue(metricsType: String, values: Array[Long]): String = {
if (metricsType == SUM_METRIC) {
val numberFormat = NumberFormat.getIntegerInstance(Locale.US)
numberFormat.format(values.sum)
@ -162,8 +162,9 @@ object SQLMetrics {
val metric = if (validValues.isEmpty) {
Seq.fill(3)(0L)
} else {
val sorted = validValues.sorted
Seq(sorted(0), sorted(validValues.length / 2), sorted(validValues.length - 1))
Arrays.sort(validValues)
Seq(validValues(0), validValues(validValues.length / 2),
validValues(validValues.length - 1))
}
metric.map(v => numberFormat.format(v.toDouble / baseForAvgMetric))
}
@ -184,8 +185,9 @@ object SQLMetrics {
val metric = if (validValues.isEmpty) {
Seq.fill(4)(0L)
} else {
val sorted = validValues.sorted
Seq(sorted.sum, sorted(0), sorted(validValues.length / 2), sorted(validValues.length - 1))
Arrays.sort(validValues)
Seq(validValues.sum, validValues(0), validValues(validValues.length / 2),
validValues(validValues.length - 1))
}
metric.map(strFormat)
}

View file

@ -16,10 +16,11 @@
*/
package org.apache.spark.sql.execution.ui
import java.util.{Date, NoSuchElementException}
import java.util.{Arrays, Date, NoSuchElementException}
import java.util.concurrent.ConcurrentHashMap
import scala.collection.JavaConverters._
import scala.collection.mutable
import org.apache.spark.{JobExecutionStatus, SparkConf}
import org.apache.spark.internal.Logging
@ -29,6 +30,7 @@ import org.apache.spark.sql.execution.SQLExecution
import org.apache.spark.sql.execution.metric._
import org.apache.spark.sql.internal.StaticSQLConf._
import org.apache.spark.status.{ElementTrackingStore, KVUtils, LiveEntity}
import org.apache.spark.util.collection.OpenHashMap
class SQLAppStatusListener(
conf: SparkConf,
@ -103,8 +105,10 @@ class SQLAppStatusListener(
// Record the accumulator IDs for the stages of this job, so that the code that keeps
// track of the metrics knows which accumulators to look at.
val accumIds = exec.metrics.map(_.accumulatorId).toSet
event.stageIds.foreach { id =>
stageMetrics.put(id, new LiveStageMetrics(id, 0, accumIds, new ConcurrentHashMap()))
if (accumIds.nonEmpty) {
event.stageInfos.foreach { stage =>
stageMetrics.put(stage.stageId, new LiveStageMetrics(0, stage.numTasks, accumIds))
}
}
exec.jobs = exec.jobs + (jobId -> JobExecutionStatus.RUNNING)
@ -118,9 +122,11 @@ class SQLAppStatusListener(
}
// Reset the metrics tracking object for the new attempt.
Option(stageMetrics.get(event.stageInfo.stageId)).foreach { metrics =>
metrics.taskMetrics.clear()
metrics.attemptId = event.stageInfo.attemptNumber
Option(stageMetrics.get(event.stageInfo.stageId)).foreach { stage =>
if (stage.attemptId != event.stageInfo.attemptNumber) {
stageMetrics.put(event.stageInfo.stageId,
new LiveStageMetrics(event.stageInfo.attemptNumber, stage.numTasks, stage.accumulatorIds))
}
}
}
@ -140,7 +146,16 @@ class SQLAppStatusListener(
override def onExecutorMetricsUpdate(event: SparkListenerExecutorMetricsUpdate): Unit = {
event.accumUpdates.foreach { case (taskId, stageId, attemptId, accumUpdates) =>
updateStageMetrics(stageId, attemptId, taskId, accumUpdates, false)
updateStageMetrics(stageId, attemptId, taskId, SQLAppStatusListener.UNKNOWN_INDEX,
accumUpdates, false)
}
}
override def onTaskStart(event: SparkListenerTaskStart): Unit = {
Option(stageMetrics.get(event.stageId)).foreach { stage =>
if (stage.attemptId == event.stageAttemptId) {
stage.registerTask(event.taskInfo.taskId, event.taskInfo.index)
}
}
}
@ -165,7 +180,7 @@ class SQLAppStatusListener(
} else {
info.accumulables
}
updateStageMetrics(event.stageId, event.stageAttemptId, info.taskId, accums,
updateStageMetrics(event.stageId, event.stageAttemptId, info.taskId, info.index, accums,
info.successful)
}
@ -181,17 +196,40 @@ class SQLAppStatusListener(
private def aggregateMetrics(exec: LiveExecutionData): Map[Long, String] = {
val metricTypes = exec.metrics.map { m => (m.accumulatorId, m.metricType) }.toMap
val metrics = exec.stages.toSeq
.flatMap { stageId => Option(stageMetrics.get(stageId)) }
.flatMap(_.taskMetrics.values().asScala)
.flatMap { metrics => metrics.ids.zip(metrics.values) }
val aggregatedMetrics = (metrics ++ exec.driverAccumUpdates.toSeq)
.filter { case (id, _) => metricTypes.contains(id) }
.groupBy(_._1)
.map { case (id, values) =>
id -> SQLMetrics.stringValue(metricTypes(id), values.map(_._2))
val taskMetrics = exec.stages.toSeq
.flatMap { stageId => Option(stageMetrics.get(stageId)) }
.flatMap(_.metricValues())
val allMetrics = new mutable.HashMap[Long, Array[Long]]()
taskMetrics.foreach { case (id, values) =>
val prev = allMetrics.getOrElse(id, null)
val updated = if (prev != null) {
prev ++ values
} else {
values
}
allMetrics(id) = updated
}
exec.driverAccumUpdates.foreach { case (id, value) =>
if (metricTypes.contains(id)) {
val prev = allMetrics.getOrElse(id, null)
val updated = if (prev != null) {
val _copy = Arrays.copyOf(prev, prev.length + 1)
_copy(prev.length) = value
_copy
} else {
Array(value)
}
allMetrics(id) = updated
}
}
val aggregatedMetrics = allMetrics.map { case (id, values) =>
id -> SQLMetrics.stringValue(metricTypes(id), values)
}.toMap
// Check the execution again for whether the aggregated metrics data has been calculated.
// This can happen if the UI is requesting this data, and the onExecutionEnd handler is
@ -208,43 +246,13 @@ class SQLAppStatusListener(
stageId: Int,
attemptId: Int,
taskId: Long,
taskIdx: Int,
accumUpdates: Seq[AccumulableInfo],
succeeded: Boolean): Unit = {
Option(stageMetrics.get(stageId)).foreach { metrics =>
if (metrics.attemptId != attemptId || metrics.accumulatorIds.isEmpty) {
return
if (metrics.attemptId == attemptId) {
metrics.updateTaskMetrics(taskId, taskIdx, succeeded, accumUpdates)
}
val oldTaskMetrics = metrics.taskMetrics.get(taskId)
if (oldTaskMetrics != null && oldTaskMetrics.succeeded) {
return
}
val updates = accumUpdates
.filter { acc => acc.update.isDefined && metrics.accumulatorIds.contains(acc.id) }
.sortBy(_.id)
if (updates.isEmpty) {
return
}
val ids = new Array[Long](updates.size)
val values = new Array[Long](updates.size)
updates.zipWithIndex.foreach { case (acc, idx) =>
ids(idx) = acc.id
// In a live application, accumulators have Long values, but when reading from event
// logs, they have String values. For now, assume all accumulators are Long and covert
// accordingly.
values(idx) = acc.update.get match {
case s: String => s.toLong
case l: Long => l
case o => throw new IllegalArgumentException(s"Unexpected: $o")
}
}
// TODO: storing metrics by task ID can cause metrics for the same task index to be
// counted multiple times, for example due to speculation or re-attempts.
metrics.taskMetrics.put(taskId, new LiveTaskMetrics(ids, values, succeeded))
}
}
@ -425,12 +433,76 @@ private class LiveExecutionData(val executionId: Long) extends LiveEntity {
}
private class LiveStageMetrics(
val stageId: Int,
var attemptId: Int,
val accumulatorIds: Set[Long],
val taskMetrics: ConcurrentHashMap[Long, LiveTaskMetrics])
val attemptId: Int,
val numTasks: Int,
val accumulatorIds: Set[Long]) {
private class LiveTaskMetrics(
val ids: Array[Long],
val values: Array[Long],
val succeeded: Boolean)
/**
* Mapping of task IDs to their respective index. Note this may contain more elements than the
* stage's number of tasks, if speculative execution is on.
*/
private val taskIndices = new OpenHashMap[Long, Int]()
/** Bit set tracking which indices have been successfully computed. */
private val completedIndices = new mutable.BitSet()
/**
* Task metrics values for the stage. Maps the metric ID to the metric values for each
* index. For each metric ID, there will be the same number of values as the number
* of indices. This relies on `SQLMetrics.stringValue` treating 0 as a neutral value,
* independent of the actual metric type.
*/
private val taskMetrics = new ConcurrentHashMap[Long, Array[Long]]()
def registerTask(taskId: Long, taskIdx: Int): Unit = {
taskIndices.update(taskId, taskIdx)
}
def updateTaskMetrics(
taskId: Long,
eventIdx: Int,
finished: Boolean,
accumUpdates: Seq[AccumulableInfo]): Unit = {
val taskIdx = if (eventIdx == SQLAppStatusListener.UNKNOWN_INDEX) {
if (!taskIndices.contains(taskId)) {
// We probably missed the start event for the task, just ignore it.
return
}
taskIndices(taskId)
} else {
// Here we can recover from a missing task start event. Just register the task again.
registerTask(taskId, eventIdx)
eventIdx
}
if (completedIndices.contains(taskIdx)) {
return
}
accumUpdates
.filter { acc => acc.update.isDefined && accumulatorIds.contains(acc.id) }
.foreach { acc =>
// In a live application, accumulators have Long values, but when reading from event
// logs, they have String values. For now, assume all accumulators are Long and convert
// accordingly.
val value = acc.update.get match {
case s: String => s.toLong
case l: Long => l
case o => throw new IllegalArgumentException(s"Unexpected: $o")
}
val metricValues = taskMetrics.computeIfAbsent(acc.id, _ => new Array(numTasks))
metricValues(taskIdx) = value
}
if (finished) {
completedIndices += taskIdx
}
}
def metricValues(): Seq[(Long, Array[Long])] = taskMetrics.asScala.toSeq
}
private object SQLAppStatusListener {
val UNKNOWN_INDEX = -1
}

View file

@ -232,7 +232,8 @@ trait SQLMetricsTestUtils extends SQLTestUtils {
val (actualNodeName, actualMetricsMap) = actualMetrics(nodeId)
assert(expectedNodeName === actualNodeName)
for ((metricName, metricPredicate) <- expectedMetricsPredicatesMap) {
assert(metricPredicate(actualMetricsMap(metricName)))
assert(metricPredicate(actualMetricsMap(metricName)),
s"$nodeId / '$metricName' (= ${actualMetricsMap(metricName)}) did not match predicate.")
}
}
}

View file

@ -0,0 +1,219 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.sql.execution.ui
import java.util.Properties
import java.util.concurrent.atomic.AtomicInteger
import scala.collection.mutable
import scala.concurrent.duration._
import org.apache.spark.{SparkConf, TaskState}
import org.apache.spark.benchmark.{Benchmark, BenchmarkBase}
import org.apache.spark.executor.ExecutorMetrics
import org.apache.spark.internal.config.Status._
import org.apache.spark.scheduler._
import org.apache.spark.sql.execution._
import org.apache.spark.sql.execution.metric.SQLMetricInfo
import org.apache.spark.status.ElementTrackingStore
import org.apache.spark.util.{AccumulatorMetadata, LongAccumulator, Utils}
import org.apache.spark.util.kvstore.InMemoryStore
/**
* Benchmark for metrics aggregation in the SQL listener.
* {{{
* To run this benchmark:
* 1. without sbt: bin/spark-submit --class <this class> --jars <core test jar>
* 2. build/sbt "core/test:runMain <this class>"
* 3. generate result: SPARK_GENERATE_BENCHMARK_FILES=1 build/sbt "core/test:runMain <this class>"
* Results will be written to "benchmarks/MetricsAggregationBenchmark-results.txt".
* }}}
*/
object MetricsAggregationBenchmark extends BenchmarkBase {
private def metricTrackingBenchmark(
timer: Benchmark.Timer,
numMetrics: Int,
numTasks: Int,
numStages: Int): Measurements = {
val conf = new SparkConf()
.set(LIVE_ENTITY_UPDATE_PERIOD, 0L)
.set(ASYNC_TRACKING_ENABLED, false)
val kvstore = new ElementTrackingStore(new InMemoryStore(), conf)
val listener = new SQLAppStatusListener(conf, kvstore, live = true)
val store = new SQLAppStatusStore(kvstore, Some(listener))
val metrics = (0 until numMetrics).map { i =>
new SQLMetricInfo(s"metric$i", i.toLong, "average")
}
val planInfo = new SparkPlanInfo(
getClass().getName(),
getClass().getName(),
Nil,
Map.empty,
metrics)
val idgen = new AtomicInteger()
val executionId = idgen.incrementAndGet()
val executionStart = SparkListenerSQLExecutionStart(
executionId,
getClass().getName(),
getClass().getName(),
getClass().getName(),
planInfo,
System.currentTimeMillis())
val executionEnd = SparkListenerSQLExecutionEnd(executionId, System.currentTimeMillis())
val properties = new Properties()
properties.setProperty(SQLExecution.EXECUTION_ID_KEY, executionId.toString)
timer.startTiming()
listener.onOtherEvent(executionStart)
val taskEventsTime = (0 until numStages).map { _ =>
val stageInfo = new StageInfo(idgen.incrementAndGet(), 0, getClass().getName(),
numTasks, Nil, Nil, getClass().getName())
val jobId = idgen.incrementAndGet()
val jobStart = SparkListenerJobStart(
jobId = jobId,
time = System.currentTimeMillis(),
stageInfos = Seq(stageInfo),
properties)
val stageStart = SparkListenerStageSubmitted(stageInfo)
val taskOffset = idgen.incrementAndGet().toLong
val taskEvents = (0 until numTasks).map { i =>
val info = new TaskInfo(
taskId = taskOffset + i.toLong,
index = i,
attemptNumber = 0,
// The following fields are not used.
launchTime = 0,
executorId = "",
host = "",
taskLocality = null,
speculative = false)
info.markFinished(TaskState.FINISHED, 1L)
val accumulables = (0 until numMetrics).map { mid =>
val acc = new LongAccumulator
acc.metadata = AccumulatorMetadata(mid, None, false)
acc.toInfo(Some(i.toLong), None)
}
info.setAccumulables(accumulables)
val start = SparkListenerTaskStart(stageInfo.stageId, stageInfo.attemptNumber, info)
val end = SparkListenerTaskEnd(stageInfo.stageId, stageInfo.attemptNumber,
taskType = "",
reason = null,
info,
new ExecutorMetrics(),
null)
(start, end)
}
val jobEnd = SparkListenerJobEnd(
jobId = jobId,
time = System.currentTimeMillis(),
JobSucceeded)
listener.onJobStart(jobStart)
listener.onStageSubmitted(stageStart)
val (_, _taskEventsTime) = Utils.timeTakenMs {
taskEvents.foreach { case (start, end) =>
listener.onTaskStart(start)
listener.onTaskEnd(end)
}
}
listener.onJobEnd(jobEnd)
_taskEventsTime
}
val (_, aggTime) = Utils.timeTakenMs {
listener.onOtherEvent(executionEnd)
val metrics = store.executionMetrics(executionId)
assert(metrics.size == numMetrics, s"${metrics.size} != $numMetrics")
}
timer.stopTiming()
kvstore.close()
Measurements(taskEventsTime, aggTime)
}
override def runBenchmarkSuite(mainArgs: Array[String]): Unit = {
val metricCount = 50
val taskCount = 100000
val stageCounts = Seq(1, 2, 3)
val benchmark = new Benchmark(
s"metrics aggregation ($metricCount metrics, $taskCount tasks per stage)", 1,
warmupTime = 0.seconds, output = output)
// Run this outside the measurement code so that classes are loaded and JIT is triggered,
// otherwise the first run tends to be much slower than others. Also because this benchmark is a
// bit weird and doesn't really map to what the Benchmark class expects, so it's a bit harder
// to use warmupTime and friends effectively.
stageCounts.foreach { count =>
metricTrackingBenchmark(new Benchmark.Timer(-1), metricCount, taskCount, count)
}
val measurements = mutable.HashMap[Int, Seq[Measurements]]()
stageCounts.foreach { count =>
benchmark.addTimerCase(s"$count stage(s)") { timer =>
val m = metricTrackingBenchmark(timer, metricCount, taskCount, count)
val all = measurements.getOrElse(count, Nil)
measurements(count) = all ++ Seq(m)
}
}
benchmark.run()
benchmark.out.printf("Stage Count Stage Proc. Time Aggreg. Time\n")
stageCounts.foreach { count =>
val data = measurements(count)
val eventsTimes = data.flatMap(_.taskEventsTimes)
val aggTimes = data.map(_.aggregationTime)
val msg = " %d %d %d\n".format(
count,
eventsTimes.sum / eventsTimes.size,
aggTimes.sum / aggTimes.size)
benchmark.out.printf(msg)
}
}
/**
* Finer-grained measurements of how long it takes to run some parts of the benchmark. This is
* collected by the benchmark method, so this collection slightly affects the overall benchmark
* results, but this data helps with seeing where the time is going, since this benchmark is
* triggering a whole lot of code in the listener class.
*/
case class Measurements(
taskEventsTimes: Seq[Long],
aggregationTime: Long)
}

View file

@ -79,9 +79,9 @@ class SQLAppStatusListenerSuite extends SharedSparkSession with JsonTestUtils
private def createStageInfo(stageId: Int, attemptId: Int): StageInfo = {
new StageInfo(stageId = stageId,
attemptId = attemptId,
numTasks = 8,
// The following fields are not used in tests
name = "",
numTasks = 0,
rddInfos = Nil,
parentIds = Nil,
details = "")
@ -94,8 +94,8 @@ class SQLAppStatusListenerSuite extends SharedSparkSession with JsonTestUtils
val info = new TaskInfo(
taskId = taskId,
attemptNumber = attemptNumber,
index = taskId.toInt,
// The following fields are not used in tests
index = 0,
launchTime = 0,
executorId = "",
host = "",
@ -190,6 +190,8 @@ class SQLAppStatusListenerSuite extends SharedSparkSession with JsonTestUtils
),
createProperties(executionId)))
listener.onStageSubmitted(SparkListenerStageSubmitted(createStageInfo(0, 0)))
listener.onTaskStart(SparkListenerTaskStart(0, 0, createTaskInfo(0, 0)))
listener.onTaskStart(SparkListenerTaskStart(0, 0, createTaskInfo(1, 0)))
assert(statusStore.executionMetrics(executionId).isEmpty)
@ -217,6 +219,8 @@ class SQLAppStatusListenerSuite extends SharedSparkSession with JsonTestUtils
// Retrying a stage should reset the metrics
listener.onStageSubmitted(SparkListenerStageSubmitted(createStageInfo(0, 1)))
listener.onTaskStart(SparkListenerTaskStart(0, 1, createTaskInfo(0, 0)))
listener.onTaskStart(SparkListenerTaskStart(0, 1, createTaskInfo(1, 0)))
listener.onExecutorMetricsUpdate(SparkListenerExecutorMetricsUpdate("", Seq(
// (task id, stage id, stage attempt, accum updates)
@ -260,6 +264,8 @@ class SQLAppStatusListenerSuite extends SharedSparkSession with JsonTestUtils
// Summit a new stage
listener.onStageSubmitted(SparkListenerStageSubmitted(createStageInfo(1, 0)))
listener.onTaskStart(SparkListenerTaskStart(1, 0, createTaskInfo(0, 0)))
listener.onTaskStart(SparkListenerTaskStart(1, 0, createTaskInfo(1, 0)))
listener.onExecutorMetricsUpdate(SparkListenerExecutorMetricsUpdate("", Seq(
// (task id, stage id, stage attempt, accum updates)
@ -490,8 +496,8 @@ class SQLAppStatusListenerSuite extends SharedSparkSession with JsonTestUtils
val statusStore = spark.sharedState.statusStore
val oldCount = statusStore.executionsList().size
val expectedAccumValue = 12345
val expectedAccumValue2 = 54321
val expectedAccumValue = 12345L
val expectedAccumValue2 = 54321L
val physicalPlan = MyPlan(sqlContext.sparkContext, expectedAccumValue, expectedAccumValue2)
val dummyQueryExecution = new QueryExecution(spark, LocalRelation()) {
override lazy val sparkPlan = physicalPlan
@ -517,8 +523,9 @@ class SQLAppStatusListenerSuite extends SharedSparkSession with JsonTestUtils
val metrics = statusStore.executionMetrics(execId)
val driverMetric = physicalPlan.metrics("dummy")
val driverMetric2 = physicalPlan.metrics("dummy2")
val expectedValue = SQLMetrics.stringValue(driverMetric.metricType, Seq(expectedAccumValue))
val expectedValue2 = SQLMetrics.stringValue(driverMetric2.metricType, Seq(expectedAccumValue2))
val expectedValue = SQLMetrics.stringValue(driverMetric.metricType, Array(expectedAccumValue))
val expectedValue2 = SQLMetrics.stringValue(driverMetric2.metricType,
Array(expectedAccumValue2))
assert(metrics.contains(driverMetric.id))
assert(metrics(driverMetric.id) === expectedValue)