[SPARK-35296][SQL] Allow Dataset.observe to work even if CollectMetricsExec in a task handles multiple partitions

### What changes were proposed in this pull request?

This PR fixes an issue that `Dataset.observe` doesn't work if `CollectMetricsExec` in a task handles multiple partitions.
If `coalesce` follows `observe` and the number of partitions shrinks after `coalesce`, `CollectMetricsExec` can handle multiple partitions in a task.

### Why are the changes needed?

The current implementation of `CollectMetricsExec` doesn't consider the case it can handle multiple partitions.
Because new `updater` is created for each partition even though those partitions belong to the same task, `collector.setState(updater)` raise an assertion error.
This is a simple reproducible example.
```
$ bin/spark-shell --master "local[1]"
scala> spark.range(1, 4, 1, 3).observe("my_event", count($"id").as("count_val")).coalesce(2).collect
```
```
java.lang.AssertionError: assertion failed
	at scala.Predef$.assert(Predef.scala:208)
	at org.apache.spark.sql.execution.AggregatingAccumulator.setState(AggregatingAccumulator.scala:204)
	at org.apache.spark.sql.execution.CollectMetricsExec.$anonfun$doExecute$2(CollectMetricsExec.scala:72)
	at org.apache.spark.sql.execution.CollectMetricsExec.$anonfun$doExecute$2$adapted(CollectMetricsExec.scala:71)
	at org.apache.spark.TaskContext$$anon$1.onTaskCompletion(TaskContext.scala:125)
	at org.apache.spark.TaskContextImpl.$anonfun$markTaskCompleted$1(TaskContextImpl.scala:124)
	at org.apache.spark.TaskContextImpl.$anonfun$markTaskCompleted$1$adapted(TaskContextImpl.scala:124)
	at org.apache.spark.TaskContextImpl.$anonfun$invokeListeners$1(TaskContextImpl.scala:137)
	at org.apache.spark.TaskContextImpl.$anonfun$invokeListeners$1$adapted(TaskContextImpl.scala:135)
```

### Does this PR introduce _any_ user-facing change?

No.

### How was this patch tested?

New test.

Closes #32786 from sarutak/fix-collectmetricsexec.

Authored-by: Kousuke Saruta <sarutak@oss.nttdata.com>
Signed-off-by: Wenchen Fan <wenchen@databricks.com>
This commit is contained in:
Kousuke Saruta 2021-06-11 01:20:35 +08:00 committed by Wenchen Fan
parent e2e3fe7782
commit 44b695fbb0
3 changed files with 48 additions and 8 deletions

View file

@ -33,7 +33,7 @@ class AggregatingAccumulator private(
bufferSchema: Seq[DataType],
initialValues: Seq[Expression],
updateExpressions: Seq[Expression],
@transient private val mergeExpressions: Seq[Expression],
mergeExpressions: Seq[Expression],
@transient private val resultExpressions: Seq[Expression],
imperatives: Array[ImperativeAggregate],
typedImperatives: Array[TypedImperativeAggregate[_]],
@ -95,13 +95,14 @@ class AggregatingAccumulator private(
/**
* Driver side operations like `merge` and `value` are executed in the DAGScheduler thread. This
* thread does not have a SQL configuration so we attach our own here. Note that we can't (and
* shouldn't) call `merge` or `value` on an accumulator originating from an executor so we just
* return a default value here.
* thread does not have a SQL configuration so we attach our own here.
*/
private[this] def withSQLConf[T](default: => T)(body: => T): T = {
private[this] def withSQLConf[T](canRunOnExecutor: Boolean, default: => T)(body: => T): T = {
if (conf != null) {
// When we can reach here, we are on the driver side.
SQLConf.withExistingConf(conf)(body)
} else if (canRunOnExecutor) {
body
} else {
default
}
@ -147,7 +148,8 @@ class AggregatingAccumulator private(
}
}
override def merge(other: AccumulatorV2[InternalRow, InternalRow]): Unit = withSQLConf(()) {
override def merge(
other: AccumulatorV2[InternalRow, InternalRow]): Unit = withSQLConf(true, ()) {
if (!other.isZero) {
other match {
case agg: AggregatingAccumulator =>
@ -171,7 +173,7 @@ class AggregatingAccumulator private(
}
}
override def value: InternalRow = withSQLConf(InternalRow.empty) {
override def value: InternalRow = withSQLConf(false, InternalRow.empty) {
// Either use the existing buffer or create a temporary one.
val input = if (!isZero) {
buffer

View file

@ -69,7 +69,11 @@ case class CollectMetricsExec(
// - Performance issues due to excessive serialization.
val updater = collector.copyAndReset()
TaskContext.get().addTaskCompletionListener[Unit] { _ =>
collector.setState(updater)
if (collector.isZero) {
collector.setState(updater)
} else {
collector.merge(updater)
}
}
rows.map { r =>

View file

@ -283,6 +283,40 @@ class DataFrameCallbackSuite extends QueryTest
}
}
test("SPARK-35296: observe should work even if a task contains multiple partitions") {
val metricMaps = ArrayBuffer.empty[Map[String, Row]]
val listener = new QueryExecutionListener {
override def onSuccess(funcName: String, qe: QueryExecution, duration: Long): Unit = {
metricMaps += qe.observedMetrics
}
override def onFailure(funcName: String, qe: QueryExecution, exception: Exception): Unit = {
// No-op
}
}
spark.listenerManager.register(listener)
try {
val df = spark.range(1, 4, 1, 3)
.observe(
name = "my_event",
count($"id").as("count_val"))
.coalesce(2)
def checkMetrics(metrics: Map[String, Row]): Unit = {
assert(metrics.size === 1)
assert(metrics("my_event") === Row(3L))
}
df.collect()
sparkContext.listenerBus.waitUntilEmpty()
assert(metricMaps.size === 1)
checkMetrics(metricMaps.head)
metricMaps.clear()
} finally {
spark.listenerManager.unregister(listener)
}
}
testQuietly("SPARK-31144: QueryExecutionListener should receive `java.lang.Error`") {
var e: Exception = null
val listener = new QueryExecutionListener {