[SPARK-35695][SQL] Collect observed metrics from cached and adaptive execution sub-trees
### What changes were proposed in this pull request? Collect observed metrics from cached and adaptive execution sub-trees. ### Why are the changes needed? Currently persisting/caching will hide all observed metrics in that sub-tree from reaching the `QueryExecutionListeners`. Adaptive query execution can also hide the metrics from reaching `QueryExecutionListeners`. ### Does this PR introduce _any_ user-facing change? Bugfix ### How was this patch tested? New UTs Closes #32862 from tanelk/SPARK-35695_collect_metrics_persist. Lead-authored-by: Tanel Kiis <tanel.kiis@gmail.com> Co-authored-by: tanel.kiis@gmail.com <tanel.kiis@gmail.com> Signed-off-by: Wenchen Fan <wenchen@databricks.com>
This commit is contained in:
parent
57ce64c511
commit
692dc66c4a
|
@ -22,6 +22,8 @@ import org.apache.spark.sql.Row
|
|||
import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow}
|
||||
import org.apache.spark.sql.catalyst.expressions.{Attribute, NamedExpression, SortOrder}
|
||||
import org.apache.spark.sql.catalyst.plans.physical.Partitioning
|
||||
import org.apache.spark.sql.execution.adaptive.{AdaptiveSparkPlanExec, QueryStageExec}
|
||||
import org.apache.spark.sql.execution.columnar.InMemoryTableScanExec
|
||||
import org.apache.spark.sql.types.StructType
|
||||
|
||||
/**
|
||||
|
@ -93,8 +95,14 @@ object CollectMetricsExec {
|
|||
*/
|
||||
def collect(plan: SparkPlan): Map[String, Row] = {
|
||||
val metrics = plan.collectWithSubqueries {
|
||||
case collector: CollectMetricsExec => collector.name -> collector.collectedMetrics
|
||||
case collector: CollectMetricsExec => Map(collector.name -> collector.collectedMetrics)
|
||||
case tableScan: InMemoryTableScanExec =>
|
||||
CollectMetricsExec.collect(tableScan.relation.cachedPlan)
|
||||
case adaptivePlan: AdaptiveSparkPlanExec =>
|
||||
CollectMetricsExec.collect(adaptivePlan.executedPlan)
|
||||
case queryStageExec: QueryStageExec =>
|
||||
CollectMetricsExec.collect(queryStageExec.plan)
|
||||
}
|
||||
metrics.toMap
|
||||
metrics.reduceOption(_ ++ _).getOrElse(Map.empty)
|
||||
}
|
||||
}
|
||||
|
|
|
@ -17,6 +17,8 @@
|
|||
|
||||
package org.apache.spark.sql.util
|
||||
|
||||
import java.lang.{Long => JLong}
|
||||
|
||||
import scala.collection.mutable.ArrayBuffer
|
||||
|
||||
import org.apache.spark._
|
||||
|
@ -28,6 +30,7 @@ import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper
|
|||
import org.apache.spark.sql.execution.command.{CreateDataSourceTableAsSelectCommand, LeafRunnableCommand}
|
||||
import org.apache.spark.sql.execution.datasources.InsertIntoHadoopFsRelationCommand
|
||||
import org.apache.spark.sql.execution.datasources.json.JsonFileFormat
|
||||
import org.apache.spark.sql.internal.SQLConf
|
||||
import org.apache.spark.sql.test.SharedSparkSession
|
||||
import org.apache.spark.sql.types.StringType
|
||||
|
||||
|
@ -235,18 +238,6 @@ class DataFrameCallbackSuite extends QueryTest
|
|||
}
|
||||
|
||||
test("get observable metrics by callback") {
|
||||
val metricMaps = ArrayBuffer.empty[Map[String, Row]]
|
||||
val listener = new QueryExecutionListener {
|
||||
override def onSuccess(funcName: String, qe: QueryExecution, duration: Long): Unit = {
|
||||
metricMaps += qe.observedMetrics
|
||||
}
|
||||
|
||||
override def onFailure(funcName: String, qe: QueryExecution, exception: Exception): Unit = {
|
||||
// No-op
|
||||
}
|
||||
}
|
||||
spark.listenerManager.register(listener)
|
||||
try {
|
||||
val df = spark.range(100)
|
||||
.observe(
|
||||
name = "my_event",
|
||||
|
@ -259,6 +250,76 @@ class DataFrameCallbackSuite extends QueryTest
|
|||
name = "other_event",
|
||||
avg($"id").cast("int").as("avg_val"))
|
||||
|
||||
validateObservedMetrics(df)
|
||||
}
|
||||
|
||||
test("SPARK-35296: observe should work even if a task contains multiple partitions") {
|
||||
val df = spark.range(0, 100, 1, 3)
|
||||
.observe(
|
||||
name = "my_event",
|
||||
min($"id").as("min_val"),
|
||||
max($"id").as("max_val"),
|
||||
// Test unresolved alias
|
||||
sum($"id"),
|
||||
count(when($"id" % 2 === 0, 1)).as("num_even"))
|
||||
.observe(
|
||||
name = "other_event",
|
||||
avg($"id").cast("int").as("avg_val"))
|
||||
.coalesce(2)
|
||||
|
||||
validateObservedMetrics(df)
|
||||
}
|
||||
|
||||
test("SPARK-35695: get observable metrics with persist by callback") {
|
||||
val df = spark.range(100)
|
||||
.observe(
|
||||
name = "my_event",
|
||||
min($"id").as("min_val"),
|
||||
max($"id").as("max_val"),
|
||||
// Test unresolved alias
|
||||
sum($"id"),
|
||||
count(when($"id" % 2 === 0, 1)).as("num_even"))
|
||||
.persist()
|
||||
.observe(
|
||||
name = "other_event",
|
||||
avg($"id").cast("int").as("avg_val"))
|
||||
.persist()
|
||||
|
||||
validateObservedMetrics(df)
|
||||
}
|
||||
|
||||
test("SPARK-35695: get observable metrics with adaptive execution by callback") {
|
||||
withSQLConf(SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true") {
|
||||
val df = spark.range(100)
|
||||
.observe(
|
||||
name = "my_event",
|
||||
min($"id").as("min_val"),
|
||||
max($"id").as("max_val"),
|
||||
// Test unresolved alias
|
||||
sum($"id"),
|
||||
count(when($"id" % 2 === 0, 1)).as("num_even"))
|
||||
.repartition($"id")
|
||||
.observe(
|
||||
name = "other_event",
|
||||
avg($"id").cast("int").as("avg_val"))
|
||||
|
||||
validateObservedMetrics(df)
|
||||
}
|
||||
}
|
||||
|
||||
private def validateObservedMetrics(df: Dataset[JLong]): Unit = {
|
||||
val metricMaps = ArrayBuffer.empty[Map[String, Row]]
|
||||
val listener = new QueryExecutionListener {
|
||||
override def onSuccess(funcName: String, qe: QueryExecution, duration: Long): Unit = {
|
||||
metricMaps += qe.observedMetrics
|
||||
}
|
||||
|
||||
override def onFailure(funcName: String, qe: QueryExecution, exception: Exception): Unit = {
|
||||
// No-op
|
||||
}
|
||||
}
|
||||
spark.listenerManager.register(listener)
|
||||
try {
|
||||
def checkMetrics(metrics: Map[String, Row]): Unit = {
|
||||
assert(metrics.size === 2)
|
||||
assert(metrics("my_event") === Row(0L, 99L, 4950L, 50L))
|
||||
|
@ -283,39 +344,6 @@ class DataFrameCallbackSuite extends QueryTest
|
|||
}
|
||||
}
|
||||
|
||||
test("SPARK-35296: observe should work even if a task contains multiple partitions") {
|
||||
val metricMaps = ArrayBuffer.empty[Map[String, Row]]
|
||||
val listener = new QueryExecutionListener {
|
||||
override def onSuccess(funcName: String, qe: QueryExecution, duration: Long): Unit = {
|
||||
metricMaps += qe.observedMetrics
|
||||
}
|
||||
|
||||
override def onFailure(funcName: String, qe: QueryExecution, exception: Exception): Unit = {
|
||||
// No-op
|
||||
}
|
||||
}
|
||||
spark.listenerManager.register(listener)
|
||||
try {
|
||||
val df = spark.range(1, 4, 1, 3)
|
||||
.observe(
|
||||
name = "my_event",
|
||||
count($"id").as("count_val"))
|
||||
.coalesce(2)
|
||||
|
||||
def checkMetrics(metrics: Map[String, Row]): Unit = {
|
||||
assert(metrics.size === 1)
|
||||
assert(metrics("my_event") === Row(3L))
|
||||
}
|
||||
|
||||
df.collect()
|
||||
sparkContext.listenerBus.waitUntilEmpty()
|
||||
assert(metricMaps.size === 1)
|
||||
checkMetrics(metricMaps.head)
|
||||
metricMaps.clear()
|
||||
} finally {
|
||||
spark.listenerManager.unregister(listener)
|
||||
}
|
||||
}
|
||||
|
||||
testQuietly("SPARK-31144: QueryExecutionListener should receive `java.lang.Error`") {
|
||||
var e: Exception = null
|
||||
|
|
Loading…
Reference in a new issue