[SPARK-20213][SQL][FOLLOW-UP] introduce SQLExecution.ignoreNestedExecutionId
## What changes were proposed in this pull request? in https://github.com/apache/spark/pull/18064, to work around the nested sql execution id issue, we introduced several internal methods in `Dataset`, like `collectInternal`, `countInternal`, `showInternal`, etc., to avoid nested execution id. However, this approach has poor expansibility. When we hit other nested execution id cases, we may need to add more internal methods in `Dataset`. Our goal is to ignore the nested execution id in some cases, and we can have a better approach to achieve this goal, by introducing `SQLExecution.ignoreNestedExecutionId`. Whenever we find a place which needs to ignore the nested execution, we can just wrap the action with `SQLExecution.ignoreNestedExecutionId`, and this is more expansible than the previous approach. The idea comes from https://github.com/apache/spark/pull/17540/files#diff-ab49028253e599e6e74cc4f4dcb2e3a8R57 by rdblue ## How was this patch tested? existing tests. Author: Wenchen Fan <wenchen@databricks.com> Closes #18419 from cloud-fan/follow.
This commit is contained in:
parent
9e50a1d37a
commit
c22810004f
|
@ -246,13 +246,8 @@ class Dataset[T] private[sql](
|
|||
_numRows: Int, truncate: Int = 20, vertical: Boolean = false): String = {
|
||||
val numRows = _numRows.max(0)
|
||||
val takeResult = toDF().take(numRows + 1)
|
||||
showString(takeResult, numRows, truncate, vertical)
|
||||
}
|
||||
|
||||
private def showString(
|
||||
dataWithOneMoreRow: Array[Row], numRows: Int, truncate: Int, vertical: Boolean): String = {
|
||||
val hasMoreData = dataWithOneMoreRow.length > numRows
|
||||
val data = dataWithOneMoreRow.take(numRows)
|
||||
val hasMoreData = takeResult.length > numRows
|
||||
val data = takeResult.take(numRows)
|
||||
|
||||
lazy val timeZone =
|
||||
DateTimeUtils.getTimeZone(sparkSession.sessionState.conf.sessionLocalTimeZone)
|
||||
|
@ -688,19 +683,6 @@ class Dataset[T] private[sql](
|
|||
println(showString(numRows, truncate = 0))
|
||||
}
|
||||
|
||||
// An internal version of `show`, which won't set execution id and trigger listeners.
|
||||
private[sql] def showInternal(_numRows: Int, truncate: Boolean): Unit = {
|
||||
val numRows = _numRows.max(0)
|
||||
val takeResult = toDF().takeInternal(numRows + 1)
|
||||
|
||||
if (truncate) {
|
||||
println(showString(takeResult, numRows, truncate = 20, vertical = false))
|
||||
} else {
|
||||
println(showString(takeResult, numRows, truncate = 0, vertical = false))
|
||||
}
|
||||
}
|
||||
// scalastyle:on println
|
||||
|
||||
/**
|
||||
* Displays the Dataset in a tabular form. For example:
|
||||
* {{{
|
||||
|
@ -2467,11 +2449,6 @@ class Dataset[T] private[sql](
|
|||
*/
|
||||
def take(n: Int): Array[T] = head(n)
|
||||
|
||||
// An internal version of `take`, which won't set execution id and trigger listeners.
|
||||
private[sql] def takeInternal(n: Int): Array[T] = {
|
||||
collectFromPlan(limit(n).queryExecution.executedPlan)
|
||||
}
|
||||
|
||||
/**
|
||||
* Returns the first `n` rows in the Dataset as a list.
|
||||
*
|
||||
|
@ -2496,11 +2473,6 @@ class Dataset[T] private[sql](
|
|||
*/
|
||||
def collect(): Array[T] = withAction("collect", queryExecution)(collectFromPlan)
|
||||
|
||||
// An internal version of `collect`, which won't set execution id and trigger listeners.
|
||||
private[sql] def collectInternal(): Array[T] = {
|
||||
collectFromPlan(queryExecution.executedPlan)
|
||||
}
|
||||
|
||||
/**
|
||||
* Returns a Java list that contains all rows in this Dataset.
|
||||
*
|
||||
|
@ -2542,11 +2514,6 @@ class Dataset[T] private[sql](
|
|||
plan.executeCollect().head.getLong(0)
|
||||
}
|
||||
|
||||
// An internal version of `count`, which won't set execution id and trigger listeners.
|
||||
private[sql] def countInternal(): Long = {
|
||||
groupBy().count().queryExecution.executedPlan.executeCollect().head.getLong(0)
|
||||
}
|
||||
|
||||
/**
|
||||
* Returns a new Dataset that has exactly `numPartitions` partitions.
|
||||
*
|
||||
|
@ -2792,7 +2759,7 @@ class Dataset[T] private[sql](
|
|||
createTempViewCommand(viewName, replace = true, global = true)
|
||||
}
|
||||
|
||||
private[spark] def createTempViewCommand(
|
||||
private def createTempViewCommand(
|
||||
viewName: String,
|
||||
replace: Boolean,
|
||||
global: Boolean): CreateViewCommand = {
|
||||
|
|
|
@ -29,6 +29,8 @@ object SQLExecution {
|
|||
|
||||
val EXECUTION_ID_KEY = "spark.sql.execution.id"
|
||||
|
||||
private val IGNORE_NESTED_EXECUTION_ID = "spark.sql.execution.ignoreNestedExecutionId"
|
||||
|
||||
private val _nextExecutionId = new AtomicLong(0)
|
||||
|
||||
private def nextExecutionId: Long = _nextExecutionId.getAndIncrement
|
||||
|
@ -42,8 +44,11 @@ object SQLExecution {
|
|||
private val testing = sys.props.contains("spark.testing")
|
||||
|
||||
private[sql] def checkSQLExecutionId(sparkSession: SparkSession): Unit = {
|
||||
val sc = sparkSession.sparkContext
|
||||
val isNestedExecution = sc.getLocalProperty(IGNORE_NESTED_EXECUTION_ID) != null
|
||||
val hasExecutionId = sc.getLocalProperty(EXECUTION_ID_KEY) != null
|
||||
// only throw an exception during tests. a missing execution ID should not fail a job.
|
||||
if (testing && sparkSession.sparkContext.getLocalProperty(EXECUTION_ID_KEY) == null) {
|
||||
if (testing && !isNestedExecution && !hasExecutionId) {
|
||||
// Attention testers: when a test fails with this exception, it means that the action that
|
||||
// started execution of a query didn't call withNewExecutionId. The execution ID should be
|
||||
// set by calling withNewExecutionId in the action that begins execution, like
|
||||
|
@ -65,7 +70,7 @@ object SQLExecution {
|
|||
val executionId = SQLExecution.nextExecutionId
|
||||
sc.setLocalProperty(EXECUTION_ID_KEY, executionId.toString)
|
||||
executionIdToQueryExecution.put(executionId, queryExecution)
|
||||
val r = try {
|
||||
try {
|
||||
// sparkContext.getCallSite() would first try to pick up any call site that was previously
|
||||
// set, then fall back to Utils.getCallSite(); call Utils.getCallSite() directly on
|
||||
// streaming queries would give us call site like "run at <unknown>:0"
|
||||
|
@ -84,7 +89,15 @@ object SQLExecution {
|
|||
executionIdToQueryExecution.remove(executionId)
|
||||
sc.setLocalProperty(EXECUTION_ID_KEY, null)
|
||||
}
|
||||
r
|
||||
} else if (sc.getLocalProperty(IGNORE_NESTED_EXECUTION_ID) != null) {
|
||||
// If `IGNORE_NESTED_EXECUTION_ID` is set, just ignore the execution id while evaluating the
|
||||
// `body`, so that Spark jobs issued in the `body` won't be tracked.
|
||||
try {
|
||||
sc.setLocalProperty(EXECUTION_ID_KEY, null)
|
||||
body
|
||||
} finally {
|
||||
sc.setLocalProperty(EXECUTION_ID_KEY, oldExecutionId)
|
||||
}
|
||||
} else {
|
||||
// Don't support nested `withNewExecutionId`. This is an example of the nested
|
||||
// `withNewExecutionId`:
|
||||
|
@ -100,7 +113,9 @@ object SQLExecution {
|
|||
// all accumulator metrics will be 0. It will confuse people if we show them in Web UI.
|
||||
//
|
||||
// A real case is the `DataFrame.count` method.
|
||||
throw new IllegalArgumentException(s"$EXECUTION_ID_KEY is already set")
|
||||
throw new IllegalArgumentException(s"$EXECUTION_ID_KEY is already set, please wrap your " +
|
||||
"action with SQLExecution.ignoreNestedExecutionId if you don't want to track the Spark " +
|
||||
"jobs issued by the nested execution.")
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -118,4 +133,20 @@ object SQLExecution {
|
|||
sc.setLocalProperty(SQLExecution.EXECUTION_ID_KEY, oldExecutionId)
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Wrap an action which may have nested execution id. This method can be used to run an execution
|
||||
* inside another execution, e.g., `CacheTableCommand` need to call `Dataset.collect`. Note that,
|
||||
* all Spark jobs issued in the body won't be tracked in UI.
|
||||
*/
|
||||
def ignoreNestedExecutionId[T](sparkSession: SparkSession)(body: => T): T = {
|
||||
val sc = sparkSession.sparkContext
|
||||
val allowNestedPreviousValue = sc.getLocalProperty(IGNORE_NESTED_EXECUTION_ID)
|
||||
try {
|
||||
sc.setLocalProperty(IGNORE_NESTED_EXECUTION_ID, "true")
|
||||
body
|
||||
} finally {
|
||||
sc.setLocalProperty(IGNORE_NESTED_EXECUTION_ID, allowNestedPreviousValue)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -27,6 +27,7 @@ import org.apache.spark.internal.Logging
|
|||
import org.apache.spark.sql.{AnalysisException, Row, SparkSession}
|
||||
import org.apache.spark.sql.catalyst.TableIdentifier
|
||||
import org.apache.spark.sql.catalyst.catalog.{CatalogStatistics, CatalogTable, CatalogTableType}
|
||||
import org.apache.spark.sql.execution.SQLExecution
|
||||
import org.apache.spark.sql.internal.SessionState
|
||||
|
||||
|
||||
|
@ -58,7 +59,9 @@ case class AnalyzeTableCommand(
|
|||
// 2. when total size is changed, `oldRowCount` becomes invalid.
|
||||
// This is to make sure that we only record the right statistics.
|
||||
if (!noscan) {
|
||||
val newRowCount = sparkSession.table(tableIdentWithDB).countInternal()
|
||||
val newRowCount = SQLExecution.ignoreNestedExecutionId(sparkSession) {
|
||||
sparkSession.table(tableIdentWithDB).count()
|
||||
}
|
||||
if (newRowCount >= 0 && newRowCount != oldRowCount) {
|
||||
newStats = if (newStats.isDefined) {
|
||||
newStats.map(_.copy(rowCount = Some(BigInt(newRowCount))))
|
||||
|
|
|
@ -22,6 +22,7 @@ import org.apache.spark.sql.catalyst.TableIdentifier
|
|||
import org.apache.spark.sql.catalyst.analysis.NoSuchTableException
|
||||
import org.apache.spark.sql.catalyst.plans.QueryPlan
|
||||
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
|
||||
import org.apache.spark.sql.execution.SQLExecution
|
||||
|
||||
case class CacheTableCommand(
|
||||
tableIdent: TableIdentifier,
|
||||
|
@ -33,16 +34,16 @@ case class CacheTableCommand(
|
|||
override def innerChildren: Seq[QueryPlan[_]] = plan.toSeq
|
||||
|
||||
override def run(sparkSession: SparkSession): Seq[Row] = {
|
||||
plan.foreach { logicalPlan =>
|
||||
Dataset.ofRows(sparkSession, logicalPlan)
|
||||
.createTempViewCommand(tableIdent.quotedString, replace = false, global = false)
|
||||
.run(sparkSession)
|
||||
}
|
||||
sparkSession.catalog.cacheTable(tableIdent.quotedString)
|
||||
SQLExecution.ignoreNestedExecutionId(sparkSession) {
|
||||
plan.foreach { logicalPlan =>
|
||||
Dataset.ofRows(sparkSession, logicalPlan).createTempView(tableIdent.quotedString)
|
||||
}
|
||||
sparkSession.catalog.cacheTable(tableIdent.quotedString)
|
||||
|
||||
if (!isLazy) {
|
||||
// Performs eager caching
|
||||
sparkSession.table(tableIdent).countInternal()
|
||||
if (!isLazy) {
|
||||
// Performs eager caching
|
||||
sparkSession.table(tableIdent).count()
|
||||
}
|
||||
}
|
||||
|
||||
Seq.empty[Row]
|
||||
|
|
|
@ -32,6 +32,7 @@ import org.apache.spark.input.{PortableDataStream, StreamInputFormat}
|
|||
import org.apache.spark.rdd.{BinaryFileRDD, RDD}
|
||||
import org.apache.spark.sql.{Dataset, Encoders, SparkSession}
|
||||
import org.apache.spark.sql.catalyst.InternalRow
|
||||
import org.apache.spark.sql.execution.SQLExecution
|
||||
import org.apache.spark.sql.execution.datasources._
|
||||
import org.apache.spark.sql.execution.datasources.text.TextFileFormat
|
||||
import org.apache.spark.sql.types.StructType
|
||||
|
@ -144,8 +145,9 @@ object TextInputCSVDataSource extends CSVDataSource {
|
|||
inputPaths: Seq[FileStatus],
|
||||
parsedOptions: CSVOptions): StructType = {
|
||||
val csv = createBaseDataset(sparkSession, inputPaths, parsedOptions)
|
||||
val maybeFirstLine =
|
||||
CSVUtils.filterCommentAndEmpty(csv, parsedOptions).takeInternal(1).headOption
|
||||
val maybeFirstLine = SQLExecution.ignoreNestedExecutionId(sparkSession) {
|
||||
CSVUtils.filterCommentAndEmpty(csv, parsedOptions).take(1).headOption
|
||||
}
|
||||
inferFromDataset(sparkSession, csv, maybeFirstLine, parsedOptions)
|
||||
}
|
||||
|
||||
|
|
|
@ -23,6 +23,7 @@ import org.apache.spark.internal.Logging
|
|||
import org.apache.spark.Partition
|
||||
import org.apache.spark.rdd.RDD
|
||||
import org.apache.spark.sql.{DataFrame, Row, SaveMode, SparkSession, SQLContext}
|
||||
import org.apache.spark.sql.execution.SQLExecution
|
||||
import org.apache.spark.sql.jdbc.JdbcDialects
|
||||
import org.apache.spark.sql.sources._
|
||||
import org.apache.spark.sql.types.StructType
|
||||
|
@ -129,14 +130,11 @@ private[sql] case class JDBCRelation(
|
|||
}
|
||||
|
||||
override def insert(data: DataFrame, overwrite: Boolean): Unit = {
|
||||
import scala.collection.JavaConverters._
|
||||
|
||||
val options = jdbcOptions.asProperties.asScala +
|
||||
("url" -> jdbcOptions.url, "dbtable" -> jdbcOptions.table)
|
||||
val mode = if (overwrite) SaveMode.Overwrite else SaveMode.Append
|
||||
|
||||
new JdbcRelationProvider().createRelation(
|
||||
data.sparkSession.sqlContext, mode, options.toMap, data)
|
||||
SQLExecution.ignoreNestedExecutionId(data.sparkSession) {
|
||||
data.write
|
||||
.mode(if (overwrite) SaveMode.Overwrite else SaveMode.Append)
|
||||
.jdbc(jdbcOptions.url, jdbcOptions.table, jdbcOptions.asProperties)
|
||||
}
|
||||
}
|
||||
|
||||
override def toString: String = {
|
||||
|
|
|
@ -22,6 +22,7 @@ import org.apache.spark.sql.{DataFrame, SQLContext}
|
|||
import org.apache.spark.sql.sources.{BaseRelation, CreatableRelationProvider, DataSourceRegister, StreamSinkProvider}
|
||||
import org.apache.spark.sql.streaming.OutputMode
|
||||
import org.apache.spark.sql.SaveMode
|
||||
import org.apache.spark.sql.execution.SQLExecution
|
||||
import org.apache.spark.sql.types.StructType
|
||||
|
||||
class ConsoleSink(options: Map[String, String]) extends Sink with Logging {
|
||||
|
@ -47,9 +48,11 @@ class ConsoleSink(options: Map[String, String]) extends Sink with Logging {
|
|||
println(batchIdStr)
|
||||
println("-------------------------------------------")
|
||||
// scalastyle:off println
|
||||
data.sparkSession.createDataFrame(
|
||||
data.sparkSession.sparkContext.parallelize(data.collectInternal()), data.schema)
|
||||
.showInternal(numRowsToShow, isTruncated)
|
||||
SQLExecution.ignoreNestedExecutionId(data.sparkSession) {
|
||||
data.sparkSession.createDataFrame(
|
||||
data.sparkSession.sparkContext.parallelize(data.collect()), data.schema)
|
||||
.show(numRowsToShow, isTruncated)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -79,7 +82,9 @@ class ConsoleSinkProvider extends StreamSinkProvider
|
|||
|
||||
// Truncate the displayed data if it is too long, by default it is true
|
||||
val isTruncated = parameters.get("truncate").map(_.toBoolean).getOrElse(true)
|
||||
data.showInternal(numRowsToShow, isTruncated)
|
||||
SQLExecution.ignoreNestedExecutionId(sqlContext.sparkSession) {
|
||||
data.show(numRowsToShow, isTruncated)
|
||||
}
|
||||
|
||||
ConsoleRelation(sqlContext, data)
|
||||
}
|
||||
|
|
|
@ -29,6 +29,7 @@ import org.apache.spark.sql.catalyst.encoders.encoderFor
|
|||
import org.apache.spark.sql.catalyst.expressions.Attribute
|
||||
import org.apache.spark.sql.catalyst.plans.logical.{LeafNode, Statistics}
|
||||
import org.apache.spark.sql.catalyst.streaming.InternalOutputModes._
|
||||
import org.apache.spark.sql.execution.SQLExecution
|
||||
import org.apache.spark.sql.streaming.OutputMode
|
||||
import org.apache.spark.sql.types.StructType
|
||||
import org.apache.spark.util.Utils
|
||||
|
@ -193,21 +194,23 @@ class MemorySink(val schema: StructType, outputMode: OutputMode) extends Sink wi
|
|||
}
|
||||
if (notCommitted) {
|
||||
logDebug(s"Committing batch $batchId to $this")
|
||||
outputMode match {
|
||||
case Append | Update =>
|
||||
val rows = AddedData(batchId, data.collectInternal())
|
||||
synchronized { batches += rows }
|
||||
SQLExecution.ignoreNestedExecutionId(data.sparkSession) {
|
||||
outputMode match {
|
||||
case Append | Update =>
|
||||
val rows = AddedData(batchId, data.collect())
|
||||
synchronized { batches += rows }
|
||||
|
||||
case Complete =>
|
||||
val rows = AddedData(batchId, data.collectInternal())
|
||||
synchronized {
|
||||
batches.clear()
|
||||
batches += rows
|
||||
}
|
||||
case Complete =>
|
||||
val rows = AddedData(batchId, data.collect())
|
||||
synchronized {
|
||||
batches.clear()
|
||||
batches += rows
|
||||
}
|
||||
|
||||
case _ =>
|
||||
throw new IllegalArgumentException(
|
||||
s"Output mode $outputMode is not supported by MemorySink")
|
||||
case _ =>
|
||||
throw new IllegalArgumentException(
|
||||
s"Output mode $outputMode is not supported by MemorySink")
|
||||
}
|
||||
}
|
||||
} else {
|
||||
logDebug(s"Skipping already committed batch: $batchId")
|
||||
|
|
Loading…
Reference in a new issue