[SPARK-35695][SQL] Collect observed metrics from cached and adaptive execution sub-trees

### What changes were proposed in this pull request?

Collect observed metrics from cached and adaptive execution sub-trees.

### Why are the changes needed?

Currently persisting/caching will hide all observed metrics in that sub-tree from reaching the `QueryExecutionListeners`. Adaptive query execution can also hide the metrics from reaching `QueryExecutionListeners`.

### Does this PR introduce _any_ user-facing change?

Bugfix

### How was this patch tested?

New UTs

Closes #32862 from tanelk/SPARK-35695_collect_metrics_persist.

Lead-authored-by: Tanel Kiis <tanel.kiis@gmail.com>
Co-authored-by: tanel.kiis@gmail.com <tanel.kiis@gmail.com>
Signed-off-by: Wenchen Fan <wenchen@databricks.com>
This commit is contained in:
Tanel Kiis 2021-06-11 21:03:08 +08:00 committed by Wenchen Fan
parent 57ce64c511
commit 692dc66c4a
2 changed files with 83 additions and 47 deletions

View file

@ -22,6 +22,8 @@ import org.apache.spark.sql.Row
import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow}
import org.apache.spark.sql.catalyst.expressions.{Attribute, NamedExpression, SortOrder}
import org.apache.spark.sql.catalyst.plans.physical.Partitioning
import org.apache.spark.sql.execution.adaptive.{AdaptiveSparkPlanExec, QueryStageExec}
import org.apache.spark.sql.execution.columnar.InMemoryTableScanExec
import org.apache.spark.sql.types.StructType
/**
@ -93,8 +95,14 @@ object CollectMetricsExec {
*/
def collect(plan: SparkPlan): Map[String, Row] = {
val metrics = plan.collectWithSubqueries {
case collector: CollectMetricsExec => collector.name -> collector.collectedMetrics
case collector: CollectMetricsExec => Map(collector.name -> collector.collectedMetrics)
case tableScan: InMemoryTableScanExec =>
CollectMetricsExec.collect(tableScan.relation.cachedPlan)
case adaptivePlan: AdaptiveSparkPlanExec =>
CollectMetricsExec.collect(adaptivePlan.executedPlan)
case queryStageExec: QueryStageExec =>
CollectMetricsExec.collect(queryStageExec.plan)
}
metrics.toMap
metrics.reduceOption(_ ++ _).getOrElse(Map.empty)
}
}

View file

@ -17,6 +17,8 @@
package org.apache.spark.sql.util
import java.lang.{Long => JLong}
import scala.collection.mutable.ArrayBuffer
import org.apache.spark._
@ -28,6 +30,7 @@ import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper
import org.apache.spark.sql.execution.command.{CreateDataSourceTableAsSelectCommand, LeafRunnableCommand}
import org.apache.spark.sql.execution.datasources.InsertIntoHadoopFsRelationCommand
import org.apache.spark.sql.execution.datasources.json.JsonFileFormat
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.test.SharedSparkSession
import org.apache.spark.sql.types.StringType
@ -235,6 +238,76 @@ class DataFrameCallbackSuite extends QueryTest
}
test("get observable metrics by callback") {
val df = spark.range(100)
.observe(
name = "my_event",
min($"id").as("min_val"),
max($"id").as("max_val"),
// Test unresolved alias
sum($"id"),
count(when($"id" % 2 === 0, 1)).as("num_even"))
.observe(
name = "other_event",
avg($"id").cast("int").as("avg_val"))
validateObservedMetrics(df)
}
test("SPARK-35296: observe should work even if a task contains multiple partitions") {
val df = spark.range(0, 100, 1, 3)
.observe(
name = "my_event",
min($"id").as("min_val"),
max($"id").as("max_val"),
// Test unresolved alias
sum($"id"),
count(when($"id" % 2 === 0, 1)).as("num_even"))
.observe(
name = "other_event",
avg($"id").cast("int").as("avg_val"))
.coalesce(2)
validateObservedMetrics(df)
}
test("SPARK-35695: get observable metrics with persist by callback") {
val df = spark.range(100)
.observe(
name = "my_event",
min($"id").as("min_val"),
max($"id").as("max_val"),
// Test unresolved alias
sum($"id"),
count(when($"id" % 2 === 0, 1)).as("num_even"))
.persist()
.observe(
name = "other_event",
avg($"id").cast("int").as("avg_val"))
.persist()
validateObservedMetrics(df)
}
test("SPARK-35695: get observable metrics with adaptive execution by callback") {
withSQLConf(SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true") {
val df = spark.range(100)
.observe(
name = "my_event",
min($"id").as("min_val"),
max($"id").as("max_val"),
// Test unresolved alias
sum($"id"),
count(when($"id" % 2 === 0, 1)).as("num_even"))
.repartition($"id")
.observe(
name = "other_event",
avg($"id").cast("int").as("avg_val"))
validateObservedMetrics(df)
}
}
private def validateObservedMetrics(df: Dataset[JLong]): Unit = {
val metricMaps = ArrayBuffer.empty[Map[String, Row]]
val listener = new QueryExecutionListener {
override def onSuccess(funcName: String, qe: QueryExecution, duration: Long): Unit = {
@ -247,18 +320,6 @@ class DataFrameCallbackSuite extends QueryTest
}
spark.listenerManager.register(listener)
try {
val df = spark.range(100)
.observe(
name = "my_event",
min($"id").as("min_val"),
max($"id").as("max_val"),
// Test unresolved alias
sum($"id"),
count(when($"id" % 2 === 0, 1)).as("num_even"))
.observe(
name = "other_event",
avg($"id").cast("int").as("avg_val"))
def checkMetrics(metrics: Map[String, Row]): Unit = {
assert(metrics.size === 2)
assert(metrics("my_event") === Row(0L, 99L, 4950L, 50L))
@ -283,39 +344,6 @@ class DataFrameCallbackSuite extends QueryTest
}
}
test("SPARK-35296: observe should work even if a task contains multiple partitions") {
val metricMaps = ArrayBuffer.empty[Map[String, Row]]
val listener = new QueryExecutionListener {
override def onSuccess(funcName: String, qe: QueryExecution, duration: Long): Unit = {
metricMaps += qe.observedMetrics
}
override def onFailure(funcName: String, qe: QueryExecution, exception: Exception): Unit = {
// No-op
}
}
spark.listenerManager.register(listener)
try {
val df = spark.range(1, 4, 1, 3)
.observe(
name = "my_event",
count($"id").as("count_val"))
.coalesce(2)
def checkMetrics(metrics: Map[String, Row]): Unit = {
assert(metrics.size === 1)
assert(metrics("my_event") === Row(3L))
}
df.collect()
sparkContext.listenerBus.waitUntilEmpty()
assert(metricMaps.size === 1)
checkMetrics(metricMaps.head)
metricMaps.clear()
} finally {
spark.listenerManager.unregister(listener)
}
}
testQuietly("SPARK-31144: QueryExecutionListener should receive `java.lang.Error`") {
var e: Exception = null