[SPARK-35296][SQL] Allow Dataset.observe to work even if CollectMetricsExec in a task handles multiple partitions
### What changes were proposed in this pull request? This PR fixes an issue that `Dataset.observe` doesn't work if `CollectMetricsExec` in a task handles multiple partitions. If `coalesce` follows `observe` and the number of partitions shrinks after `coalesce`, `CollectMetricsExec` can handle multiple partitions in a task. ### Why are the changes needed? The current implementation of `CollectMetricsExec` doesn't consider the case it can handle multiple partitions. Because new `updater` is created for each partition even though those partitions belong to the same task, `collector.setState(updater)` raise an assertion error. This is a simple reproducible example. ``` $ bin/spark-shell --master "local[1]" scala> spark.range(1, 4, 1, 3).observe("my_event", count($"id").as("count_val")).coalesce(2).collect ``` ``` java.lang.AssertionError: assertion failed at scala.Predef$.assert(Predef.scala:208) at org.apache.spark.sql.execution.AggregatingAccumulator.setState(AggregatingAccumulator.scala:204) at org.apache.spark.sql.execution.CollectMetricsExec.$anonfun$doExecute$2(CollectMetricsExec.scala:72) at org.apache.spark.sql.execution.CollectMetricsExec.$anonfun$doExecute$2$adapted(CollectMetricsExec.scala:71) at org.apache.spark.TaskContext$$anon$1.onTaskCompletion(TaskContext.scala:125) at org.apache.spark.TaskContextImpl.$anonfun$markTaskCompleted$1(TaskContextImpl.scala:124) at org.apache.spark.TaskContextImpl.$anonfun$markTaskCompleted$1$adapted(TaskContextImpl.scala:124) at org.apache.spark.TaskContextImpl.$anonfun$invokeListeners$1(TaskContextImpl.scala:137) at org.apache.spark.TaskContextImpl.$anonfun$invokeListeners$1$adapted(TaskContextImpl.scala:135) ``` ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? New test. Closes #32786 from sarutak/fix-collectmetricsexec. Authored-by: Kousuke Saruta <sarutak@oss.nttdata.com> Signed-off-by: Wenchen Fan <wenchen@databricks.com>
This commit is contained in:
parent
e2e3fe7782
commit
44b695fbb0
|
@ -33,7 +33,7 @@ class AggregatingAccumulator private(
|
|||
bufferSchema: Seq[DataType],
|
||||
initialValues: Seq[Expression],
|
||||
updateExpressions: Seq[Expression],
|
||||
@transient private val mergeExpressions: Seq[Expression],
|
||||
mergeExpressions: Seq[Expression],
|
||||
@transient private val resultExpressions: Seq[Expression],
|
||||
imperatives: Array[ImperativeAggregate],
|
||||
typedImperatives: Array[TypedImperativeAggregate[_]],
|
||||
|
@ -95,13 +95,14 @@ class AggregatingAccumulator private(
|
|||
|
||||
/**
|
||||
* Driver side operations like `merge` and `value` are executed in the DAGScheduler thread. This
|
||||
* thread does not have a SQL configuration so we attach our own here. Note that we can't (and
|
||||
* shouldn't) call `merge` or `value` on an accumulator originating from an executor so we just
|
||||
* return a default value here.
|
||||
* thread does not have a SQL configuration so we attach our own here.
|
||||
*/
|
||||
private[this] def withSQLConf[T](default: => T)(body: => T): T = {
|
||||
private[this] def withSQLConf[T](canRunOnExecutor: Boolean, default: => T)(body: => T): T = {
|
||||
if (conf != null) {
|
||||
// When we can reach here, we are on the driver side.
|
||||
SQLConf.withExistingConf(conf)(body)
|
||||
} else if (canRunOnExecutor) {
|
||||
body
|
||||
} else {
|
||||
default
|
||||
}
|
||||
|
@ -147,7 +148,8 @@ class AggregatingAccumulator private(
|
|||
}
|
||||
}
|
||||
|
||||
override def merge(other: AccumulatorV2[InternalRow, InternalRow]): Unit = withSQLConf(()) {
|
||||
override def merge(
|
||||
other: AccumulatorV2[InternalRow, InternalRow]): Unit = withSQLConf(true, ()) {
|
||||
if (!other.isZero) {
|
||||
other match {
|
||||
case agg: AggregatingAccumulator =>
|
||||
|
@ -171,7 +173,7 @@ class AggregatingAccumulator private(
|
|||
}
|
||||
}
|
||||
|
||||
override def value: InternalRow = withSQLConf(InternalRow.empty) {
|
||||
override def value: InternalRow = withSQLConf(false, InternalRow.empty) {
|
||||
// Either use the existing buffer or create a temporary one.
|
||||
val input = if (!isZero) {
|
||||
buffer
|
||||
|
|
|
@ -69,7 +69,11 @@ case class CollectMetricsExec(
|
|||
// - Performance issues due to excessive serialization.
|
||||
val updater = collector.copyAndReset()
|
||||
TaskContext.get().addTaskCompletionListener[Unit] { _ =>
|
||||
collector.setState(updater)
|
||||
if (collector.isZero) {
|
||||
collector.setState(updater)
|
||||
} else {
|
||||
collector.merge(updater)
|
||||
}
|
||||
}
|
||||
|
||||
rows.map { r =>
|
||||
|
|
|
@ -283,6 +283,40 @@ class DataFrameCallbackSuite extends QueryTest
|
|||
}
|
||||
}
|
||||
|
||||
test("SPARK-35296: observe should work even if a task contains multiple partitions") {
|
||||
val metricMaps = ArrayBuffer.empty[Map[String, Row]]
|
||||
val listener = new QueryExecutionListener {
|
||||
override def onSuccess(funcName: String, qe: QueryExecution, duration: Long): Unit = {
|
||||
metricMaps += qe.observedMetrics
|
||||
}
|
||||
|
||||
override def onFailure(funcName: String, qe: QueryExecution, exception: Exception): Unit = {
|
||||
// No-op
|
||||
}
|
||||
}
|
||||
spark.listenerManager.register(listener)
|
||||
try {
|
||||
val df = spark.range(1, 4, 1, 3)
|
||||
.observe(
|
||||
name = "my_event",
|
||||
count($"id").as("count_val"))
|
||||
.coalesce(2)
|
||||
|
||||
def checkMetrics(metrics: Map[String, Row]): Unit = {
|
||||
assert(metrics.size === 1)
|
||||
assert(metrics("my_event") === Row(3L))
|
||||
}
|
||||
|
||||
df.collect()
|
||||
sparkContext.listenerBus.waitUntilEmpty()
|
||||
assert(metricMaps.size === 1)
|
||||
checkMetrics(metricMaps.head)
|
||||
metricMaps.clear()
|
||||
} finally {
|
||||
spark.listenerManager.unregister(listener)
|
||||
}
|
||||
}
|
||||
|
||||
testQuietly("SPARK-31144: QueryExecutionListener should receive `java.lang.Error`") {
|
||||
var e: Exception = null
|
||||
val listener = new QueryExecutionListener {
|
||||
|
|
Loading…
Reference in a new issue