[SPARK-25680][SQL] SQL execution listener shouldn't happen on execution thread

## What changes were proposed in this pull request?

The SQL execution listener framework was created from scratch(see https://github.com/apache/spark/pull/9078). It didn't leverage what we already have in the spark listener framework, and one major problem is, the listener runs on the spark execution thread, which means a bad listener can block spark's query processing.

This PR re-implements the SQL execution listener framework. Now `ExecutionListenerManager` is just a normal spark listener, which watches the `SparkListenerSQLExecutionEnd` events and post events to the
user-provided SQL execution listeners.

## How was this patch tested?

existing tests.

Closes #22674 from cloud-fan/listener.

Authored-by: Wenchen Fan <wenchen@databricks.com>
Signed-off-by: Wenchen Fan <wenchen@databricks.com>
This commit is contained in:
Wenchen Fan 2018-10-17 16:06:07 +08:00
parent e9332f600e
commit 9690eba16e
13 changed files with 170 additions and 106 deletions

View file

@ -60,6 +60,14 @@ private[spark] trait ListenerBus[L <: AnyRef, E] extends Logging {
}
}
/**
* Remove all listeners and they won't receive any events. This method is thread-safe and can be
* called in any thread.
*/
final def removeAllListeners(): Unit = {
listenersPlusTimers.clear()
}
/**
* This can be overridden by subclasses if there is any extra cleanup to do when removing a
* listener. In particular AsyncEventQueues can clean up queues in the LiveListenerBus.

View file

@ -38,7 +38,9 @@ object MimaExcludes {
lazy val v30excludes = v24excludes ++ Seq(
ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.io.SnappyCompressionCodec.version"),
ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.api.java.JavaPairRDD.flatMapValues"),
ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.streaming.api.java.JavaPairDStream.flatMapValues")
ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.streaming.api.java.JavaPairDStream.flatMapValues"),
ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.sql.util.ExecutionListenerManager.clone"),
ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.sql.util.ExecutionListenerManager.this")
)
// Exclude rules for 2.4.x

View file

@ -672,17 +672,8 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) {
*/
private def runCommand(session: SparkSession, name: String)(command: LogicalPlan): Unit = {
val qe = session.sessionState.executePlan(command)
try {
val start = System.nanoTime()
// call `QueryExecution.toRDD` to trigger the execution of commands.
SQLExecution.withNewExecutionId(session, qe)(qe.toRdd)
val end = System.nanoTime()
session.listenerManager.onSuccess(name, qe, end - start)
} catch {
case e: Exception =>
session.listenerManager.onFailure(name, qe, e)
throw e
}
// call `QueryExecution.toRDD` to trigger the execution of commands.
SQLExecution.withNewExecutionId(session, qe, Some(name))(qe.toRdd)
}
///////////////////////////////////////////////////////////////////////////////////////

View file

@ -3356,21 +3356,11 @@ class Dataset[T] private[sql](
* user-registered callback functions.
*/
private def withAction[U](name: String, qe: QueryExecution)(action: SparkPlan => U) = {
try {
SQLExecution.withNewExecutionId(sparkSession, qe, Some(name)) {
qe.executedPlan.foreach { plan =>
plan.resetMetrics()
}
val start = System.nanoTime()
val result = SQLExecution.withNewExecutionId(sparkSession, qe) {
action(qe.executedPlan)
}
val end = System.nanoTime()
sparkSession.listenerManager.onSuccess(name, qe, end - start)
result
} catch {
case e: Exception =>
sparkSession.listenerManager.onFailure(name, qe, e)
throw e
action(qe.executedPlan)
}
}

View file

@ -58,7 +58,8 @@ object SQLExecution {
*/
def withNewExecutionId[T](
sparkSession: SparkSession,
queryExecution: QueryExecution)(body: => T): T = {
queryExecution: QueryExecution,
name: Option[String] = None)(body: => T): T = {
val sc = sparkSession.sparkContext
val oldExecutionId = sc.getLocalProperty(EXECUTION_ID_KEY)
val executionId = SQLExecution.nextExecutionId
@ -71,14 +72,35 @@ object SQLExecution {
val callSite = sc.getCallSite()
withSQLConfPropagated(sparkSession) {
sc.listenerBus.post(SparkListenerSQLExecutionStart(
executionId, callSite.shortForm, callSite.longForm, queryExecution.toString,
SparkPlanInfo.fromSparkPlan(queryExecution.executedPlan), System.currentTimeMillis()))
var ex: Option[Exception] = None
val startTime = System.nanoTime()
try {
sc.listenerBus.post(SparkListenerSQLExecutionStart(
executionId = executionId,
description = callSite.shortForm,
details = callSite.longForm,
physicalPlanDescription = queryExecution.toString,
// `queryExecution.executedPlan` triggers query planning. If it fails, the exception
// will be caught and reported in the `SparkListenerSQLExecutionEnd`
sparkPlanInfo = SparkPlanInfo.fromSparkPlan(queryExecution.executedPlan),
time = System.currentTimeMillis()))
body
} catch {
case e: Exception =>
ex = Some(e)
throw e
} finally {
sc.listenerBus.post(SparkListenerSQLExecutionEnd(
executionId, System.currentTimeMillis()))
val endTime = System.nanoTime()
val event = SparkListenerSQLExecutionEnd(executionId, System.currentTimeMillis())
// Currently only `Dataset.withAction` and `DataFrameWriter.runCommand` specify the `name`
// parameter. The `ExecutionListenerManager` only watches SQL executions with name. We
// can specify the execution name in more places in the future, so that
// `QueryExecutionListener` can track more cases.
event.executionName = name
event.duration = endTime - startTime
event.qe = queryExecution
event.executionFailure = ex
sc.listenerBus.post(event)
}
}
} finally {

View file

@ -17,6 +17,7 @@
package org.apache.spark.sql.execution.ui
import com.fasterxml.jackson.annotation.JsonIgnore
import com.fasterxml.jackson.databind.JavaType
import com.fasterxml.jackson.databind.`type`.TypeFactory
import com.fasterxml.jackson.databind.annotation.JsonDeserialize
@ -24,8 +25,7 @@ import com.fasterxml.jackson.databind.util.Converter
import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.scheduler._
import org.apache.spark.sql.execution.SparkPlanInfo
import org.apache.spark.sql.execution.metric._
import org.apache.spark.sql.execution.{QueryExecution, SparkPlanInfo}
@DeveloperApi
case class SparkListenerSQLExecutionStart(
@ -39,7 +39,22 @@ case class SparkListenerSQLExecutionStart(
@DeveloperApi
case class SparkListenerSQLExecutionEnd(executionId: Long, time: Long)
extends SparkListenerEvent
extends SparkListenerEvent {
// The name of the execution, e.g. `df.collect` will trigger a SQL execution with name "collect".
@JsonIgnore private[sql] var executionName: Option[String] = None
// The following 3 fields are only accessed when `executionName` is defined.
// The duration of the SQL execution, in nanoseconds.
@JsonIgnore private[sql] var duration: Long = 0L
// The `QueryExecution` instance that represents the SQL execution
@JsonIgnore private[sql] var qe: QueryExecution = null
// The exception object that caused this execution to fail. None if the execution doesn't fail.
@JsonIgnore private[sql] var executionFailure: Option[Exception] = None
}
/**
* A message used to update SQL metric value for driver-side updates (which doesn't get reflected

View file

@ -266,8 +266,8 @@ abstract class BaseSessionStateBuilder(
* This gets cloned from parent if available, otherwise a new instance is created.
*/
protected def listenerManager: ExecutionListenerManager = {
parentState.map(_.listenerManager.clone()).getOrElse(
new ExecutionListenerManager(session.sparkContext.conf))
parentState.map(_.listenerManager.clone(session)).getOrElse(
new ExecutionListenerManager(session, loadExtensions = true))
}
/**

View file

@ -17,17 +17,16 @@
package org.apache.spark.sql.util
import java.util.concurrent.locks.ReentrantReadWriteLock
import scala.collection.JavaConverters._
import scala.collection.mutable.ListBuffer
import scala.util.control.NonFatal
import org.apache.spark.SparkConf
import org.apache.spark.annotation.{DeveloperApi, Experimental, InterfaceStability}
import org.apache.spark.internal.Logging
import org.apache.spark.scheduler.{SparkListener, SparkListenerEvent}
import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.execution.QueryExecution
import org.apache.spark.sql.execution.ui.SparkListenerSQLExecutionEnd
import org.apache.spark.sql.internal.StaticSQLConf._
import org.apache.spark.util.Utils
import org.apache.spark.util.{ListenerBus, Utils}
/**
* :: Experimental ::
@ -75,10 +74,18 @@ trait QueryExecutionListener {
*/
@Experimental
@InterfaceStability.Evolving
class ExecutionListenerManager private extends Logging {
// The `session` is used to indicate which session carries this listener manager, and we only
// catch SQL executions which are launched by the same session.
// The `loadExtensions` flag is used to indicate whether we should load the pre-defined,
// user-specified listeners during construction. We should not do it when cloning this listener
// manager, as we will copy all listeners to the cloned listener manager.
class ExecutionListenerManager private[sql](session: SparkSession, loadExtensions: Boolean)
extends Logging {
private[sql] def this(conf: SparkConf) = {
this()
private val listenerBus = new ExecutionListenerBus(session)
if (loadExtensions) {
val conf = session.sparkContext.conf
conf.get(QUERY_EXECUTION_LISTENERS).foreach { classNames =>
Utils.loadExtensions(classOf[QueryExecutionListener], classNames, conf).foreach(register)
}
@ -88,82 +95,63 @@ class ExecutionListenerManager private extends Logging {
* Registers the specified [[QueryExecutionListener]].
*/
@DeveloperApi
def register(listener: QueryExecutionListener): Unit = writeLock {
listeners += listener
def register(listener: QueryExecutionListener): Unit = {
listenerBus.addListener(listener)
}
/**
* Unregisters the specified [[QueryExecutionListener]].
*/
@DeveloperApi
def unregister(listener: QueryExecutionListener): Unit = writeLock {
listeners -= listener
def unregister(listener: QueryExecutionListener): Unit = {
listenerBus.removeListener(listener)
}
/**
* Removes all the registered [[QueryExecutionListener]].
*/
@DeveloperApi
def clear(): Unit = writeLock {
listeners.clear()
def clear(): Unit = {
listenerBus.removeAllListeners()
}
/**
* Get an identical copy of this listener manager.
*/
@DeveloperApi
override def clone(): ExecutionListenerManager = writeLock {
val newListenerManager = new ExecutionListenerManager
listeners.foreach(newListenerManager.register)
private[sql] def clone(session: SparkSession): ExecutionListenerManager = {
val newListenerManager = new ExecutionListenerManager(session, loadExtensions = false)
listenerBus.listeners.asScala.foreach(newListenerManager.register)
newListenerManager
}
}
private[sql] def onSuccess(funcName: String, qe: QueryExecution, duration: Long): Unit = {
readLock {
withErrorHandling { listener =>
listener.onSuccess(funcName, qe, duration)
private[sql] class ExecutionListenerBus(session: SparkSession)
extends SparkListener with ListenerBus[QueryExecutionListener, SparkListenerSQLExecutionEnd] {
session.sparkContext.listenerBus.addToSharedQueue(this)
override def onOtherEvent(event: SparkListenerEvent): Unit = event match {
case e: SparkListenerSQLExecutionEnd => postToAll(e)
case _ =>
}
override protected def doPostEvent(
listener: QueryExecutionListener,
event: SparkListenerSQLExecutionEnd): Unit = {
if (shouldReport(event)) {
val funcName = event.executionName.get
event.executionFailure match {
case Some(ex) =>
listener.onFailure(funcName, event.qe, ex)
case _ =>
listener.onSuccess(funcName, event.qe, event.duration)
}
}
}
private[sql] def onFailure(funcName: String, qe: QueryExecution, exception: Exception): Unit = {
readLock {
withErrorHandling { listener =>
listener.onFailure(funcName, qe, exception)
}
}
}
private[this] val listeners = ListBuffer.empty[QueryExecutionListener]
/** A lock to prevent updating the list of listeners while we are traversing through them. */
private[this] val lock = new ReentrantReadWriteLock()
private def withErrorHandling(f: QueryExecutionListener => Unit): Unit = {
for (listener <- listeners) {
try {
f(listener)
} catch {
case NonFatal(e) => logWarning("Error executing query execution listener", e)
}
}
}
/** Acquires a read lock on the cache for the duration of `f`. */
private def readLock[A](f: => A): A = {
val rl = lock.readLock()
rl.lock()
try f finally {
rl.unlock()
}
}
/** Acquires a write lock on the cache for the duration of `f`. */
private def writeLock[A](f: => A): A = {
val wl = lock.writeLock()
wl.lock()
try f finally {
wl.unlock()
}
private def shouldReport(e: SparkListenerSQLExecutionEnd): Boolean = {
// Only catch SQL execution with a name, and triggered by the same spark session that this
// listener manager belongs.
e.executionName.isDefined && e.qe.sparkSession.eq(this.session)
}
}

View file

@ -155,6 +155,7 @@ class SessionStateSuite extends SparkFunSuite {
assert(forkedSession ne activeSession)
assert(forkedSession.listenerManager ne activeSession.listenerManager)
runCollectQueryOn(forkedSession)
activeSession.sparkContext.listenerBus.waitUntilEmpty(1000)
assert(collectorA.commands.length == 1) // forked should callback to A
assert(collectorA.commands(0) == "collect")
@ -162,12 +163,14 @@ class SessionStateSuite extends SparkFunSuite {
// => changes to forked do not affect original
forkedSession.listenerManager.register(collectorB)
runCollectQueryOn(activeSession)
activeSession.sparkContext.listenerBus.waitUntilEmpty(1000)
assert(collectorB.commands.isEmpty) // original should not callback to B
assert(collectorA.commands.length == 2) // original should still callback to A
assert(collectorA.commands(1) == "collect")
// <= changes to original do not affect forked
activeSession.listenerManager.register(collectorC)
runCollectQueryOn(forkedSession)
activeSession.sparkContext.listenerBus.waitUntilEmpty(1000)
assert(collectorC.commands.isEmpty) // forked should not callback to C
assert(collectorA.commands.length == 3) // forked should still callback to A
assert(collectorB.commands.length == 1) // forked should still callback to B

View file

@ -356,10 +356,13 @@ class UDFSuite extends QueryTest with SharedSQLContext {
.withColumn("b", udf1($"a", lit(10)))
df.cache()
df.write.saveAsTable("t")
sparkContext.listenerBus.waitUntilEmpty(1000)
assert(numTotalCachedHit == 1, "expected to be cached in saveAsTable")
df.write.insertInto("t")
sparkContext.listenerBus.waitUntilEmpty(1000)
assert(numTotalCachedHit == 2, "expected to be cached in insertInto")
df.write.save(path.getCanonicalPath)
sparkContext.listenerBus.waitUntilEmpty(1000)
assert(numTotalCachedHit == 3, "expected to be cached in save for native")
}
}

View file

@ -17,13 +17,15 @@
package org.apache.spark.sql.execution
import org.json4s.jackson.JsonMethods.parse
import org.json4s.jackson.JsonMethods._
import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.execution.ui.SparkListenerSQLExecutionStart
import org.apache.spark.sql.LocalSparkSession
import org.apache.spark.sql.execution.ui.{SparkListenerSQLExecutionEnd, SparkListenerSQLExecutionStart}
import org.apache.spark.sql.test.TestSparkSession
import org.apache.spark.util.JsonProtocol
class SQLJsonProtocolSuite extends SparkFunSuite {
class SQLJsonProtocolSuite extends SparkFunSuite with LocalSparkSession {
test("SparkPlanGraph backward compatibility: metadata") {
val SQLExecutionStartJsonString =
@ -49,4 +51,29 @@ class SQLJsonProtocolSuite extends SparkFunSuite {
new SparkPlanInfo("TestNode", "test string", Nil, Map(), Nil), 0)
assert(reconstructedEvent == expectedEvent)
}
test("SparkListenerSQLExecutionEnd backward compatibility") {
spark = new TestSparkSession()
val qe = spark.sql("select 1").queryExecution
val event = SparkListenerSQLExecutionEnd(1, 10)
event.duration = 1000
event.executionName = Some("test")
event.qe = qe
event.executionFailure = Some(new RuntimeException("test"))
val json = JsonProtocol.sparkEventToJson(event)
assert(json == parse(
"""
|{
| "Event" : "org.apache.spark.sql.execution.ui.SparkListenerSQLExecutionEnd",
| "executionId" : 1,
| "time" : 10
|}
""".stripMargin))
val readBack = JsonProtocol.sparkEventFromJson(json)
event.duration = 0
event.executionName = None
event.qe = null
event.executionFailure = None
assert(readBack == event)
}
}

View file

@ -48,6 +48,7 @@ class DataFrameCallbackSuite extends QueryTest with SharedSQLContext {
df.select("i").collect()
df.filter($"i" > 0).count()
sparkContext.listenerBus.waitUntilEmpty(1000)
assert(metrics.length == 2)
assert(metrics(0)._1 == "collect")
@ -78,6 +79,7 @@ class DataFrameCallbackSuite extends QueryTest with SharedSQLContext {
val e = intercept[SparkException](df.select(errorUdf($"i")).collect())
sparkContext.listenerBus.waitUntilEmpty(1000)
assert(metrics.length == 1)
assert(metrics(0)._1 == "collect")
assert(metrics(0)._2.analyzed.isInstanceOf[Project])
@ -103,10 +105,16 @@ class DataFrameCallbackSuite extends QueryTest with SharedSQLContext {
spark.listenerManager.register(listener)
val df = Seq(1 -> "a").toDF("i", "j").groupBy("i").count()
df.collect()
// Wait for the first `collect` to be caught by our listener. Otherwise the next `collect` will
// reset the plan metrics.
sparkContext.listenerBus.waitUntilEmpty(1000)
df.collect()
Seq(1 -> "a", 2 -> "a").toDF("i", "j").groupBy("i").count().collect()
sparkContext.listenerBus.waitUntilEmpty(1000)
assert(metrics.length == 3)
assert(metrics(0) === 1)
assert(metrics(1) === 1)
@ -154,6 +162,7 @@ class DataFrameCallbackSuite extends QueryTest with SharedSQLContext {
// For this simple case, the peakExecutionMemory of a stage should be the data size of the
// aggregate operator, as we only have one memory consuming operator per stage.
sparkContext.listenerBus.waitUntilEmpty(1000)
assert(metrics.length == 2)
assert(metrics(0) == topAggDataSize)
assert(metrics(1) == bottomAggDataSize)
@ -177,6 +186,7 @@ class DataFrameCallbackSuite extends QueryTest with SharedSQLContext {
withTempPath { path =>
spark.range(10).write.format("json").save(path.getCanonicalPath)
sparkContext.listenerBus.waitUntilEmpty(1000)
assert(commands.length == 1)
assert(commands.head._1 == "save")
assert(commands.head._2.isInstanceOf[InsertIntoHadoopFsRelationCommand])
@ -187,6 +197,7 @@ class DataFrameCallbackSuite extends QueryTest with SharedSQLContext {
withTable("tab") {
sql("CREATE TABLE tab(i long) using parquet") // adds commands(1) via onSuccess
spark.range(10).write.insertInto("tab")
sparkContext.listenerBus.waitUntilEmpty(1000)
assert(commands.length == 3)
assert(commands(2)._1 == "insertInto")
assert(commands(2)._2.isInstanceOf[InsertIntoTable])
@ -197,6 +208,7 @@ class DataFrameCallbackSuite extends QueryTest with SharedSQLContext {
withTable("tab") {
spark.range(10).select($"id", $"id" % 5 as "p").write.partitionBy("p").saveAsTable("tab")
sparkContext.listenerBus.waitUntilEmpty(1000)
assert(commands.length == 5)
assert(commands(4)._1 == "saveAsTable")
assert(commands(4)._2.isInstanceOf[CreateTable])
@ -208,6 +220,7 @@ class DataFrameCallbackSuite extends QueryTest with SharedSQLContext {
val e = intercept[AnalysisException] {
spark.range(10).select($"id", $"id").write.insertInto("tab")
}
sparkContext.listenerBus.waitUntilEmpty(1000)
assert(exceptions.length == 1)
assert(exceptions.head._1 == "insertInto")
assert(exceptions.head._2 == e)

View file

@ -20,26 +20,28 @@ package org.apache.spark.sql.util
import java.util.concurrent.atomic.AtomicInteger
import org.apache.spark._
import org.apache.spark.sql.{LocalSparkSession, SparkSession}
import org.apache.spark.sql.execution.QueryExecution
import org.apache.spark.sql.internal.StaticSQLConf._
class ExecutionListenerManagerSuite extends SparkFunSuite {
class ExecutionListenerManagerSuite extends SparkFunSuite with LocalSparkSession {
import CountingQueryExecutionListener._
test("register query execution listeners using configuration") {
val conf = new SparkConf(false)
.set(QUERY_EXECUTION_LISTENERS, Seq(classOf[CountingQueryExecutionListener].getName()))
spark = SparkSession.builder().master("local").appName("test").config(conf).getOrCreate()
val mgr = new ExecutionListenerManager(conf)
spark.sql("select 1").collect()
spark.sparkContext.listenerBus.waitUntilEmpty(1000)
assert(INSTANCE_COUNT.get() === 1)
mgr.onSuccess(null, null, 42L)
assert(CALLBACK_COUNT.get() === 1)
val clone = mgr.clone()
val cloned = spark.cloneSession()
cloned.sql("select 1").collect()
spark.sparkContext.listenerBus.waitUntilEmpty(1000)
assert(INSTANCE_COUNT.get() === 1)
clone.onSuccess(null, null, 42L)
assert(CALLBACK_COUNT.get() === 2)
}