From 44b695fbb06b0d89783b4838941c68543c5a5c8b Mon Sep 17 00:00:00 2001 From: Kousuke Saruta Date: Fri, 11 Jun 2021 01:20:35 +0800 Subject: [PATCH] [SPARK-35296][SQL] Allow Dataset.observe to work even if CollectMetricsExec in a task handles multiple partitions ### What changes were proposed in this pull request? This PR fixes an issue that `Dataset.observe` doesn't work if `CollectMetricsExec` in a task handles multiple partitions. If `coalesce` follows `observe` and the number of partitions shrinks after `coalesce`, `CollectMetricsExec` can handle multiple partitions in a task. ### Why are the changes needed? The current implementation of `CollectMetricsExec` doesn't consider the case it can handle multiple partitions. Because new `updater` is created for each partition even though those partitions belong to the same task, `collector.setState(updater)` raise an assertion error. This is a simple reproducible example. ``` $ bin/spark-shell --master "local[1]" scala> spark.range(1, 4, 1, 3).observe("my_event", count($"id").as("count_val")).coalesce(2).collect ``` ``` java.lang.AssertionError: assertion failed at scala.Predef$.assert(Predef.scala:208) at org.apache.spark.sql.execution.AggregatingAccumulator.setState(AggregatingAccumulator.scala:204) at org.apache.spark.sql.execution.CollectMetricsExec.$anonfun$doExecute$2(CollectMetricsExec.scala:72) at org.apache.spark.sql.execution.CollectMetricsExec.$anonfun$doExecute$2$adapted(CollectMetricsExec.scala:71) at org.apache.spark.TaskContext$$anon$1.onTaskCompletion(TaskContext.scala:125) at org.apache.spark.TaskContextImpl.$anonfun$markTaskCompleted$1(TaskContextImpl.scala:124) at org.apache.spark.TaskContextImpl.$anonfun$markTaskCompleted$1$adapted(TaskContextImpl.scala:124) at org.apache.spark.TaskContextImpl.$anonfun$invokeListeners$1(TaskContextImpl.scala:137) at org.apache.spark.TaskContextImpl.$anonfun$invokeListeners$1$adapted(TaskContextImpl.scala:135) ``` ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? New test. Closes #32786 from sarutak/fix-collectmetricsexec. Authored-by: Kousuke Saruta Signed-off-by: Wenchen Fan --- .../execution/AggregatingAccumulator.scala | 16 +++++---- .../sql/execution/CollectMetricsExec.scala | 6 +++- .../sql/util/DataFrameCallbackSuite.scala | 34 +++++++++++++++++++ 3 files changed, 48 insertions(+), 8 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/AggregatingAccumulator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/AggregatingAccumulator.scala index 94e159c562..0fa4e6c316 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/AggregatingAccumulator.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/AggregatingAccumulator.scala @@ -33,7 +33,7 @@ class AggregatingAccumulator private( bufferSchema: Seq[DataType], initialValues: Seq[Expression], updateExpressions: Seq[Expression], - @transient private val mergeExpressions: Seq[Expression], + mergeExpressions: Seq[Expression], @transient private val resultExpressions: Seq[Expression], imperatives: Array[ImperativeAggregate], typedImperatives: Array[TypedImperativeAggregate[_]], @@ -95,13 +95,14 @@ class AggregatingAccumulator private( /** * Driver side operations like `merge` and `value` are executed in the DAGScheduler thread. This - * thread does not have a SQL configuration so we attach our own here. Note that we can't (and - * shouldn't) call `merge` or `value` on an accumulator originating from an executor so we just - * return a default value here. + * thread does not have a SQL configuration so we attach our own here. */ - private[this] def withSQLConf[T](default: => T)(body: => T): T = { + private[this] def withSQLConf[T](canRunOnExecutor: Boolean, default: => T)(body: => T): T = { if (conf != null) { + // When we can reach here, we are on the driver side. SQLConf.withExistingConf(conf)(body) + } else if (canRunOnExecutor) { + body } else { default } @@ -147,7 +148,8 @@ class AggregatingAccumulator private( } } - override def merge(other: AccumulatorV2[InternalRow, InternalRow]): Unit = withSQLConf(()) { + override def merge( + other: AccumulatorV2[InternalRow, InternalRow]): Unit = withSQLConf(true, ()) { if (!other.isZero) { other match { case agg: AggregatingAccumulator => @@ -171,7 +173,7 @@ class AggregatingAccumulator private( } } - override def value: InternalRow = withSQLConf(InternalRow.empty) { + override def value: InternalRow = withSQLConf(false, InternalRow.empty) { // Either use the existing buffer or create a temporary one. val input = if (!isZero) { buffer diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/CollectMetricsExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/CollectMetricsExec.scala index 500425e480..933dabe009 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/CollectMetricsExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/CollectMetricsExec.scala @@ -69,7 +69,11 @@ case class CollectMetricsExec( // - Performance issues due to excessive serialization. val updater = collector.copyAndReset() TaskContext.get().addTaskCompletionListener[Unit] { _ => - collector.setState(updater) + if (collector.isZero) { + collector.setState(updater) + } else { + collector.merge(updater) + } } rows.map { r => diff --git a/sql/core/src/test/scala/org/apache/spark/sql/util/DataFrameCallbackSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/util/DataFrameCallbackSuite.scala index 02f95716ff..7a18d6ea6c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/util/DataFrameCallbackSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/util/DataFrameCallbackSuite.scala @@ -283,6 +283,40 @@ class DataFrameCallbackSuite extends QueryTest } } + test("SPARK-35296: observe should work even if a task contains multiple partitions") { + val metricMaps = ArrayBuffer.empty[Map[String, Row]] + val listener = new QueryExecutionListener { + override def onSuccess(funcName: String, qe: QueryExecution, duration: Long): Unit = { + metricMaps += qe.observedMetrics + } + + override def onFailure(funcName: String, qe: QueryExecution, exception: Exception): Unit = { + // No-op + } + } + spark.listenerManager.register(listener) + try { + val df = spark.range(1, 4, 1, 3) + .observe( + name = "my_event", + count($"id").as("count_val")) + .coalesce(2) + + def checkMetrics(metrics: Map[String, Row]): Unit = { + assert(metrics.size === 1) + assert(metrics("my_event") === Row(3L)) + } + + df.collect() + sparkContext.listenerBus.waitUntilEmpty() + assert(metricMaps.size === 1) + checkMetrics(metricMaps.head) + metricMaps.clear() + } finally { + spark.listenerManager.unregister(listener) + } + } + testQuietly("SPARK-31144: QueryExecutionListener should receive `java.lang.Error`") { var e: Exception = null val listener = new QueryExecutionListener {