[SPARK-25680][SQL] SQL execution listener shouldn't happen on execution thread
## What changes were proposed in this pull request? The SQL execution listener framework was created from scratch(see https://github.com/apache/spark/pull/9078). It didn't leverage what we already have in the spark listener framework, and one major problem is, the listener runs on the spark execution thread, which means a bad listener can block spark's query processing. This PR re-implements the SQL execution listener framework. Now `ExecutionListenerManager` is just a normal spark listener, which watches the `SparkListenerSQLExecutionEnd` events and post events to the user-provided SQL execution listeners. ## How was this patch tested? existing tests. Closes #22674 from cloud-fan/listener. Authored-by: Wenchen Fan <wenchen@databricks.com> Signed-off-by: Wenchen Fan <wenchen@databricks.com>
This commit is contained in:
parent
e9332f600e
commit
9690eba16e
|
@ -60,6 +60,14 @@ private[spark] trait ListenerBus[L <: AnyRef, E] extends Logging {
|
|||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Remove all listeners and they won't receive any events. This method is thread-safe and can be
|
||||
* called in any thread.
|
||||
*/
|
||||
final def removeAllListeners(): Unit = {
|
||||
listenersPlusTimers.clear()
|
||||
}
|
||||
|
||||
/**
|
||||
* This can be overridden by subclasses if there is any extra cleanup to do when removing a
|
||||
* listener. In particular AsyncEventQueues can clean up queues in the LiveListenerBus.
|
||||
|
|
|
@ -38,7 +38,9 @@ object MimaExcludes {
|
|||
lazy val v30excludes = v24excludes ++ Seq(
|
||||
ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.io.SnappyCompressionCodec.version"),
|
||||
ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.api.java.JavaPairRDD.flatMapValues"),
|
||||
ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.streaming.api.java.JavaPairDStream.flatMapValues")
|
||||
ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.streaming.api.java.JavaPairDStream.flatMapValues"),
|
||||
ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.sql.util.ExecutionListenerManager.clone"),
|
||||
ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.sql.util.ExecutionListenerManager.this")
|
||||
)
|
||||
|
||||
// Exclude rules for 2.4.x
|
||||
|
|
|
@ -672,17 +672,8 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) {
|
|||
*/
|
||||
private def runCommand(session: SparkSession, name: String)(command: LogicalPlan): Unit = {
|
||||
val qe = session.sessionState.executePlan(command)
|
||||
try {
|
||||
val start = System.nanoTime()
|
||||
// call `QueryExecution.toRDD` to trigger the execution of commands.
|
||||
SQLExecution.withNewExecutionId(session, qe)(qe.toRdd)
|
||||
val end = System.nanoTime()
|
||||
session.listenerManager.onSuccess(name, qe, end - start)
|
||||
} catch {
|
||||
case e: Exception =>
|
||||
session.listenerManager.onFailure(name, qe, e)
|
||||
throw e
|
||||
}
|
||||
// call `QueryExecution.toRDD` to trigger the execution of commands.
|
||||
SQLExecution.withNewExecutionId(session, qe, Some(name))(qe.toRdd)
|
||||
}
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////////
|
||||
|
|
|
@ -3356,21 +3356,11 @@ class Dataset[T] private[sql](
|
|||
* user-registered callback functions.
|
||||
*/
|
||||
private def withAction[U](name: String, qe: QueryExecution)(action: SparkPlan => U) = {
|
||||
try {
|
||||
SQLExecution.withNewExecutionId(sparkSession, qe, Some(name)) {
|
||||
qe.executedPlan.foreach { plan =>
|
||||
plan.resetMetrics()
|
||||
}
|
||||
val start = System.nanoTime()
|
||||
val result = SQLExecution.withNewExecutionId(sparkSession, qe) {
|
||||
action(qe.executedPlan)
|
||||
}
|
||||
val end = System.nanoTime()
|
||||
sparkSession.listenerManager.onSuccess(name, qe, end - start)
|
||||
result
|
||||
} catch {
|
||||
case e: Exception =>
|
||||
sparkSession.listenerManager.onFailure(name, qe, e)
|
||||
throw e
|
||||
action(qe.executedPlan)
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -58,7 +58,8 @@ object SQLExecution {
|
|||
*/
|
||||
def withNewExecutionId[T](
|
||||
sparkSession: SparkSession,
|
||||
queryExecution: QueryExecution)(body: => T): T = {
|
||||
queryExecution: QueryExecution,
|
||||
name: Option[String] = None)(body: => T): T = {
|
||||
val sc = sparkSession.sparkContext
|
||||
val oldExecutionId = sc.getLocalProperty(EXECUTION_ID_KEY)
|
||||
val executionId = SQLExecution.nextExecutionId
|
||||
|
@ -71,14 +72,35 @@ object SQLExecution {
|
|||
val callSite = sc.getCallSite()
|
||||
|
||||
withSQLConfPropagated(sparkSession) {
|
||||
sc.listenerBus.post(SparkListenerSQLExecutionStart(
|
||||
executionId, callSite.shortForm, callSite.longForm, queryExecution.toString,
|
||||
SparkPlanInfo.fromSparkPlan(queryExecution.executedPlan), System.currentTimeMillis()))
|
||||
var ex: Option[Exception] = None
|
||||
val startTime = System.nanoTime()
|
||||
try {
|
||||
sc.listenerBus.post(SparkListenerSQLExecutionStart(
|
||||
executionId = executionId,
|
||||
description = callSite.shortForm,
|
||||
details = callSite.longForm,
|
||||
physicalPlanDescription = queryExecution.toString,
|
||||
// `queryExecution.executedPlan` triggers query planning. If it fails, the exception
|
||||
// will be caught and reported in the `SparkListenerSQLExecutionEnd`
|
||||
sparkPlanInfo = SparkPlanInfo.fromSparkPlan(queryExecution.executedPlan),
|
||||
time = System.currentTimeMillis()))
|
||||
body
|
||||
} catch {
|
||||
case e: Exception =>
|
||||
ex = Some(e)
|
||||
throw e
|
||||
} finally {
|
||||
sc.listenerBus.post(SparkListenerSQLExecutionEnd(
|
||||
executionId, System.currentTimeMillis()))
|
||||
val endTime = System.nanoTime()
|
||||
val event = SparkListenerSQLExecutionEnd(executionId, System.currentTimeMillis())
|
||||
// Currently only `Dataset.withAction` and `DataFrameWriter.runCommand` specify the `name`
|
||||
// parameter. The `ExecutionListenerManager` only watches SQL executions with name. We
|
||||
// can specify the execution name in more places in the future, so that
|
||||
// `QueryExecutionListener` can track more cases.
|
||||
event.executionName = name
|
||||
event.duration = endTime - startTime
|
||||
event.qe = queryExecution
|
||||
event.executionFailure = ex
|
||||
sc.listenerBus.post(event)
|
||||
}
|
||||
}
|
||||
} finally {
|
||||
|
|
|
@ -17,6 +17,7 @@
|
|||
|
||||
package org.apache.spark.sql.execution.ui
|
||||
|
||||
import com.fasterxml.jackson.annotation.JsonIgnore
|
||||
import com.fasterxml.jackson.databind.JavaType
|
||||
import com.fasterxml.jackson.databind.`type`.TypeFactory
|
||||
import com.fasterxml.jackson.databind.annotation.JsonDeserialize
|
||||
|
@ -24,8 +25,7 @@ import com.fasterxml.jackson.databind.util.Converter
|
|||
|
||||
import org.apache.spark.annotation.DeveloperApi
|
||||
import org.apache.spark.scheduler._
|
||||
import org.apache.spark.sql.execution.SparkPlanInfo
|
||||
import org.apache.spark.sql.execution.metric._
|
||||
import org.apache.spark.sql.execution.{QueryExecution, SparkPlanInfo}
|
||||
|
||||
@DeveloperApi
|
||||
case class SparkListenerSQLExecutionStart(
|
||||
|
@ -39,7 +39,22 @@ case class SparkListenerSQLExecutionStart(
|
|||
|
||||
@DeveloperApi
|
||||
case class SparkListenerSQLExecutionEnd(executionId: Long, time: Long)
|
||||
extends SparkListenerEvent
|
||||
extends SparkListenerEvent {
|
||||
|
||||
// The name of the execution, e.g. `df.collect` will trigger a SQL execution with name "collect".
|
||||
@JsonIgnore private[sql] var executionName: Option[String] = None
|
||||
|
||||
// The following 3 fields are only accessed when `executionName` is defined.
|
||||
|
||||
// The duration of the SQL execution, in nanoseconds.
|
||||
@JsonIgnore private[sql] var duration: Long = 0L
|
||||
|
||||
// The `QueryExecution` instance that represents the SQL execution
|
||||
@JsonIgnore private[sql] var qe: QueryExecution = null
|
||||
|
||||
// The exception object that caused this execution to fail. None if the execution doesn't fail.
|
||||
@JsonIgnore private[sql] var executionFailure: Option[Exception] = None
|
||||
}
|
||||
|
||||
/**
|
||||
* A message used to update SQL metric value for driver-side updates (which doesn't get reflected
|
||||
|
|
|
@ -266,8 +266,8 @@ abstract class BaseSessionStateBuilder(
|
|||
* This gets cloned from parent if available, otherwise a new instance is created.
|
||||
*/
|
||||
protected def listenerManager: ExecutionListenerManager = {
|
||||
parentState.map(_.listenerManager.clone()).getOrElse(
|
||||
new ExecutionListenerManager(session.sparkContext.conf))
|
||||
parentState.map(_.listenerManager.clone(session)).getOrElse(
|
||||
new ExecutionListenerManager(session, loadExtensions = true))
|
||||
}
|
||||
|
||||
/**
|
||||
|
|
|
@ -17,17 +17,16 @@
|
|||
|
||||
package org.apache.spark.sql.util
|
||||
|
||||
import java.util.concurrent.locks.ReentrantReadWriteLock
|
||||
import scala.collection.JavaConverters._
|
||||
|
||||
import scala.collection.mutable.ListBuffer
|
||||
import scala.util.control.NonFatal
|
||||
|
||||
import org.apache.spark.SparkConf
|
||||
import org.apache.spark.annotation.{DeveloperApi, Experimental, InterfaceStability}
|
||||
import org.apache.spark.internal.Logging
|
||||
import org.apache.spark.scheduler.{SparkListener, SparkListenerEvent}
|
||||
import org.apache.spark.sql.SparkSession
|
||||
import org.apache.spark.sql.execution.QueryExecution
|
||||
import org.apache.spark.sql.execution.ui.SparkListenerSQLExecutionEnd
|
||||
import org.apache.spark.sql.internal.StaticSQLConf._
|
||||
import org.apache.spark.util.Utils
|
||||
import org.apache.spark.util.{ListenerBus, Utils}
|
||||
|
||||
/**
|
||||
* :: Experimental ::
|
||||
|
@ -75,10 +74,18 @@ trait QueryExecutionListener {
|
|||
*/
|
||||
@Experimental
|
||||
@InterfaceStability.Evolving
|
||||
class ExecutionListenerManager private extends Logging {
|
||||
// The `session` is used to indicate which session carries this listener manager, and we only
|
||||
// catch SQL executions which are launched by the same session.
|
||||
// The `loadExtensions` flag is used to indicate whether we should load the pre-defined,
|
||||
// user-specified listeners during construction. We should not do it when cloning this listener
|
||||
// manager, as we will copy all listeners to the cloned listener manager.
|
||||
class ExecutionListenerManager private[sql](session: SparkSession, loadExtensions: Boolean)
|
||||
extends Logging {
|
||||
|
||||
private[sql] def this(conf: SparkConf) = {
|
||||
this()
|
||||
private val listenerBus = new ExecutionListenerBus(session)
|
||||
|
||||
if (loadExtensions) {
|
||||
val conf = session.sparkContext.conf
|
||||
conf.get(QUERY_EXECUTION_LISTENERS).foreach { classNames =>
|
||||
Utils.loadExtensions(classOf[QueryExecutionListener], classNames, conf).foreach(register)
|
||||
}
|
||||
|
@ -88,82 +95,63 @@ class ExecutionListenerManager private extends Logging {
|
|||
* Registers the specified [[QueryExecutionListener]].
|
||||
*/
|
||||
@DeveloperApi
|
||||
def register(listener: QueryExecutionListener): Unit = writeLock {
|
||||
listeners += listener
|
||||
def register(listener: QueryExecutionListener): Unit = {
|
||||
listenerBus.addListener(listener)
|
||||
}
|
||||
|
||||
/**
|
||||
* Unregisters the specified [[QueryExecutionListener]].
|
||||
*/
|
||||
@DeveloperApi
|
||||
def unregister(listener: QueryExecutionListener): Unit = writeLock {
|
||||
listeners -= listener
|
||||
def unregister(listener: QueryExecutionListener): Unit = {
|
||||
listenerBus.removeListener(listener)
|
||||
}
|
||||
|
||||
/**
|
||||
* Removes all the registered [[QueryExecutionListener]].
|
||||
*/
|
||||
@DeveloperApi
|
||||
def clear(): Unit = writeLock {
|
||||
listeners.clear()
|
||||
def clear(): Unit = {
|
||||
listenerBus.removeAllListeners()
|
||||
}
|
||||
|
||||
/**
|
||||
* Get an identical copy of this listener manager.
|
||||
*/
|
||||
@DeveloperApi
|
||||
override def clone(): ExecutionListenerManager = writeLock {
|
||||
val newListenerManager = new ExecutionListenerManager
|
||||
listeners.foreach(newListenerManager.register)
|
||||
private[sql] def clone(session: SparkSession): ExecutionListenerManager = {
|
||||
val newListenerManager = new ExecutionListenerManager(session, loadExtensions = false)
|
||||
listenerBus.listeners.asScala.foreach(newListenerManager.register)
|
||||
newListenerManager
|
||||
}
|
||||
}
|
||||
|
||||
private[sql] def onSuccess(funcName: String, qe: QueryExecution, duration: Long): Unit = {
|
||||
readLock {
|
||||
withErrorHandling { listener =>
|
||||
listener.onSuccess(funcName, qe, duration)
|
||||
private[sql] class ExecutionListenerBus(session: SparkSession)
|
||||
extends SparkListener with ListenerBus[QueryExecutionListener, SparkListenerSQLExecutionEnd] {
|
||||
|
||||
session.sparkContext.listenerBus.addToSharedQueue(this)
|
||||
|
||||
override def onOtherEvent(event: SparkListenerEvent): Unit = event match {
|
||||
case e: SparkListenerSQLExecutionEnd => postToAll(e)
|
||||
case _ =>
|
||||
}
|
||||
|
||||
override protected def doPostEvent(
|
||||
listener: QueryExecutionListener,
|
||||
event: SparkListenerSQLExecutionEnd): Unit = {
|
||||
if (shouldReport(event)) {
|
||||
val funcName = event.executionName.get
|
||||
event.executionFailure match {
|
||||
case Some(ex) =>
|
||||
listener.onFailure(funcName, event.qe, ex)
|
||||
case _ =>
|
||||
listener.onSuccess(funcName, event.qe, event.duration)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private[sql] def onFailure(funcName: String, qe: QueryExecution, exception: Exception): Unit = {
|
||||
readLock {
|
||||
withErrorHandling { listener =>
|
||||
listener.onFailure(funcName, qe, exception)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private[this] val listeners = ListBuffer.empty[QueryExecutionListener]
|
||||
|
||||
/** A lock to prevent updating the list of listeners while we are traversing through them. */
|
||||
private[this] val lock = new ReentrantReadWriteLock()
|
||||
|
||||
private def withErrorHandling(f: QueryExecutionListener => Unit): Unit = {
|
||||
for (listener <- listeners) {
|
||||
try {
|
||||
f(listener)
|
||||
} catch {
|
||||
case NonFatal(e) => logWarning("Error executing query execution listener", e)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/** Acquires a read lock on the cache for the duration of `f`. */
|
||||
private def readLock[A](f: => A): A = {
|
||||
val rl = lock.readLock()
|
||||
rl.lock()
|
||||
try f finally {
|
||||
rl.unlock()
|
||||
}
|
||||
}
|
||||
|
||||
/** Acquires a write lock on the cache for the duration of `f`. */
|
||||
private def writeLock[A](f: => A): A = {
|
||||
val wl = lock.writeLock()
|
||||
wl.lock()
|
||||
try f finally {
|
||||
wl.unlock()
|
||||
}
|
||||
private def shouldReport(e: SparkListenerSQLExecutionEnd): Boolean = {
|
||||
// Only catch SQL execution with a name, and triggered by the same spark session that this
|
||||
// listener manager belongs.
|
||||
e.executionName.isDefined && e.qe.sparkSession.eq(this.session)
|
||||
}
|
||||
}
|
||||
|
|
|
@ -155,6 +155,7 @@ class SessionStateSuite extends SparkFunSuite {
|
|||
assert(forkedSession ne activeSession)
|
||||
assert(forkedSession.listenerManager ne activeSession.listenerManager)
|
||||
runCollectQueryOn(forkedSession)
|
||||
activeSession.sparkContext.listenerBus.waitUntilEmpty(1000)
|
||||
assert(collectorA.commands.length == 1) // forked should callback to A
|
||||
assert(collectorA.commands(0) == "collect")
|
||||
|
||||
|
@ -162,12 +163,14 @@ class SessionStateSuite extends SparkFunSuite {
|
|||
// => changes to forked do not affect original
|
||||
forkedSession.listenerManager.register(collectorB)
|
||||
runCollectQueryOn(activeSession)
|
||||
activeSession.sparkContext.listenerBus.waitUntilEmpty(1000)
|
||||
assert(collectorB.commands.isEmpty) // original should not callback to B
|
||||
assert(collectorA.commands.length == 2) // original should still callback to A
|
||||
assert(collectorA.commands(1) == "collect")
|
||||
// <= changes to original do not affect forked
|
||||
activeSession.listenerManager.register(collectorC)
|
||||
runCollectQueryOn(forkedSession)
|
||||
activeSession.sparkContext.listenerBus.waitUntilEmpty(1000)
|
||||
assert(collectorC.commands.isEmpty) // forked should not callback to C
|
||||
assert(collectorA.commands.length == 3) // forked should still callback to A
|
||||
assert(collectorB.commands.length == 1) // forked should still callback to B
|
||||
|
|
|
@ -356,10 +356,13 @@ class UDFSuite extends QueryTest with SharedSQLContext {
|
|||
.withColumn("b", udf1($"a", lit(10)))
|
||||
df.cache()
|
||||
df.write.saveAsTable("t")
|
||||
sparkContext.listenerBus.waitUntilEmpty(1000)
|
||||
assert(numTotalCachedHit == 1, "expected to be cached in saveAsTable")
|
||||
df.write.insertInto("t")
|
||||
sparkContext.listenerBus.waitUntilEmpty(1000)
|
||||
assert(numTotalCachedHit == 2, "expected to be cached in insertInto")
|
||||
df.write.save(path.getCanonicalPath)
|
||||
sparkContext.listenerBus.waitUntilEmpty(1000)
|
||||
assert(numTotalCachedHit == 3, "expected to be cached in save for native")
|
||||
}
|
||||
}
|
||||
|
|
|
@ -17,13 +17,15 @@
|
|||
|
||||
package org.apache.spark.sql.execution
|
||||
|
||||
import org.json4s.jackson.JsonMethods.parse
|
||||
import org.json4s.jackson.JsonMethods._
|
||||
|
||||
import org.apache.spark.SparkFunSuite
|
||||
import org.apache.spark.sql.execution.ui.SparkListenerSQLExecutionStart
|
||||
import org.apache.spark.sql.LocalSparkSession
|
||||
import org.apache.spark.sql.execution.ui.{SparkListenerSQLExecutionEnd, SparkListenerSQLExecutionStart}
|
||||
import org.apache.spark.sql.test.TestSparkSession
|
||||
import org.apache.spark.util.JsonProtocol
|
||||
|
||||
class SQLJsonProtocolSuite extends SparkFunSuite {
|
||||
class SQLJsonProtocolSuite extends SparkFunSuite with LocalSparkSession {
|
||||
|
||||
test("SparkPlanGraph backward compatibility: metadata") {
|
||||
val SQLExecutionStartJsonString =
|
||||
|
@ -49,4 +51,29 @@ class SQLJsonProtocolSuite extends SparkFunSuite {
|
|||
new SparkPlanInfo("TestNode", "test string", Nil, Map(), Nil), 0)
|
||||
assert(reconstructedEvent == expectedEvent)
|
||||
}
|
||||
|
||||
test("SparkListenerSQLExecutionEnd backward compatibility") {
|
||||
spark = new TestSparkSession()
|
||||
val qe = spark.sql("select 1").queryExecution
|
||||
val event = SparkListenerSQLExecutionEnd(1, 10)
|
||||
event.duration = 1000
|
||||
event.executionName = Some("test")
|
||||
event.qe = qe
|
||||
event.executionFailure = Some(new RuntimeException("test"))
|
||||
val json = JsonProtocol.sparkEventToJson(event)
|
||||
assert(json == parse(
|
||||
"""
|
||||
|{
|
||||
| "Event" : "org.apache.spark.sql.execution.ui.SparkListenerSQLExecutionEnd",
|
||||
| "executionId" : 1,
|
||||
| "time" : 10
|
||||
|}
|
||||
""".stripMargin))
|
||||
val readBack = JsonProtocol.sparkEventFromJson(json)
|
||||
event.duration = 0
|
||||
event.executionName = None
|
||||
event.qe = null
|
||||
event.executionFailure = None
|
||||
assert(readBack == event)
|
||||
}
|
||||
}
|
||||
|
|
|
@ -48,6 +48,7 @@ class DataFrameCallbackSuite extends QueryTest with SharedSQLContext {
|
|||
df.select("i").collect()
|
||||
df.filter($"i" > 0).count()
|
||||
|
||||
sparkContext.listenerBus.waitUntilEmpty(1000)
|
||||
assert(metrics.length == 2)
|
||||
|
||||
assert(metrics(0)._1 == "collect")
|
||||
|
@ -78,6 +79,7 @@ class DataFrameCallbackSuite extends QueryTest with SharedSQLContext {
|
|||
|
||||
val e = intercept[SparkException](df.select(errorUdf($"i")).collect())
|
||||
|
||||
sparkContext.listenerBus.waitUntilEmpty(1000)
|
||||
assert(metrics.length == 1)
|
||||
assert(metrics(0)._1 == "collect")
|
||||
assert(metrics(0)._2.analyzed.isInstanceOf[Project])
|
||||
|
@ -103,10 +105,16 @@ class DataFrameCallbackSuite extends QueryTest with SharedSQLContext {
|
|||
spark.listenerManager.register(listener)
|
||||
|
||||
val df = Seq(1 -> "a").toDF("i", "j").groupBy("i").count()
|
||||
|
||||
df.collect()
|
||||
// Wait for the first `collect` to be caught by our listener. Otherwise the next `collect` will
|
||||
// reset the plan metrics.
|
||||
sparkContext.listenerBus.waitUntilEmpty(1000)
|
||||
df.collect()
|
||||
|
||||
Seq(1 -> "a", 2 -> "a").toDF("i", "j").groupBy("i").count().collect()
|
||||
|
||||
sparkContext.listenerBus.waitUntilEmpty(1000)
|
||||
assert(metrics.length == 3)
|
||||
assert(metrics(0) === 1)
|
||||
assert(metrics(1) === 1)
|
||||
|
@ -154,6 +162,7 @@ class DataFrameCallbackSuite extends QueryTest with SharedSQLContext {
|
|||
|
||||
// For this simple case, the peakExecutionMemory of a stage should be the data size of the
|
||||
// aggregate operator, as we only have one memory consuming operator per stage.
|
||||
sparkContext.listenerBus.waitUntilEmpty(1000)
|
||||
assert(metrics.length == 2)
|
||||
assert(metrics(0) == topAggDataSize)
|
||||
assert(metrics(1) == bottomAggDataSize)
|
||||
|
@ -177,6 +186,7 @@ class DataFrameCallbackSuite extends QueryTest with SharedSQLContext {
|
|||
|
||||
withTempPath { path =>
|
||||
spark.range(10).write.format("json").save(path.getCanonicalPath)
|
||||
sparkContext.listenerBus.waitUntilEmpty(1000)
|
||||
assert(commands.length == 1)
|
||||
assert(commands.head._1 == "save")
|
||||
assert(commands.head._2.isInstanceOf[InsertIntoHadoopFsRelationCommand])
|
||||
|
@ -187,6 +197,7 @@ class DataFrameCallbackSuite extends QueryTest with SharedSQLContext {
|
|||
withTable("tab") {
|
||||
sql("CREATE TABLE tab(i long) using parquet") // adds commands(1) via onSuccess
|
||||
spark.range(10).write.insertInto("tab")
|
||||
sparkContext.listenerBus.waitUntilEmpty(1000)
|
||||
assert(commands.length == 3)
|
||||
assert(commands(2)._1 == "insertInto")
|
||||
assert(commands(2)._2.isInstanceOf[InsertIntoTable])
|
||||
|
@ -197,6 +208,7 @@ class DataFrameCallbackSuite extends QueryTest with SharedSQLContext {
|
|||
|
||||
withTable("tab") {
|
||||
spark.range(10).select($"id", $"id" % 5 as "p").write.partitionBy("p").saveAsTable("tab")
|
||||
sparkContext.listenerBus.waitUntilEmpty(1000)
|
||||
assert(commands.length == 5)
|
||||
assert(commands(4)._1 == "saveAsTable")
|
||||
assert(commands(4)._2.isInstanceOf[CreateTable])
|
||||
|
@ -208,6 +220,7 @@ class DataFrameCallbackSuite extends QueryTest with SharedSQLContext {
|
|||
val e = intercept[AnalysisException] {
|
||||
spark.range(10).select($"id", $"id").write.insertInto("tab")
|
||||
}
|
||||
sparkContext.listenerBus.waitUntilEmpty(1000)
|
||||
assert(exceptions.length == 1)
|
||||
assert(exceptions.head._1 == "insertInto")
|
||||
assert(exceptions.head._2 == e)
|
||||
|
|
|
@ -20,26 +20,28 @@ package org.apache.spark.sql.util
|
|||
import java.util.concurrent.atomic.AtomicInteger
|
||||
|
||||
import org.apache.spark._
|
||||
import org.apache.spark.sql.{LocalSparkSession, SparkSession}
|
||||
import org.apache.spark.sql.execution.QueryExecution
|
||||
import org.apache.spark.sql.internal.StaticSQLConf._
|
||||
|
||||
class ExecutionListenerManagerSuite extends SparkFunSuite {
|
||||
class ExecutionListenerManagerSuite extends SparkFunSuite with LocalSparkSession {
|
||||
|
||||
import CountingQueryExecutionListener._
|
||||
|
||||
test("register query execution listeners using configuration") {
|
||||
val conf = new SparkConf(false)
|
||||
.set(QUERY_EXECUTION_LISTENERS, Seq(classOf[CountingQueryExecutionListener].getName()))
|
||||
spark = SparkSession.builder().master("local").appName("test").config(conf).getOrCreate()
|
||||
|
||||
val mgr = new ExecutionListenerManager(conf)
|
||||
spark.sql("select 1").collect()
|
||||
spark.sparkContext.listenerBus.waitUntilEmpty(1000)
|
||||
assert(INSTANCE_COUNT.get() === 1)
|
||||
mgr.onSuccess(null, null, 42L)
|
||||
assert(CALLBACK_COUNT.get() === 1)
|
||||
|
||||
val clone = mgr.clone()
|
||||
val cloned = spark.cloneSession()
|
||||
cloned.sql("select 1").collect()
|
||||
spark.sparkContext.listenerBus.waitUntilEmpty(1000)
|
||||
assert(INSTANCE_COUNT.get() === 1)
|
||||
|
||||
clone.onSuccess(null, null, 42L)
|
||||
assert(CALLBACK_COUNT.get() === 2)
|
||||
}
|
||||
|
||||
|
|
Loading…
Reference in a new issue