From 692dc66c4a3660665c1f156df6eeb9ce6f86195e Mon Sep 17 00:00:00 2001 From: Tanel Kiis Date: Fri, 11 Jun 2021 21:03:08 +0800 Subject: [PATCH] [SPARK-35695][SQL] Collect observed metrics from cached and adaptive execution sub-trees ### What changes were proposed in this pull request? Collect observed metrics from cached and adaptive execution sub-trees. ### Why are the changes needed? Currently persisting/caching will hide all observed metrics in that sub-tree from reaching the `QueryExecutionListeners`. Adaptive query execution can also hide the metrics from reaching `QueryExecutionListeners`. ### Does this PR introduce _any_ user-facing change? Bugfix ### How was this patch tested? New UTs Closes #32862 from tanelk/SPARK-35695_collect_metrics_persist. Lead-authored-by: Tanel Kiis Co-authored-by: tanel.kiis@gmail.com Signed-off-by: Wenchen Fan --- .../sql/execution/CollectMetricsExec.scala | 12 +- .../sql/util/DataFrameCallbackSuite.scala | 118 +++++++++++------- 2 files changed, 83 insertions(+), 47 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/CollectMetricsExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/CollectMetricsExec.scala index 933dabe009..89aeb09676 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/CollectMetricsExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/CollectMetricsExec.scala @@ -22,6 +22,8 @@ import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow} import org.apache.spark.sql.catalyst.expressions.{Attribute, NamedExpression, SortOrder} import org.apache.spark.sql.catalyst.plans.physical.Partitioning +import org.apache.spark.sql.execution.adaptive.{AdaptiveSparkPlanExec, QueryStageExec} +import org.apache.spark.sql.execution.columnar.InMemoryTableScanExec import org.apache.spark.sql.types.StructType /** @@ -93,8 +95,14 @@ object CollectMetricsExec { */ def collect(plan: SparkPlan): Map[String, Row] = { val metrics = plan.collectWithSubqueries { - case collector: CollectMetricsExec => collector.name -> collector.collectedMetrics + case collector: CollectMetricsExec => Map(collector.name -> collector.collectedMetrics) + case tableScan: InMemoryTableScanExec => + CollectMetricsExec.collect(tableScan.relation.cachedPlan) + case adaptivePlan: AdaptiveSparkPlanExec => + CollectMetricsExec.collect(adaptivePlan.executedPlan) + case queryStageExec: QueryStageExec => + CollectMetricsExec.collect(queryStageExec.plan) } - metrics.toMap + metrics.reduceOption(_ ++ _).getOrElse(Map.empty) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/util/DataFrameCallbackSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/util/DataFrameCallbackSuite.scala index 7a18d6ea6c..01efd9857f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/util/DataFrameCallbackSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/util/DataFrameCallbackSuite.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql.util +import java.lang.{Long => JLong} + import scala.collection.mutable.ArrayBuffer import org.apache.spark._ @@ -28,6 +30,7 @@ import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper import org.apache.spark.sql.execution.command.{CreateDataSourceTableAsSelectCommand, LeafRunnableCommand} import org.apache.spark.sql.execution.datasources.InsertIntoHadoopFsRelationCommand import org.apache.spark.sql.execution.datasources.json.JsonFileFormat +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SharedSparkSession import org.apache.spark.sql.types.StringType @@ -235,6 +238,76 @@ class DataFrameCallbackSuite extends QueryTest } test("get observable metrics by callback") { + val df = spark.range(100) + .observe( + name = "my_event", + min($"id").as("min_val"), + max($"id").as("max_val"), + // Test unresolved alias + sum($"id"), + count(when($"id" % 2 === 0, 1)).as("num_even")) + .observe( + name = "other_event", + avg($"id").cast("int").as("avg_val")) + + validateObservedMetrics(df) + } + + test("SPARK-35296: observe should work even if a task contains multiple partitions") { + val df = spark.range(0, 100, 1, 3) + .observe( + name = "my_event", + min($"id").as("min_val"), + max($"id").as("max_val"), + // Test unresolved alias + sum($"id"), + count(when($"id" % 2 === 0, 1)).as("num_even")) + .observe( + name = "other_event", + avg($"id").cast("int").as("avg_val")) + .coalesce(2) + + validateObservedMetrics(df) + } + + test("SPARK-35695: get observable metrics with persist by callback") { + val df = spark.range(100) + .observe( + name = "my_event", + min($"id").as("min_val"), + max($"id").as("max_val"), + // Test unresolved alias + sum($"id"), + count(when($"id" % 2 === 0, 1)).as("num_even")) + .persist() + .observe( + name = "other_event", + avg($"id").cast("int").as("avg_val")) + .persist() + + validateObservedMetrics(df) + } + + test("SPARK-35695: get observable metrics with adaptive execution by callback") { + withSQLConf(SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true") { + val df = spark.range(100) + .observe( + name = "my_event", + min($"id").as("min_val"), + max($"id").as("max_val"), + // Test unresolved alias + sum($"id"), + count(when($"id" % 2 === 0, 1)).as("num_even")) + .repartition($"id") + .observe( + name = "other_event", + avg($"id").cast("int").as("avg_val")) + + validateObservedMetrics(df) + } + } + + private def validateObservedMetrics(df: Dataset[JLong]): Unit = { val metricMaps = ArrayBuffer.empty[Map[String, Row]] val listener = new QueryExecutionListener { override def onSuccess(funcName: String, qe: QueryExecution, duration: Long): Unit = { @@ -247,18 +320,6 @@ class DataFrameCallbackSuite extends QueryTest } spark.listenerManager.register(listener) try { - val df = spark.range(100) - .observe( - name = "my_event", - min($"id").as("min_val"), - max($"id").as("max_val"), - // Test unresolved alias - sum($"id"), - count(when($"id" % 2 === 0, 1)).as("num_even")) - .observe( - name = "other_event", - avg($"id").cast("int").as("avg_val")) - def checkMetrics(metrics: Map[String, Row]): Unit = { assert(metrics.size === 2) assert(metrics("my_event") === Row(0L, 99L, 4950L, 50L)) @@ -283,39 +344,6 @@ class DataFrameCallbackSuite extends QueryTest } } - test("SPARK-35296: observe should work even if a task contains multiple partitions") { - val metricMaps = ArrayBuffer.empty[Map[String, Row]] - val listener = new QueryExecutionListener { - override def onSuccess(funcName: String, qe: QueryExecution, duration: Long): Unit = { - metricMaps += qe.observedMetrics - } - - override def onFailure(funcName: String, qe: QueryExecution, exception: Exception): Unit = { - // No-op - } - } - spark.listenerManager.register(listener) - try { - val df = spark.range(1, 4, 1, 3) - .observe( - name = "my_event", - count($"id").as("count_val")) - .coalesce(2) - - def checkMetrics(metrics: Map[String, Row]): Unit = { - assert(metrics.size === 1) - assert(metrics("my_event") === Row(3L)) - } - - df.collect() - sparkContext.listenerBus.waitUntilEmpty() - assert(metricMaps.size === 1) - checkMetrics(metricMaps.head) - metricMaps.clear() - } finally { - spark.listenerManager.unregister(listener) - } - } testQuietly("SPARK-31144: QueryExecutionListener should receive `java.lang.Error`") { var e: Exception = null