[SPARK-20213][SQL][FOLLOW-UP] introduce SQLExecution.ignoreNestedExecutionId

## What changes were proposed in this pull request?

in https://github.com/apache/spark/pull/18064, to work around the nested sql execution id issue, we introduced several internal methods in `Dataset`, like `collectInternal`, `countInternal`, `showInternal`, etc., to avoid nested execution id.

However, this approach has poor expansibility. When we hit other nested execution id cases, we may need to add more internal methods in `Dataset`.

Our goal is to ignore the nested execution id in some cases, and we can have a better approach to achieve this goal, by introducing `SQLExecution.ignoreNestedExecutionId`. Whenever we find a place which needs to ignore the nested execution, we can just wrap the action with `SQLExecution.ignoreNestedExecutionId`, and this is more expansible than the previous approach.

The idea comes from https://github.com/apache/spark/pull/17540/files#diff-ab49028253e599e6e74cc4f4dcb2e3a8R57 by rdblue

## How was this patch tested?

existing tests.

Author: Wenchen Fan <wenchen@databricks.com>

Closes #18419 from cloud-fan/follow.
This commit is contained in:
Wenchen Fan 2017-06-27 02:35:51 +08:00
parent 9e50a1d37a
commit c22810004f
8 changed files with 87 additions and 77 deletions

View file

@ -246,13 +246,8 @@ class Dataset[T] private[sql](
_numRows: Int, truncate: Int = 20, vertical: Boolean = false): String = {
val numRows = _numRows.max(0)
val takeResult = toDF().take(numRows + 1)
showString(takeResult, numRows, truncate, vertical)
}
private def showString(
dataWithOneMoreRow: Array[Row], numRows: Int, truncate: Int, vertical: Boolean): String = {
val hasMoreData = dataWithOneMoreRow.length > numRows
val data = dataWithOneMoreRow.take(numRows)
val hasMoreData = takeResult.length > numRows
val data = takeResult.take(numRows)
lazy val timeZone =
DateTimeUtils.getTimeZone(sparkSession.sessionState.conf.sessionLocalTimeZone)
@ -688,19 +683,6 @@ class Dataset[T] private[sql](
println(showString(numRows, truncate = 0))
}
// An internal version of `show`, which won't set execution id and trigger listeners.
private[sql] def showInternal(_numRows: Int, truncate: Boolean): Unit = {
val numRows = _numRows.max(0)
val takeResult = toDF().takeInternal(numRows + 1)
if (truncate) {
println(showString(takeResult, numRows, truncate = 20, vertical = false))
} else {
println(showString(takeResult, numRows, truncate = 0, vertical = false))
}
}
// scalastyle:on println
/**
* Displays the Dataset in a tabular form. For example:
* {{{
@ -2467,11 +2449,6 @@ class Dataset[T] private[sql](
*/
def take(n: Int): Array[T] = head(n)
// An internal version of `take`, which won't set execution id and trigger listeners.
private[sql] def takeInternal(n: Int): Array[T] = {
collectFromPlan(limit(n).queryExecution.executedPlan)
}
/**
* Returns the first `n` rows in the Dataset as a list.
*
@ -2496,11 +2473,6 @@ class Dataset[T] private[sql](
*/
def collect(): Array[T] = withAction("collect", queryExecution)(collectFromPlan)
// An internal version of `collect`, which won't set execution id and trigger listeners.
private[sql] def collectInternal(): Array[T] = {
collectFromPlan(queryExecution.executedPlan)
}
/**
* Returns a Java list that contains all rows in this Dataset.
*
@ -2542,11 +2514,6 @@ class Dataset[T] private[sql](
plan.executeCollect().head.getLong(0)
}
// An internal version of `count`, which won't set execution id and trigger listeners.
private[sql] def countInternal(): Long = {
groupBy().count().queryExecution.executedPlan.executeCollect().head.getLong(0)
}
/**
* Returns a new Dataset that has exactly `numPartitions` partitions.
*
@ -2792,7 +2759,7 @@ class Dataset[T] private[sql](
createTempViewCommand(viewName, replace = true, global = true)
}
private[spark] def createTempViewCommand(
private def createTempViewCommand(
viewName: String,
replace: Boolean,
global: Boolean): CreateViewCommand = {

View file

@ -29,6 +29,8 @@ object SQLExecution {
val EXECUTION_ID_KEY = "spark.sql.execution.id"
private val IGNORE_NESTED_EXECUTION_ID = "spark.sql.execution.ignoreNestedExecutionId"
private val _nextExecutionId = new AtomicLong(0)
private def nextExecutionId: Long = _nextExecutionId.getAndIncrement
@ -42,8 +44,11 @@ object SQLExecution {
private val testing = sys.props.contains("spark.testing")
private[sql] def checkSQLExecutionId(sparkSession: SparkSession): Unit = {
val sc = sparkSession.sparkContext
val isNestedExecution = sc.getLocalProperty(IGNORE_NESTED_EXECUTION_ID) != null
val hasExecutionId = sc.getLocalProperty(EXECUTION_ID_KEY) != null
// only throw an exception during tests. a missing execution ID should not fail a job.
if (testing && sparkSession.sparkContext.getLocalProperty(EXECUTION_ID_KEY) == null) {
if (testing && !isNestedExecution && !hasExecutionId) {
// Attention testers: when a test fails with this exception, it means that the action that
// started execution of a query didn't call withNewExecutionId. The execution ID should be
// set by calling withNewExecutionId in the action that begins execution, like
@ -65,7 +70,7 @@ object SQLExecution {
val executionId = SQLExecution.nextExecutionId
sc.setLocalProperty(EXECUTION_ID_KEY, executionId.toString)
executionIdToQueryExecution.put(executionId, queryExecution)
val r = try {
try {
// sparkContext.getCallSite() would first try to pick up any call site that was previously
// set, then fall back to Utils.getCallSite(); call Utils.getCallSite() directly on
// streaming queries would give us call site like "run at <unknown>:0"
@ -84,7 +89,15 @@ object SQLExecution {
executionIdToQueryExecution.remove(executionId)
sc.setLocalProperty(EXECUTION_ID_KEY, null)
}
r
} else if (sc.getLocalProperty(IGNORE_NESTED_EXECUTION_ID) != null) {
// If `IGNORE_NESTED_EXECUTION_ID` is set, just ignore the execution id while evaluating the
// `body`, so that Spark jobs issued in the `body` won't be tracked.
try {
sc.setLocalProperty(EXECUTION_ID_KEY, null)
body
} finally {
sc.setLocalProperty(EXECUTION_ID_KEY, oldExecutionId)
}
} else {
// Don't support nested `withNewExecutionId`. This is an example of the nested
// `withNewExecutionId`:
@ -100,7 +113,9 @@ object SQLExecution {
// all accumulator metrics will be 0. It will confuse people if we show them in Web UI.
//
// A real case is the `DataFrame.count` method.
throw new IllegalArgumentException(s"$EXECUTION_ID_KEY is already set")
throw new IllegalArgumentException(s"$EXECUTION_ID_KEY is already set, please wrap your " +
"action with SQLExecution.ignoreNestedExecutionId if you don't want to track the Spark " +
"jobs issued by the nested execution.")
}
}
@ -118,4 +133,20 @@ object SQLExecution {
sc.setLocalProperty(SQLExecution.EXECUTION_ID_KEY, oldExecutionId)
}
}
/**
* Wrap an action which may have nested execution id. This method can be used to run an execution
* inside another execution, e.g., `CacheTableCommand` need to call `Dataset.collect`. Note that,
* all Spark jobs issued in the body won't be tracked in UI.
*/
def ignoreNestedExecutionId[T](sparkSession: SparkSession)(body: => T): T = {
val sc = sparkSession.sparkContext
val allowNestedPreviousValue = sc.getLocalProperty(IGNORE_NESTED_EXECUTION_ID)
try {
sc.setLocalProperty(IGNORE_NESTED_EXECUTION_ID, "true")
body
} finally {
sc.setLocalProperty(IGNORE_NESTED_EXECUTION_ID, allowNestedPreviousValue)
}
}
}

View file

@ -27,6 +27,7 @@ import org.apache.spark.internal.Logging
import org.apache.spark.sql.{AnalysisException, Row, SparkSession}
import org.apache.spark.sql.catalyst.TableIdentifier
import org.apache.spark.sql.catalyst.catalog.{CatalogStatistics, CatalogTable, CatalogTableType}
import org.apache.spark.sql.execution.SQLExecution
import org.apache.spark.sql.internal.SessionState
@ -58,7 +59,9 @@ case class AnalyzeTableCommand(
// 2. when total size is changed, `oldRowCount` becomes invalid.
// This is to make sure that we only record the right statistics.
if (!noscan) {
val newRowCount = sparkSession.table(tableIdentWithDB).countInternal()
val newRowCount = SQLExecution.ignoreNestedExecutionId(sparkSession) {
sparkSession.table(tableIdentWithDB).count()
}
if (newRowCount >= 0 && newRowCount != oldRowCount) {
newStats = if (newStats.isDefined) {
newStats.map(_.copy(rowCount = Some(BigInt(newRowCount))))

View file

@ -22,6 +22,7 @@ import org.apache.spark.sql.catalyst.TableIdentifier
import org.apache.spark.sql.catalyst.analysis.NoSuchTableException
import org.apache.spark.sql.catalyst.plans.QueryPlan
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.execution.SQLExecution
case class CacheTableCommand(
tableIdent: TableIdentifier,
@ -33,16 +34,16 @@ case class CacheTableCommand(
override def innerChildren: Seq[QueryPlan[_]] = plan.toSeq
override def run(sparkSession: SparkSession): Seq[Row] = {
SQLExecution.ignoreNestedExecutionId(sparkSession) {
plan.foreach { logicalPlan =>
Dataset.ofRows(sparkSession, logicalPlan)
.createTempViewCommand(tableIdent.quotedString, replace = false, global = false)
.run(sparkSession)
Dataset.ofRows(sparkSession, logicalPlan).createTempView(tableIdent.quotedString)
}
sparkSession.catalog.cacheTable(tableIdent.quotedString)
if (!isLazy) {
// Performs eager caching
sparkSession.table(tableIdent).countInternal()
sparkSession.table(tableIdent).count()
}
}
Seq.empty[Row]

View file

@ -32,6 +32,7 @@ import org.apache.spark.input.{PortableDataStream, StreamInputFormat}
import org.apache.spark.rdd.{BinaryFileRDD, RDD}
import org.apache.spark.sql.{Dataset, Encoders, SparkSession}
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.execution.SQLExecution
import org.apache.spark.sql.execution.datasources._
import org.apache.spark.sql.execution.datasources.text.TextFileFormat
import org.apache.spark.sql.types.StructType
@ -144,8 +145,9 @@ object TextInputCSVDataSource extends CSVDataSource {
inputPaths: Seq[FileStatus],
parsedOptions: CSVOptions): StructType = {
val csv = createBaseDataset(sparkSession, inputPaths, parsedOptions)
val maybeFirstLine =
CSVUtils.filterCommentAndEmpty(csv, parsedOptions).takeInternal(1).headOption
val maybeFirstLine = SQLExecution.ignoreNestedExecutionId(sparkSession) {
CSVUtils.filterCommentAndEmpty(csv, parsedOptions).take(1).headOption
}
inferFromDataset(sparkSession, csv, maybeFirstLine, parsedOptions)
}

View file

@ -23,6 +23,7 @@ import org.apache.spark.internal.Logging
import org.apache.spark.Partition
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.{DataFrame, Row, SaveMode, SparkSession, SQLContext}
import org.apache.spark.sql.execution.SQLExecution
import org.apache.spark.sql.jdbc.JdbcDialects
import org.apache.spark.sql.sources._
import org.apache.spark.sql.types.StructType
@ -129,14 +130,11 @@ private[sql] case class JDBCRelation(
}
override def insert(data: DataFrame, overwrite: Boolean): Unit = {
import scala.collection.JavaConverters._
val options = jdbcOptions.asProperties.asScala +
("url" -> jdbcOptions.url, "dbtable" -> jdbcOptions.table)
val mode = if (overwrite) SaveMode.Overwrite else SaveMode.Append
new JdbcRelationProvider().createRelation(
data.sparkSession.sqlContext, mode, options.toMap, data)
SQLExecution.ignoreNestedExecutionId(data.sparkSession) {
data.write
.mode(if (overwrite) SaveMode.Overwrite else SaveMode.Append)
.jdbc(jdbcOptions.url, jdbcOptions.table, jdbcOptions.asProperties)
}
}
override def toString: String = {

View file

@ -22,6 +22,7 @@ import org.apache.spark.sql.{DataFrame, SQLContext}
import org.apache.spark.sql.sources.{BaseRelation, CreatableRelationProvider, DataSourceRegister, StreamSinkProvider}
import org.apache.spark.sql.streaming.OutputMode
import org.apache.spark.sql.SaveMode
import org.apache.spark.sql.execution.SQLExecution
import org.apache.spark.sql.types.StructType
class ConsoleSink(options: Map[String, String]) extends Sink with Logging {
@ -47,9 +48,11 @@ class ConsoleSink(options: Map[String, String]) extends Sink with Logging {
println(batchIdStr)
println("-------------------------------------------")
// scalastyle:off println
SQLExecution.ignoreNestedExecutionId(data.sparkSession) {
data.sparkSession.createDataFrame(
data.sparkSession.sparkContext.parallelize(data.collectInternal()), data.schema)
.showInternal(numRowsToShow, isTruncated)
data.sparkSession.sparkContext.parallelize(data.collect()), data.schema)
.show(numRowsToShow, isTruncated)
}
}
}
@ -79,7 +82,9 @@ class ConsoleSinkProvider extends StreamSinkProvider
// Truncate the displayed data if it is too long, by default it is true
val isTruncated = parameters.get("truncate").map(_.toBoolean).getOrElse(true)
data.showInternal(numRowsToShow, isTruncated)
SQLExecution.ignoreNestedExecutionId(sqlContext.sparkSession) {
data.show(numRowsToShow, isTruncated)
}
ConsoleRelation(sqlContext, data)
}

View file

@ -29,6 +29,7 @@ import org.apache.spark.sql.catalyst.encoders.encoderFor
import org.apache.spark.sql.catalyst.expressions.Attribute
import org.apache.spark.sql.catalyst.plans.logical.{LeafNode, Statistics}
import org.apache.spark.sql.catalyst.streaming.InternalOutputModes._
import org.apache.spark.sql.execution.SQLExecution
import org.apache.spark.sql.streaming.OutputMode
import org.apache.spark.sql.types.StructType
import org.apache.spark.util.Utils
@ -193,13 +194,14 @@ class MemorySink(val schema: StructType, outputMode: OutputMode) extends Sink wi
}
if (notCommitted) {
logDebug(s"Committing batch $batchId to $this")
SQLExecution.ignoreNestedExecutionId(data.sparkSession) {
outputMode match {
case Append | Update =>
val rows = AddedData(batchId, data.collectInternal())
val rows = AddedData(batchId, data.collect())
synchronized { batches += rows }
case Complete =>
val rows = AddedData(batchId, data.collectInternal())
val rows = AddedData(batchId, data.collect())
synchronized {
batches.clear()
batches += rows
@ -209,6 +211,7 @@ class MemorySink(val schema: StructType, outputMode: OutputMode) extends Sink wi
throw new IllegalArgumentException(
s"Output mode $outputMode is not supported by MemorySink")
}
}
} else {
logDebug(s"Skipping already committed batch: $batchId")
}