[SPARK-35695][SQL] Collect observed metrics from cached and adaptive execution sub-trees
### What changes were proposed in this pull request? Collect observed metrics from cached and adaptive execution sub-trees. ### Why are the changes needed? Currently persisting/caching will hide all observed metrics in that sub-tree from reaching the `QueryExecutionListeners`. Adaptive query execution can also hide the metrics from reaching `QueryExecutionListeners`. ### Does this PR introduce _any_ user-facing change? Bugfix ### How was this patch tested? New UTs Closes #32862 from tanelk/SPARK-35695_collect_metrics_persist. Lead-authored-by: Tanel Kiis <tanel.kiis@gmail.com> Co-authored-by: tanel.kiis@gmail.com <tanel.kiis@gmail.com> Signed-off-by: Wenchen Fan <wenchen@databricks.com>
This commit is contained in:
parent
57ce64c511
commit
692dc66c4a
|
@ -22,6 +22,8 @@ import org.apache.spark.sql.Row
|
||||||
import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow}
|
import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow}
|
||||||
import org.apache.spark.sql.catalyst.expressions.{Attribute, NamedExpression, SortOrder}
|
import org.apache.spark.sql.catalyst.expressions.{Attribute, NamedExpression, SortOrder}
|
||||||
import org.apache.spark.sql.catalyst.plans.physical.Partitioning
|
import org.apache.spark.sql.catalyst.plans.physical.Partitioning
|
||||||
|
import org.apache.spark.sql.execution.adaptive.{AdaptiveSparkPlanExec, QueryStageExec}
|
||||||
|
import org.apache.spark.sql.execution.columnar.InMemoryTableScanExec
|
||||||
import org.apache.spark.sql.types.StructType
|
import org.apache.spark.sql.types.StructType
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
@ -93,8 +95,14 @@ object CollectMetricsExec {
|
||||||
*/
|
*/
|
||||||
def collect(plan: SparkPlan): Map[String, Row] = {
|
def collect(plan: SparkPlan): Map[String, Row] = {
|
||||||
val metrics = plan.collectWithSubqueries {
|
val metrics = plan.collectWithSubqueries {
|
||||||
case collector: CollectMetricsExec => collector.name -> collector.collectedMetrics
|
case collector: CollectMetricsExec => Map(collector.name -> collector.collectedMetrics)
|
||||||
|
case tableScan: InMemoryTableScanExec =>
|
||||||
|
CollectMetricsExec.collect(tableScan.relation.cachedPlan)
|
||||||
|
case adaptivePlan: AdaptiveSparkPlanExec =>
|
||||||
|
CollectMetricsExec.collect(adaptivePlan.executedPlan)
|
||||||
|
case queryStageExec: QueryStageExec =>
|
||||||
|
CollectMetricsExec.collect(queryStageExec.plan)
|
||||||
}
|
}
|
||||||
metrics.toMap
|
metrics.reduceOption(_ ++ _).getOrElse(Map.empty)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -17,6 +17,8 @@
|
||||||
|
|
||||||
package org.apache.spark.sql.util
|
package org.apache.spark.sql.util
|
||||||
|
|
||||||
|
import java.lang.{Long => JLong}
|
||||||
|
|
||||||
import scala.collection.mutable.ArrayBuffer
|
import scala.collection.mutable.ArrayBuffer
|
||||||
|
|
||||||
import org.apache.spark._
|
import org.apache.spark._
|
||||||
|
@ -28,6 +30,7 @@ import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper
|
||||||
import org.apache.spark.sql.execution.command.{CreateDataSourceTableAsSelectCommand, LeafRunnableCommand}
|
import org.apache.spark.sql.execution.command.{CreateDataSourceTableAsSelectCommand, LeafRunnableCommand}
|
||||||
import org.apache.spark.sql.execution.datasources.InsertIntoHadoopFsRelationCommand
|
import org.apache.spark.sql.execution.datasources.InsertIntoHadoopFsRelationCommand
|
||||||
import org.apache.spark.sql.execution.datasources.json.JsonFileFormat
|
import org.apache.spark.sql.execution.datasources.json.JsonFileFormat
|
||||||
|
import org.apache.spark.sql.internal.SQLConf
|
||||||
import org.apache.spark.sql.test.SharedSparkSession
|
import org.apache.spark.sql.test.SharedSparkSession
|
||||||
import org.apache.spark.sql.types.StringType
|
import org.apache.spark.sql.types.StringType
|
||||||
|
|
||||||
|
@ -235,6 +238,76 @@ class DataFrameCallbackSuite extends QueryTest
|
||||||
}
|
}
|
||||||
|
|
||||||
test("get observable metrics by callback") {
|
test("get observable metrics by callback") {
|
||||||
|
val df = spark.range(100)
|
||||||
|
.observe(
|
||||||
|
name = "my_event",
|
||||||
|
min($"id").as("min_val"),
|
||||||
|
max($"id").as("max_val"),
|
||||||
|
// Test unresolved alias
|
||||||
|
sum($"id"),
|
||||||
|
count(when($"id" % 2 === 0, 1)).as("num_even"))
|
||||||
|
.observe(
|
||||||
|
name = "other_event",
|
||||||
|
avg($"id").cast("int").as("avg_val"))
|
||||||
|
|
||||||
|
validateObservedMetrics(df)
|
||||||
|
}
|
||||||
|
|
||||||
|
test("SPARK-35296: observe should work even if a task contains multiple partitions") {
|
||||||
|
val df = spark.range(0, 100, 1, 3)
|
||||||
|
.observe(
|
||||||
|
name = "my_event",
|
||||||
|
min($"id").as("min_val"),
|
||||||
|
max($"id").as("max_val"),
|
||||||
|
// Test unresolved alias
|
||||||
|
sum($"id"),
|
||||||
|
count(when($"id" % 2 === 0, 1)).as("num_even"))
|
||||||
|
.observe(
|
||||||
|
name = "other_event",
|
||||||
|
avg($"id").cast("int").as("avg_val"))
|
||||||
|
.coalesce(2)
|
||||||
|
|
||||||
|
validateObservedMetrics(df)
|
||||||
|
}
|
||||||
|
|
||||||
|
test("SPARK-35695: get observable metrics with persist by callback") {
|
||||||
|
val df = spark.range(100)
|
||||||
|
.observe(
|
||||||
|
name = "my_event",
|
||||||
|
min($"id").as("min_val"),
|
||||||
|
max($"id").as("max_val"),
|
||||||
|
// Test unresolved alias
|
||||||
|
sum($"id"),
|
||||||
|
count(when($"id" % 2 === 0, 1)).as("num_even"))
|
||||||
|
.persist()
|
||||||
|
.observe(
|
||||||
|
name = "other_event",
|
||||||
|
avg($"id").cast("int").as("avg_val"))
|
||||||
|
.persist()
|
||||||
|
|
||||||
|
validateObservedMetrics(df)
|
||||||
|
}
|
||||||
|
|
||||||
|
test("SPARK-35695: get observable metrics with adaptive execution by callback") {
|
||||||
|
withSQLConf(SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true") {
|
||||||
|
val df = spark.range(100)
|
||||||
|
.observe(
|
||||||
|
name = "my_event",
|
||||||
|
min($"id").as("min_val"),
|
||||||
|
max($"id").as("max_val"),
|
||||||
|
// Test unresolved alias
|
||||||
|
sum($"id"),
|
||||||
|
count(when($"id" % 2 === 0, 1)).as("num_even"))
|
||||||
|
.repartition($"id")
|
||||||
|
.observe(
|
||||||
|
name = "other_event",
|
||||||
|
avg($"id").cast("int").as("avg_val"))
|
||||||
|
|
||||||
|
validateObservedMetrics(df)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private def validateObservedMetrics(df: Dataset[JLong]): Unit = {
|
||||||
val metricMaps = ArrayBuffer.empty[Map[String, Row]]
|
val metricMaps = ArrayBuffer.empty[Map[String, Row]]
|
||||||
val listener = new QueryExecutionListener {
|
val listener = new QueryExecutionListener {
|
||||||
override def onSuccess(funcName: String, qe: QueryExecution, duration: Long): Unit = {
|
override def onSuccess(funcName: String, qe: QueryExecution, duration: Long): Unit = {
|
||||||
|
@ -247,18 +320,6 @@ class DataFrameCallbackSuite extends QueryTest
|
||||||
}
|
}
|
||||||
spark.listenerManager.register(listener)
|
spark.listenerManager.register(listener)
|
||||||
try {
|
try {
|
||||||
val df = spark.range(100)
|
|
||||||
.observe(
|
|
||||||
name = "my_event",
|
|
||||||
min($"id").as("min_val"),
|
|
||||||
max($"id").as("max_val"),
|
|
||||||
// Test unresolved alias
|
|
||||||
sum($"id"),
|
|
||||||
count(when($"id" % 2 === 0, 1)).as("num_even"))
|
|
||||||
.observe(
|
|
||||||
name = "other_event",
|
|
||||||
avg($"id").cast("int").as("avg_val"))
|
|
||||||
|
|
||||||
def checkMetrics(metrics: Map[String, Row]): Unit = {
|
def checkMetrics(metrics: Map[String, Row]): Unit = {
|
||||||
assert(metrics.size === 2)
|
assert(metrics.size === 2)
|
||||||
assert(metrics("my_event") === Row(0L, 99L, 4950L, 50L))
|
assert(metrics("my_event") === Row(0L, 99L, 4950L, 50L))
|
||||||
|
@ -283,39 +344,6 @@ class DataFrameCallbackSuite extends QueryTest
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
test("SPARK-35296: observe should work even if a task contains multiple partitions") {
|
|
||||||
val metricMaps = ArrayBuffer.empty[Map[String, Row]]
|
|
||||||
val listener = new QueryExecutionListener {
|
|
||||||
override def onSuccess(funcName: String, qe: QueryExecution, duration: Long): Unit = {
|
|
||||||
metricMaps += qe.observedMetrics
|
|
||||||
}
|
|
||||||
|
|
||||||
override def onFailure(funcName: String, qe: QueryExecution, exception: Exception): Unit = {
|
|
||||||
// No-op
|
|
||||||
}
|
|
||||||
}
|
|
||||||
spark.listenerManager.register(listener)
|
|
||||||
try {
|
|
||||||
val df = spark.range(1, 4, 1, 3)
|
|
||||||
.observe(
|
|
||||||
name = "my_event",
|
|
||||||
count($"id").as("count_val"))
|
|
||||||
.coalesce(2)
|
|
||||||
|
|
||||||
def checkMetrics(metrics: Map[String, Row]): Unit = {
|
|
||||||
assert(metrics.size === 1)
|
|
||||||
assert(metrics("my_event") === Row(3L))
|
|
||||||
}
|
|
||||||
|
|
||||||
df.collect()
|
|
||||||
sparkContext.listenerBus.waitUntilEmpty()
|
|
||||||
assert(metricMaps.size === 1)
|
|
||||||
checkMetrics(metricMaps.head)
|
|
||||||
metricMaps.clear()
|
|
||||||
} finally {
|
|
||||||
spark.listenerManager.unregister(listener)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
testQuietly("SPARK-31144: QueryExecutionListener should receive `java.lang.Error`") {
|
testQuietly("SPARK-31144: QueryExecutionListener should receive `java.lang.Error`") {
|
||||||
var e: Exception = null
|
var e: Exception = null
|
||||||
|
|
Loading…
Reference in a new issue