[SPARK-21238][SQL] allow nested SQL execution

## What changes were proposed in this pull request?

This is kind of another follow-up for https://github.com/apache/spark/pull/18064 .

In #18064 , we wrap every SQL command with SQL execution, which makes nested SQL execution very likely to happen. #18419 trid to improve it a little bit, by introduing `SQLExecition.ignoreNestedExecutionId`. However, this is not friendly to data source developers, they may need to update their code to use this `ignoreNestedExecutionId` API.

This PR proposes a new solution, to just allow nested execution. The downside is that, we may have multiple executions for one query. We can improve this by updating the data organization in SQLListener, to have 1-n mapping from query to execution, instead of 1-1 mapping. This can be done in a follow-up.

## How was this patch tested?

existing tests.

Author: Wenchen Fan <wenchen@databricks.com>

Closes #18450 from cloud-fan/execution-id.
This commit is contained in:
Wenchen Fan 2017-06-29 14:37:42 +08:00
parent a946be35ac
commit 9f6b3e65cc
8 changed files with 47 additions and 135 deletions

View file

@ -22,15 +22,12 @@ import java.util.concurrent.atomic.AtomicLong
import org.apache.spark.SparkContext
import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.execution.ui.{SparkListenerSQLExecutionEnd,
SparkListenerSQLExecutionStart}
import org.apache.spark.sql.execution.ui.{SparkListenerSQLExecutionEnd, SparkListenerSQLExecutionStart}
object SQLExecution {
val EXECUTION_ID_KEY = "spark.sql.execution.id"
private val IGNORE_NESTED_EXECUTION_ID = "spark.sql.execution.ignoreNestedExecutionId"
private val _nextExecutionId = new AtomicLong(0)
private def nextExecutionId: Long = _nextExecutionId.getAndIncrement
@ -45,10 +42,8 @@ object SQLExecution {
private[sql] def checkSQLExecutionId(sparkSession: SparkSession): Unit = {
val sc = sparkSession.sparkContext
val isNestedExecution = sc.getLocalProperty(IGNORE_NESTED_EXECUTION_ID) != null
val hasExecutionId = sc.getLocalProperty(EXECUTION_ID_KEY) != null
// only throw an exception during tests. a missing execution ID should not fail a job.
if (testing && !isNestedExecution && !hasExecutionId) {
if (testing && sc.getLocalProperty(EXECUTION_ID_KEY) == null) {
// Attention testers: when a test fails with this exception, it means that the action that
// started execution of a query didn't call withNewExecutionId. The execution ID should be
// set by calling withNewExecutionId in the action that begins execution, like
@ -66,56 +61,27 @@ object SQLExecution {
queryExecution: QueryExecution)(body: => T): T = {
val sc = sparkSession.sparkContext
val oldExecutionId = sc.getLocalProperty(EXECUTION_ID_KEY)
if (oldExecutionId == null) {
val executionId = SQLExecution.nextExecutionId
sc.setLocalProperty(EXECUTION_ID_KEY, executionId.toString)
executionIdToQueryExecution.put(executionId, queryExecution)
try {
// sparkContext.getCallSite() would first try to pick up any call site that was previously
// set, then fall back to Utils.getCallSite(); call Utils.getCallSite() directly on
// streaming queries would give us call site like "run at <unknown>:0"
val callSite = sparkSession.sparkContext.getCallSite()
val executionId = SQLExecution.nextExecutionId
sc.setLocalProperty(EXECUTION_ID_KEY, executionId.toString)
executionIdToQueryExecution.put(executionId, queryExecution)
try {
// sparkContext.getCallSite() would first try to pick up any call site that was previously
// set, then fall back to Utils.getCallSite(); call Utils.getCallSite() directly on
// streaming queries would give us call site like "run at <unknown>:0"
val callSite = sparkSession.sparkContext.getCallSite()
sparkSession.sparkContext.listenerBus.post(SparkListenerSQLExecutionStart(
executionId, callSite.shortForm, callSite.longForm, queryExecution.toString,
SparkPlanInfo.fromSparkPlan(queryExecution.executedPlan), System.currentTimeMillis()))
try {
body
} finally {
sparkSession.sparkContext.listenerBus.post(SparkListenerSQLExecutionEnd(
executionId, System.currentTimeMillis()))
}
} finally {
executionIdToQueryExecution.remove(executionId)
sc.setLocalProperty(EXECUTION_ID_KEY, null)
}
} else if (sc.getLocalProperty(IGNORE_NESTED_EXECUTION_ID) != null) {
// If `IGNORE_NESTED_EXECUTION_ID` is set, just ignore the execution id while evaluating the
// `body`, so that Spark jobs issued in the `body` won't be tracked.
sparkSession.sparkContext.listenerBus.post(SparkListenerSQLExecutionStart(
executionId, callSite.shortForm, callSite.longForm, queryExecution.toString,
SparkPlanInfo.fromSparkPlan(queryExecution.executedPlan), System.currentTimeMillis()))
try {
sc.setLocalProperty(EXECUTION_ID_KEY, null)
body
} finally {
sc.setLocalProperty(EXECUTION_ID_KEY, oldExecutionId)
sparkSession.sparkContext.listenerBus.post(SparkListenerSQLExecutionEnd(
executionId, System.currentTimeMillis()))
}
} else {
// Don't support nested `withNewExecutionId`. This is an example of the nested
// `withNewExecutionId`:
//
// class DataFrame {
// def foo: T = withNewExecutionId { something.createNewDataFrame().collect() }
// }
//
// Note: `collect` will call withNewExecutionId
// In this case, only the "executedPlan" for "collect" will be executed. The "executedPlan"
// for the outer DataFrame won't be executed. So it's meaningless to create a new Execution
// for the outer DataFrame. Even if we track it, since its "executedPlan" doesn't run,
// all accumulator metrics will be 0. It will confuse people if we show them in Web UI.
//
// A real case is the `DataFrame.count` method.
throw new IllegalArgumentException(s"$EXECUTION_ID_KEY is already set, please wrap your " +
"action with SQLExecution.ignoreNestedExecutionId if you don't want to track the Spark " +
"jobs issued by the nested execution.")
} finally {
executionIdToQueryExecution.remove(executionId)
sc.setLocalProperty(EXECUTION_ID_KEY, oldExecutionId)
}
}
@ -133,20 +99,4 @@ object SQLExecution {
sc.setLocalProperty(SQLExecution.EXECUTION_ID_KEY, oldExecutionId)
}
}
/**
* Wrap an action which may have nested execution id. This method can be used to run an execution
* inside another execution, e.g., `CacheTableCommand` need to call `Dataset.collect`. Note that,
* all Spark jobs issued in the body won't be tracked in UI.
*/
def ignoreNestedExecutionId[T](sparkSession: SparkSession)(body: => T): T = {
val sc = sparkSession.sparkContext
val allowNestedPreviousValue = sc.getLocalProperty(IGNORE_NESTED_EXECUTION_ID)
try {
sc.setLocalProperty(IGNORE_NESTED_EXECUTION_ID, "true")
body
} finally {
sc.setLocalProperty(IGNORE_NESTED_EXECUTION_ID, allowNestedPreviousValue)
}
}
}

View file

@ -51,9 +51,7 @@ case class AnalyzeTableCommand(
// 2. when total size is changed, `oldRowCount` becomes invalid.
// This is to make sure that we only record the right statistics.
if (!noscan) {
val newRowCount = SQLExecution.ignoreNestedExecutionId(sparkSession) {
sparkSession.table(tableIdentWithDB).count()
}
val newRowCount = sparkSession.table(tableIdentWithDB).count()
if (newRowCount >= 0 && newRowCount != oldRowCount) {
newStats = if (newStats.isDefined) {
newStats.map(_.copy(rowCount = Some(BigInt(newRowCount))))

View file

@ -34,16 +34,14 @@ case class CacheTableCommand(
override def innerChildren: Seq[QueryPlan[_]] = plan.toSeq
override def run(sparkSession: SparkSession): Seq[Row] = {
SQLExecution.ignoreNestedExecutionId(sparkSession) {
plan.foreach { logicalPlan =>
Dataset.ofRows(sparkSession, logicalPlan).createTempView(tableIdent.quotedString)
}
sparkSession.catalog.cacheTable(tableIdent.quotedString)
plan.foreach { logicalPlan =>
Dataset.ofRows(sparkSession, logicalPlan).createTempView(tableIdent.quotedString)
}
sparkSession.catalog.cacheTable(tableIdent.quotedString)
if (!isLazy) {
// Performs eager caching
sparkSession.table(tableIdent).count()
}
if (!isLazy) {
// Performs eager caching
sparkSession.table(tableIdent).count()
}
Seq.empty[Row]

View file

@ -145,9 +145,7 @@ object TextInputCSVDataSource extends CSVDataSource {
inputPaths: Seq[FileStatus],
parsedOptions: CSVOptions): StructType = {
val csv = createBaseDataset(sparkSession, inputPaths, parsedOptions)
val maybeFirstLine = SQLExecution.ignoreNestedExecutionId(sparkSession) {
CSVUtils.filterCommentAndEmpty(csv, parsedOptions).take(1).headOption
}
val maybeFirstLine = CSVUtils.filterCommentAndEmpty(csv, parsedOptions).take(1).headOption
inferFromDataset(sparkSession, csv, maybeFirstLine, parsedOptions)
}

View file

@ -130,11 +130,9 @@ private[sql] case class JDBCRelation(
}
override def insert(data: DataFrame, overwrite: Boolean): Unit = {
SQLExecution.ignoreNestedExecutionId(data.sparkSession) {
data.write
.mode(if (overwrite) SaveMode.Overwrite else SaveMode.Append)
.jdbc(jdbcOptions.url, jdbcOptions.table, jdbcOptions.asProperties)
}
data.write
.mode(if (overwrite) SaveMode.Overwrite else SaveMode.Append)
.jdbc(jdbcOptions.url, jdbcOptions.table, jdbcOptions.asProperties)
}
override def toString: String = {

View file

@ -48,11 +48,9 @@ class ConsoleSink(options: Map[String, String]) extends Sink with Logging {
println(batchIdStr)
println("-------------------------------------------")
// scalastyle:off println
SQLExecution.ignoreNestedExecutionId(data.sparkSession) {
data.sparkSession.createDataFrame(
data.sparkSession.sparkContext.parallelize(data.collect()), data.schema)
.show(numRowsToShow, isTruncated)
}
data.sparkSession.createDataFrame(
data.sparkSession.sparkContext.parallelize(data.collect()), data.schema)
.show(numRowsToShow, isTruncated)
}
}
@ -82,9 +80,7 @@ class ConsoleSinkProvider extends StreamSinkProvider
// Truncate the displayed data if it is too long, by default it is true
val isTruncated = parameters.get("truncate").map(_.toBoolean).getOrElse(true)
SQLExecution.ignoreNestedExecutionId(sqlContext.sparkSession) {
data.show(numRowsToShow, isTruncated)
}
data.show(numRowsToShow, isTruncated)
ConsoleRelation(sqlContext, data)
}

View file

@ -194,23 +194,21 @@ class MemorySink(val schema: StructType, outputMode: OutputMode) extends Sink wi
}
if (notCommitted) {
logDebug(s"Committing batch $batchId to $this")
SQLExecution.ignoreNestedExecutionId(data.sparkSession) {
outputMode match {
case Append | Update =>
val rows = AddedData(batchId, data.collect())
synchronized { batches += rows }
outputMode match {
case Append | Update =>
val rows = AddedData(batchId, data.collect())
synchronized { batches += rows }
case Complete =>
val rows = AddedData(batchId, data.collect())
synchronized {
batches.clear()
batches += rows
}
case Complete =>
val rows = AddedData(batchId, data.collect())
synchronized {
batches.clear()
batches += rows
}
case _ =>
throw new IllegalArgumentException(
s"Output mode $outputMode is not supported by MemorySink")
}
case _ =>
throw new IllegalArgumentException(
s"Output mode $outputMode is not supported by MemorySink")
}
} else {
logDebug(s"Skipping already committed batch: $batchId")

View file

@ -26,22 +26,9 @@ import org.apache.spark.sql.SparkSession
class SQLExecutionSuite extends SparkFunSuite {
test("concurrent query execution (SPARK-10548)") {
// Try to reproduce the issue with the old SparkContext
val conf = new SparkConf()
.setMaster("local[*]")
.setAppName("test")
val badSparkContext = new BadSparkContext(conf)
try {
testConcurrentQueryExecution(badSparkContext)
fail("unable to reproduce SPARK-10548")
} catch {
case e: IllegalArgumentException =>
assert(e.getMessage.contains(SQLExecution.EXECUTION_ID_KEY))
} finally {
badSparkContext.stop()
}
// Verify that the issue is fixed with the latest SparkContext
val goodSparkContext = new SparkContext(conf)
try {
testConcurrentQueryExecution(goodSparkContext)
@ -134,17 +121,6 @@ class SQLExecutionSuite extends SparkFunSuite {
}
}
/**
* A bad [[SparkContext]] that does not clone the inheritable thread local properties
* when passing them to children threads.
*/
private class BadSparkContext(conf: SparkConf) extends SparkContext(conf) {
protected[spark] override val localProperties = new InheritableThreadLocal[Properties] {
override protected def childValue(parent: Properties): Properties = new Properties(parent)
override protected def initialValue(): Properties = new Properties()
}
}
object SQLExecutionSuite {
@volatile var canProgress = false
}