diff --git a/sql/core/benchmarks/MetricsAggregationBenchmark-jdk11-results.txt b/sql/core/benchmarks/MetricsAggregationBenchmark-jdk11-results.txt new file mode 100644 index 0000000000..e33ed30eaa --- /dev/null +++ b/sql/core/benchmarks/MetricsAggregationBenchmark-jdk11-results.txt @@ -0,0 +1,12 @@ +OpenJDK 64-Bit Server VM 11.0.4+11 on Linux 4.15.0-66-generic +Intel(R) Core(TM) i7-6820HQ CPU @ 2.70GHz +metrics aggregation (50 metrics, 100000 tasks per stage): Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------------------------------ +1 stage(s) 672 841 179 0.0 671888474.0 1.0X +2 stage(s) 1700 1842 201 0.0 1699591662.0 0.4X +3 stage(s) 2601 2776 247 0.0 2601465786.0 0.3X + +Stage Count Stage Proc. Time Aggreg. Time + 1 436 164 + 2 537 354 + 3 480 602 diff --git a/sql/core/benchmarks/MetricsAggregationBenchmark-results.txt b/sql/core/benchmarks/MetricsAggregationBenchmark-results.txt new file mode 100644 index 0000000000..4fae928258 --- /dev/null +++ b/sql/core/benchmarks/MetricsAggregationBenchmark-results.txt @@ -0,0 +1,12 @@ +Java HotSpot(TM) 64-Bit Server VM 1.8.0_181-b13 on Linux 4.15.0-66-generic +Intel(R) Core(TM) i7-6820HQ CPU @ 2.70GHz +metrics aggregation (50 metrics, 100000 tasks per stage): Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------------------------------ +1 stage(s) 740 883 147 0.0 740089816.0 1.0X +2 stage(s) 1661 1943 399 0.0 1660649192.0 0.4X +3 stage(s) 2711 2967 362 0.0 2711110178.0 0.3X + +Stage Count Stage Proc. Time Aggreg. Time + 1 405 179 + 2 375 414 + 3 364 644 diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/metric/SQLMetrics.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/metric/SQLMetrics.scala index 19809b0750..b7f0ab2969 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/metric/SQLMetrics.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/metric/SQLMetrics.scala @@ -18,7 +18,7 @@ package org.apache.spark.sql.execution.metric import java.text.NumberFormat -import java.util.Locale +import java.util.{Arrays, Locale} import scala.concurrent.duration._ @@ -150,7 +150,7 @@ object SQLMetrics { * A function that defines how we aggregate the final accumulator results among all tasks, * and represent it in string for a SQL physical operator. */ - def stringValue(metricsType: String, values: Seq[Long]): String = { + def stringValue(metricsType: String, values: Array[Long]): String = { if (metricsType == SUM_METRIC) { val numberFormat = NumberFormat.getIntegerInstance(Locale.US) numberFormat.format(values.sum) @@ -162,8 +162,9 @@ object SQLMetrics { val metric = if (validValues.isEmpty) { Seq.fill(3)(0L) } else { - val sorted = validValues.sorted - Seq(sorted(0), sorted(validValues.length / 2), sorted(validValues.length - 1)) + Arrays.sort(validValues) + Seq(validValues(0), validValues(validValues.length / 2), + validValues(validValues.length - 1)) } metric.map(v => numberFormat.format(v.toDouble / baseForAvgMetric)) } @@ -184,8 +185,9 @@ object SQLMetrics { val metric = if (validValues.isEmpty) { Seq.fill(4)(0L) } else { - val sorted = validValues.sorted - Seq(sorted.sum, sorted(0), sorted(validValues.length / 2), sorted(validValues.length - 1)) + Arrays.sort(validValues) + Seq(validValues.sum, validValues(0), validValues(validValues.length / 2), + validValues(validValues.length - 1)) } metric.map(strFormat) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLAppStatusListener.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLAppStatusListener.scala index 2c4a7eacdf..da526612e7 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLAppStatusListener.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLAppStatusListener.scala @@ -16,10 +16,11 @@ */ package org.apache.spark.sql.execution.ui -import java.util.{Date, NoSuchElementException} +import java.util.{Arrays, Date, NoSuchElementException} import java.util.concurrent.ConcurrentHashMap import scala.collection.JavaConverters._ +import scala.collection.mutable import org.apache.spark.{JobExecutionStatus, SparkConf} import org.apache.spark.internal.Logging @@ -29,6 +30,7 @@ import org.apache.spark.sql.execution.SQLExecution import org.apache.spark.sql.execution.metric._ import org.apache.spark.sql.internal.StaticSQLConf._ import org.apache.spark.status.{ElementTrackingStore, KVUtils, LiveEntity} +import org.apache.spark.util.collection.OpenHashMap class SQLAppStatusListener( conf: SparkConf, @@ -103,8 +105,10 @@ class SQLAppStatusListener( // Record the accumulator IDs for the stages of this job, so that the code that keeps // track of the metrics knows which accumulators to look at. val accumIds = exec.metrics.map(_.accumulatorId).toSet - event.stageIds.foreach { id => - stageMetrics.put(id, new LiveStageMetrics(id, 0, accumIds, new ConcurrentHashMap())) + if (accumIds.nonEmpty) { + event.stageInfos.foreach { stage => + stageMetrics.put(stage.stageId, new LiveStageMetrics(0, stage.numTasks, accumIds)) + } } exec.jobs = exec.jobs + (jobId -> JobExecutionStatus.RUNNING) @@ -118,9 +122,11 @@ class SQLAppStatusListener( } // Reset the metrics tracking object for the new attempt. - Option(stageMetrics.get(event.stageInfo.stageId)).foreach { metrics => - metrics.taskMetrics.clear() - metrics.attemptId = event.stageInfo.attemptNumber + Option(stageMetrics.get(event.stageInfo.stageId)).foreach { stage => + if (stage.attemptId != event.stageInfo.attemptNumber) { + stageMetrics.put(event.stageInfo.stageId, + new LiveStageMetrics(event.stageInfo.attemptNumber, stage.numTasks, stage.accumulatorIds)) + } } } @@ -140,7 +146,16 @@ class SQLAppStatusListener( override def onExecutorMetricsUpdate(event: SparkListenerExecutorMetricsUpdate): Unit = { event.accumUpdates.foreach { case (taskId, stageId, attemptId, accumUpdates) => - updateStageMetrics(stageId, attemptId, taskId, accumUpdates, false) + updateStageMetrics(stageId, attemptId, taskId, SQLAppStatusListener.UNKNOWN_INDEX, + accumUpdates, false) + } + } + + override def onTaskStart(event: SparkListenerTaskStart): Unit = { + Option(stageMetrics.get(event.stageId)).foreach { stage => + if (stage.attemptId == event.stageAttemptId) { + stage.registerTask(event.taskInfo.taskId, event.taskInfo.index) + } } } @@ -165,7 +180,7 @@ class SQLAppStatusListener( } else { info.accumulables } - updateStageMetrics(event.stageId, event.stageAttemptId, info.taskId, accums, + updateStageMetrics(event.stageId, event.stageAttemptId, info.taskId, info.index, accums, info.successful) } @@ -181,17 +196,40 @@ class SQLAppStatusListener( private def aggregateMetrics(exec: LiveExecutionData): Map[Long, String] = { val metricTypes = exec.metrics.map { m => (m.accumulatorId, m.metricType) }.toMap - val metrics = exec.stages.toSeq - .flatMap { stageId => Option(stageMetrics.get(stageId)) } - .flatMap(_.taskMetrics.values().asScala) - .flatMap { metrics => metrics.ids.zip(metrics.values) } - val aggregatedMetrics = (metrics ++ exec.driverAccumUpdates.toSeq) - .filter { case (id, _) => metricTypes.contains(id) } - .groupBy(_._1) - .map { case (id, values) => - id -> SQLMetrics.stringValue(metricTypes(id), values.map(_._2)) + val taskMetrics = exec.stages.toSeq + .flatMap { stageId => Option(stageMetrics.get(stageId)) } + .flatMap(_.metricValues()) + + val allMetrics = new mutable.HashMap[Long, Array[Long]]() + + taskMetrics.foreach { case (id, values) => + val prev = allMetrics.getOrElse(id, null) + val updated = if (prev != null) { + prev ++ values + } else { + values } + allMetrics(id) = updated + } + + exec.driverAccumUpdates.foreach { case (id, value) => + if (metricTypes.contains(id)) { + val prev = allMetrics.getOrElse(id, null) + val updated = if (prev != null) { + val _copy = Arrays.copyOf(prev, prev.length + 1) + _copy(prev.length) = value + _copy + } else { + Array(value) + } + allMetrics(id) = updated + } + } + + val aggregatedMetrics = allMetrics.map { case (id, values) => + id -> SQLMetrics.stringValue(metricTypes(id), values) + }.toMap // Check the execution again for whether the aggregated metrics data has been calculated. // This can happen if the UI is requesting this data, and the onExecutionEnd handler is @@ -208,43 +246,13 @@ class SQLAppStatusListener( stageId: Int, attemptId: Int, taskId: Long, + taskIdx: Int, accumUpdates: Seq[AccumulableInfo], succeeded: Boolean): Unit = { Option(stageMetrics.get(stageId)).foreach { metrics => - if (metrics.attemptId != attemptId || metrics.accumulatorIds.isEmpty) { - return + if (metrics.attemptId == attemptId) { + metrics.updateTaskMetrics(taskId, taskIdx, succeeded, accumUpdates) } - - val oldTaskMetrics = metrics.taskMetrics.get(taskId) - if (oldTaskMetrics != null && oldTaskMetrics.succeeded) { - return - } - - val updates = accumUpdates - .filter { acc => acc.update.isDefined && metrics.accumulatorIds.contains(acc.id) } - .sortBy(_.id) - - if (updates.isEmpty) { - return - } - - val ids = new Array[Long](updates.size) - val values = new Array[Long](updates.size) - updates.zipWithIndex.foreach { case (acc, idx) => - ids(idx) = acc.id - // In a live application, accumulators have Long values, but when reading from event - // logs, they have String values. For now, assume all accumulators are Long and covert - // accordingly. - values(idx) = acc.update.get match { - case s: String => s.toLong - case l: Long => l - case o => throw new IllegalArgumentException(s"Unexpected: $o") - } - } - - // TODO: storing metrics by task ID can cause metrics for the same task index to be - // counted multiple times, for example due to speculation or re-attempts. - metrics.taskMetrics.put(taskId, new LiveTaskMetrics(ids, values, succeeded)) } } @@ -425,12 +433,76 @@ private class LiveExecutionData(val executionId: Long) extends LiveEntity { } private class LiveStageMetrics( - val stageId: Int, - var attemptId: Int, - val accumulatorIds: Set[Long], - val taskMetrics: ConcurrentHashMap[Long, LiveTaskMetrics]) + val attemptId: Int, + val numTasks: Int, + val accumulatorIds: Set[Long]) { -private class LiveTaskMetrics( - val ids: Array[Long], - val values: Array[Long], - val succeeded: Boolean) + /** + * Mapping of task IDs to their respective index. Note this may contain more elements than the + * stage's number of tasks, if speculative execution is on. + */ + private val taskIndices = new OpenHashMap[Long, Int]() + + /** Bit set tracking which indices have been successfully computed. */ + private val completedIndices = new mutable.BitSet() + + /** + * Task metrics values for the stage. Maps the metric ID to the metric values for each + * index. For each metric ID, there will be the same number of values as the number + * of indices. This relies on `SQLMetrics.stringValue` treating 0 as a neutral value, + * independent of the actual metric type. + */ + private val taskMetrics = new ConcurrentHashMap[Long, Array[Long]]() + + def registerTask(taskId: Long, taskIdx: Int): Unit = { + taskIndices.update(taskId, taskIdx) + } + + def updateTaskMetrics( + taskId: Long, + eventIdx: Int, + finished: Boolean, + accumUpdates: Seq[AccumulableInfo]): Unit = { + val taskIdx = if (eventIdx == SQLAppStatusListener.UNKNOWN_INDEX) { + if (!taskIndices.contains(taskId)) { + // We probably missed the start event for the task, just ignore it. + return + } + taskIndices(taskId) + } else { + // Here we can recover from a missing task start event. Just register the task again. + registerTask(taskId, eventIdx) + eventIdx + } + + if (completedIndices.contains(taskIdx)) { + return + } + + accumUpdates + .filter { acc => acc.update.isDefined && accumulatorIds.contains(acc.id) } + .foreach { acc => + // In a live application, accumulators have Long values, but when reading from event + // logs, they have String values. For now, assume all accumulators are Long and convert + // accordingly. + val value = acc.update.get match { + case s: String => s.toLong + case l: Long => l + case o => throw new IllegalArgumentException(s"Unexpected: $o") + } + + val metricValues = taskMetrics.computeIfAbsent(acc.id, _ => new Array(numTasks)) + metricValues(taskIdx) = value + } + + if (finished) { + completedIndices += taskIdx + } + } + + def metricValues(): Seq[(Long, Array[Long])] = taskMetrics.asScala.toSeq +} + +private object SQLAppStatusListener { + val UNKNOWN_INDEX = -1 +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsTestUtils.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsTestUtils.scala index 5ab9b6f5fc..57731e5f49 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsTestUtils.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsTestUtils.scala @@ -232,7 +232,8 @@ trait SQLMetricsTestUtils extends SQLTestUtils { val (actualNodeName, actualMetricsMap) = actualMetrics(nodeId) assert(expectedNodeName === actualNodeName) for ((metricName, metricPredicate) <- expectedMetricsPredicatesMap) { - assert(metricPredicate(actualMetricsMap(metricName))) + assert(metricPredicate(actualMetricsMap(metricName)), + s"$nodeId / '$metricName' (= ${actualMetricsMap(metricName)}) did not match predicate.") } } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/ui/MetricsAggregationBenchmark.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/ui/MetricsAggregationBenchmark.scala new file mode 100644 index 0000000000..a88abc8209 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/ui/MetricsAggregationBenchmark.scala @@ -0,0 +1,219 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.ui + +import java.util.Properties +import java.util.concurrent.atomic.AtomicInteger + +import scala.collection.mutable +import scala.concurrent.duration._ + +import org.apache.spark.{SparkConf, TaskState} +import org.apache.spark.benchmark.{Benchmark, BenchmarkBase} +import org.apache.spark.executor.ExecutorMetrics +import org.apache.spark.internal.config.Status._ +import org.apache.spark.scheduler._ +import org.apache.spark.sql.execution._ +import org.apache.spark.sql.execution.metric.SQLMetricInfo +import org.apache.spark.status.ElementTrackingStore +import org.apache.spark.util.{AccumulatorMetadata, LongAccumulator, Utils} +import org.apache.spark.util.kvstore.InMemoryStore + +/** + * Benchmark for metrics aggregation in the SQL listener. + * {{{ + * To run this benchmark: + * 1. without sbt: bin/spark-submit --class --jars + * 2. build/sbt "core/test:runMain " + * 3. generate result: SPARK_GENERATE_BENCHMARK_FILES=1 build/sbt "core/test:runMain " + * Results will be written to "benchmarks/MetricsAggregationBenchmark-results.txt". + * }}} + */ +object MetricsAggregationBenchmark extends BenchmarkBase { + + private def metricTrackingBenchmark( + timer: Benchmark.Timer, + numMetrics: Int, + numTasks: Int, + numStages: Int): Measurements = { + val conf = new SparkConf() + .set(LIVE_ENTITY_UPDATE_PERIOD, 0L) + .set(ASYNC_TRACKING_ENABLED, false) + val kvstore = new ElementTrackingStore(new InMemoryStore(), conf) + val listener = new SQLAppStatusListener(conf, kvstore, live = true) + val store = new SQLAppStatusStore(kvstore, Some(listener)) + + val metrics = (0 until numMetrics).map { i => + new SQLMetricInfo(s"metric$i", i.toLong, "average") + } + + val planInfo = new SparkPlanInfo( + getClass().getName(), + getClass().getName(), + Nil, + Map.empty, + metrics) + + val idgen = new AtomicInteger() + val executionId = idgen.incrementAndGet() + val executionStart = SparkListenerSQLExecutionStart( + executionId, + getClass().getName(), + getClass().getName(), + getClass().getName(), + planInfo, + System.currentTimeMillis()) + + val executionEnd = SparkListenerSQLExecutionEnd(executionId, System.currentTimeMillis()) + + val properties = new Properties() + properties.setProperty(SQLExecution.EXECUTION_ID_KEY, executionId.toString) + + timer.startTiming() + listener.onOtherEvent(executionStart) + + val taskEventsTime = (0 until numStages).map { _ => + val stageInfo = new StageInfo(idgen.incrementAndGet(), 0, getClass().getName(), + numTasks, Nil, Nil, getClass().getName()) + + val jobId = idgen.incrementAndGet() + val jobStart = SparkListenerJobStart( + jobId = jobId, + time = System.currentTimeMillis(), + stageInfos = Seq(stageInfo), + properties) + + val stageStart = SparkListenerStageSubmitted(stageInfo) + + val taskOffset = idgen.incrementAndGet().toLong + val taskEvents = (0 until numTasks).map { i => + val info = new TaskInfo( + taskId = taskOffset + i.toLong, + index = i, + attemptNumber = 0, + // The following fields are not used. + launchTime = 0, + executorId = "", + host = "", + taskLocality = null, + speculative = false) + info.markFinished(TaskState.FINISHED, 1L) + + val accumulables = (0 until numMetrics).map { mid => + val acc = new LongAccumulator + acc.metadata = AccumulatorMetadata(mid, None, false) + acc.toInfo(Some(i.toLong), None) + } + + info.setAccumulables(accumulables) + + val start = SparkListenerTaskStart(stageInfo.stageId, stageInfo.attemptNumber, info) + val end = SparkListenerTaskEnd(stageInfo.stageId, stageInfo.attemptNumber, + taskType = "", + reason = null, + info, + new ExecutorMetrics(), + null) + + (start, end) + } + + val jobEnd = SparkListenerJobEnd( + jobId = jobId, + time = System.currentTimeMillis(), + JobSucceeded) + + listener.onJobStart(jobStart) + listener.onStageSubmitted(stageStart) + + val (_, _taskEventsTime) = Utils.timeTakenMs { + taskEvents.foreach { case (start, end) => + listener.onTaskStart(start) + listener.onTaskEnd(end) + } + } + + listener.onJobEnd(jobEnd) + _taskEventsTime + } + + val (_, aggTime) = Utils.timeTakenMs { + listener.onOtherEvent(executionEnd) + val metrics = store.executionMetrics(executionId) + assert(metrics.size == numMetrics, s"${metrics.size} != $numMetrics") + } + + timer.stopTiming() + kvstore.close() + + Measurements(taskEventsTime, aggTime) + } + + override def runBenchmarkSuite(mainArgs: Array[String]): Unit = { + val metricCount = 50 + val taskCount = 100000 + val stageCounts = Seq(1, 2, 3) + + val benchmark = new Benchmark( + s"metrics aggregation ($metricCount metrics, $taskCount tasks per stage)", 1, + warmupTime = 0.seconds, output = output) + + // Run this outside the measurement code so that classes are loaded and JIT is triggered, + // otherwise the first run tends to be much slower than others. Also because this benchmark is a + // bit weird and doesn't really map to what the Benchmark class expects, so it's a bit harder + // to use warmupTime and friends effectively. + stageCounts.foreach { count => + metricTrackingBenchmark(new Benchmark.Timer(-1), metricCount, taskCount, count) + } + + val measurements = mutable.HashMap[Int, Seq[Measurements]]() + + stageCounts.foreach { count => + benchmark.addTimerCase(s"$count stage(s)") { timer => + val m = metricTrackingBenchmark(timer, metricCount, taskCount, count) + val all = measurements.getOrElse(count, Nil) + measurements(count) = all ++ Seq(m) + } + } + + benchmark.run() + + benchmark.out.printf("Stage Count Stage Proc. Time Aggreg. Time\n") + stageCounts.foreach { count => + val data = measurements(count) + val eventsTimes = data.flatMap(_.taskEventsTimes) + val aggTimes = data.map(_.aggregationTime) + + val msg = " %d %d %d\n".format( + count, + eventsTimes.sum / eventsTimes.size, + aggTimes.sum / aggTimes.size) + benchmark.out.printf(msg) + } + } + + /** + * Finer-grained measurements of how long it takes to run some parts of the benchmark. This is + * collected by the benchmark method, so this collection slightly affects the overall benchmark + * results, but this data helps with seeing where the time is going, since this benchmark is + * triggering a whole lot of code in the listener class. + */ + case class Measurements( + taskEventsTimes: Seq[Long], + aggregationTime: Long) +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/ui/SQLAppStatusListenerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/ui/SQLAppStatusListenerSuite.scala index 88864ccec7..b8c0935b33 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/ui/SQLAppStatusListenerSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/ui/SQLAppStatusListenerSuite.scala @@ -79,9 +79,9 @@ class SQLAppStatusListenerSuite extends SharedSparkSession with JsonTestUtils private def createStageInfo(stageId: Int, attemptId: Int): StageInfo = { new StageInfo(stageId = stageId, attemptId = attemptId, + numTasks = 8, // The following fields are not used in tests name = "", - numTasks = 0, rddInfos = Nil, parentIds = Nil, details = "") @@ -94,8 +94,8 @@ class SQLAppStatusListenerSuite extends SharedSparkSession with JsonTestUtils val info = new TaskInfo( taskId = taskId, attemptNumber = attemptNumber, + index = taskId.toInt, // The following fields are not used in tests - index = 0, launchTime = 0, executorId = "", host = "", @@ -190,6 +190,8 @@ class SQLAppStatusListenerSuite extends SharedSparkSession with JsonTestUtils ), createProperties(executionId))) listener.onStageSubmitted(SparkListenerStageSubmitted(createStageInfo(0, 0))) + listener.onTaskStart(SparkListenerTaskStart(0, 0, createTaskInfo(0, 0))) + listener.onTaskStart(SparkListenerTaskStart(0, 0, createTaskInfo(1, 0))) assert(statusStore.executionMetrics(executionId).isEmpty) @@ -217,6 +219,8 @@ class SQLAppStatusListenerSuite extends SharedSparkSession with JsonTestUtils // Retrying a stage should reset the metrics listener.onStageSubmitted(SparkListenerStageSubmitted(createStageInfo(0, 1))) + listener.onTaskStart(SparkListenerTaskStart(0, 1, createTaskInfo(0, 0))) + listener.onTaskStart(SparkListenerTaskStart(0, 1, createTaskInfo(1, 0))) listener.onExecutorMetricsUpdate(SparkListenerExecutorMetricsUpdate("", Seq( // (task id, stage id, stage attempt, accum updates) @@ -260,6 +264,8 @@ class SQLAppStatusListenerSuite extends SharedSparkSession with JsonTestUtils // Summit a new stage listener.onStageSubmitted(SparkListenerStageSubmitted(createStageInfo(1, 0))) + listener.onTaskStart(SparkListenerTaskStart(1, 0, createTaskInfo(0, 0))) + listener.onTaskStart(SparkListenerTaskStart(1, 0, createTaskInfo(1, 0))) listener.onExecutorMetricsUpdate(SparkListenerExecutorMetricsUpdate("", Seq( // (task id, stage id, stage attempt, accum updates) @@ -490,8 +496,8 @@ class SQLAppStatusListenerSuite extends SharedSparkSession with JsonTestUtils val statusStore = spark.sharedState.statusStore val oldCount = statusStore.executionsList().size - val expectedAccumValue = 12345 - val expectedAccumValue2 = 54321 + val expectedAccumValue = 12345L + val expectedAccumValue2 = 54321L val physicalPlan = MyPlan(sqlContext.sparkContext, expectedAccumValue, expectedAccumValue2) val dummyQueryExecution = new QueryExecution(spark, LocalRelation()) { override lazy val sparkPlan = physicalPlan @@ -517,8 +523,9 @@ class SQLAppStatusListenerSuite extends SharedSparkSession with JsonTestUtils val metrics = statusStore.executionMetrics(execId) val driverMetric = physicalPlan.metrics("dummy") val driverMetric2 = physicalPlan.metrics("dummy2") - val expectedValue = SQLMetrics.stringValue(driverMetric.metricType, Seq(expectedAccumValue)) - val expectedValue2 = SQLMetrics.stringValue(driverMetric2.metricType, Seq(expectedAccumValue2)) + val expectedValue = SQLMetrics.stringValue(driverMetric.metricType, Array(expectedAccumValue)) + val expectedValue2 = SQLMetrics.stringValue(driverMetric2.metricType, + Array(expectedAccumValue2)) assert(metrics.contains(driverMetric.id)) assert(metrics(driverMetric.id) === expectedValue)