[SPARK-35378][SQL] Eagerly execute commands in QueryExecution instead of caller sides

### What changes were proposed in this pull request?
Currently, Spark eagerly executes commands on the caller side of `QueryExecution`, which is a bit hacky as `QueryExecution` is not aware of it and leads to confusion.

For example, if you run `sql("show tables").collect()`, you will see two queries with identical query plans in the web UI.
![image](https://user-images.githubusercontent.com/3182036/121193729-a72d0480-c8a0-11eb-8b12-379019607ad5.png)
![image](https://user-images.githubusercontent.com/3182036/121193822-bc099800-c8a0-11eb-9d2a-34ab1329e2f7.png)
![image](https://user-images.githubusercontent.com/3182036/121193845-c0ce4c00-c8a0-11eb-96d0-ef604a4dfab0.png)

The first query is triggered at `Dataset.logicalPlan`, which eagerly executes the command.
The second query is triggered at `Dataset.collect`, which is the normal query execution.

From the web UI, it's hard to tell that these two queries are caused by eager command execution.

This PR proposes to move the eager command execution to `QueryExecution`, and turn the command plan to `CommandResult` to indicate that command has been executed already. Now `sql("show tables").collect()` still triggers two queries, but the quey plans are not identical. The second query becomes:
![image](https://user-images.githubusercontent.com/3182036/121194850-b3659180-c8a1-11eb-9abf-2980f84f089d.png)

In addition to the UI improvements, this PR also has other benefits:
1. Simplifies code as caller side no need to worry about eager command execution. `QueryExecution` takes care of it.
2. It helps https://github.com/apache/spark/pull/32442 , where there can be more plan nodes above commands, and we need to replace commands with something like local relation that produces unsafe rows.

### Why are the changes needed?
Explained above.

### Does this PR introduce _any_ user-facing change?
No

### How was this patch tested?
existing tests

Closes #32513 from beliefer/SPARK-35378.

Lead-authored-by: gengjiaan <gengjiaan@360.cn>
Co-authored-by: beliefer <beliefer@163.com>
Co-authored-by: Jiaan Geng <beliefer@163.com>
Co-authored-by: Wenchen Fan <cloud0fan@gmail.com>
Signed-off-by: Wenchen Fan <wenchen@databricks.com>
This commit is contained in:
gengjiaan 2021-06-09 04:45:44 +00:00 committed by Wenchen Fan
parent afff42178c
commit 8013f985a4
23 changed files with 347 additions and 115 deletions

View file

@ -95,6 +95,8 @@ license: |
- In Spark 3.2, `FloatType` is mapped to `FLOAT` in MySQL. Prior to this, it used to be mapped to `REAL`, which is by default a synonym to `DOUBLE PRECISION` in MySQL.
- In Spark 3.2, the query executions triggered by `DataFrameWriter` are always named `command` when being sent to `QueryExecutionListener`. In Spark 3.1 and earlier, the name is one of `save`, `insertInto`, `saveAsTable`, `create`, `append`, `overwrite`, `overwritePartitions`, `replace`.
## Upgrading from Spark SQL 3.0 to 3.1
- In Spark 3.1, statistical aggregation function includes `std`, `stddev`, `stddev_samp`, `variance`, `var_samp`, `skewness`, `kurtosis`, `covar_samp`, `corr` will return `NULL` instead of `Double.NaN` when `DivideByZero` occurs during expression evaluation, for example, when `stddev_samp` applied on a single element set. In Spark version 3.0 and earlier, it will return `Double.NaN` in such case. To restore the behavior before Spark 3.1, you can set `spark.sql.legacy.statisticalAggregate` to `true`.

View file

@ -31,7 +31,6 @@ import org.apache.spark.sql.catalyst.util.CaseInsensitiveMap
import org.apache.spark.sql.connector.catalog.{CatalogPlugin, CatalogV2Implicits, CatalogV2Util, Identifier, SupportsCatalogOptions, Table, TableCatalog, TableProvider, V1Table}
import org.apache.spark.sql.connector.catalog.TableCapability._
import org.apache.spark.sql.connector.expressions.{FieldReference, IdentityTransform, Transform}
import org.apache.spark.sql.execution.SQLExecution
import org.apache.spark.sql.execution.command.DDLUtils
import org.apache.spark.sql.execution.datasources.{CreateTable, DataSource, DataSourceUtils, LogicalRelation}
import org.apache.spark.sql.execution.datasources.v2._
@ -311,13 +310,13 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) {
val relation = DataSourceV2Relation.create(table, catalog, ident, dsOptions)
checkPartitioningMatchesV2Table(table)
if (mode == SaveMode.Append) {
runCommand(df.sparkSession, "save") {
runCommand(df.sparkSession) {
AppendData.byName(relation, df.logicalPlan, finalOptions)
}
} else {
// Truncate the table. TableCapabilityCheck will throw a nice exception if this
// isn't supported
runCommand(df.sparkSession, "save") {
runCommand(df.sparkSession) {
OverwriteByExpression.byName(
relation, df.logicalPlan, Literal(true), finalOptions)
}
@ -332,7 +331,7 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) {
val location = Option(dsOptions.get("path")).map(TableCatalog.PROP_LOCATION -> _)
runCommand(df.sparkSession, "save") {
runCommand(df.sparkSession) {
CreateTableAsSelect(
catalog,
ident,
@ -379,7 +378,7 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) {
val optionsWithPath = getOptionsWithPath(path)
// Code path for data source v1.
runCommand(df.sparkSession, "save") {
runCommand(df.sparkSession) {
DataSource(
sparkSession = df.sparkSession,
className = source,
@ -475,13 +474,13 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) {
}
}
runCommand(df.sparkSession, "insertInto") {
runCommand(df.sparkSession) {
command
}
}
private def insertInto(tableIdent: TableIdentifier): Unit = {
runCommand(df.sparkSession, "insertInto") {
runCommand(df.sparkSession) {
InsertIntoStatement(
table = UnresolvedRelation(tableIdent),
partitionSpec = Map.empty[String, Option[String]],
@ -631,7 +630,7 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) {
external = false)
}
runCommand(df.sparkSession, "saveAsTable") {
runCommand(df.sparkSession) {
command
}
}
@ -698,7 +697,7 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) {
partitionColumnNames = partitioningColumns.getOrElse(Nil),
bucketSpec = getBucketSpec)
runCommand(df.sparkSession, "saveAsTable")(
runCommand(df.sparkSession)(
CreateTable(tableDesc, mode, Some(df.logicalPlan)))
}
@ -856,10 +855,10 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) {
* Wrap a DataFrameWriter action to track the QueryExecution and time cost, then report to the
* user-registered callback functions.
*/
private def runCommand(session: SparkSession, name: String)(command: LogicalPlan): Unit = {
private def runCommand(session: SparkSession)(command: LogicalPlan): Unit = {
val qe = session.sessionState.executePlan(command)
// call `QueryExecution.toRDD` to trigger the execution of commands.
SQLExecution.withNewExecutionId(qe, Some(name))(qe.toRdd)
// call `QueryExecution.commandExecuted` to trigger the execution of commands.
qe.commandExecuted
}
private def lookupV2Provider(): Option[TableProvider] = {

View file

@ -107,7 +107,7 @@ final class DataFrameWriterV2[T] private[sql](table: String, ds: Dataset[T])
}
override def create(): Unit = {
runCommand("create") {
runCommand(
CreateTableAsSelectStatement(
tableName,
logicalPlan,
@ -121,8 +121,7 @@ final class DataFrameWriterV2[T] private[sql](table: String, ds: Dataset[T])
options.toMap,
None,
ifNotExists = false,
external = false)
}
external = false))
}
override def replace(): Unit = {
@ -146,7 +145,7 @@ final class DataFrameWriterV2[T] private[sql](table: String, ds: Dataset[T])
@throws(classOf[NoSuchTableException])
def append(): Unit = {
val append = AppendData.byName(UnresolvedRelation(tableName), logicalPlan, options.toMap)
runCommand("append")(append)
runCommand(append)
}
/**
@ -163,7 +162,7 @@ final class DataFrameWriterV2[T] private[sql](table: String, ds: Dataset[T])
def overwrite(condition: Column): Unit = {
val overwrite = OverwriteByExpression.byName(
UnresolvedRelation(tableName), logicalPlan, condition.expr, options.toMap)
runCommand("overwrite")(overwrite)
runCommand(overwrite)
}
/**
@ -183,21 +182,21 @@ final class DataFrameWriterV2[T] private[sql](table: String, ds: Dataset[T])
def overwritePartitions(): Unit = {
val dynamicOverwrite = OverwritePartitionsDynamic.byName(
UnresolvedRelation(tableName), logicalPlan, options.toMap)
runCommand("overwritePartitions")(dynamicOverwrite)
runCommand(dynamicOverwrite)
}
/**
* Wrap an action to track the QueryExecution and time cost, then report to the user-registered
* callback functions.
*/
private def runCommand(name: String)(command: LogicalPlan): Unit = {
private def runCommand(command: LogicalPlan): Unit = {
val qe = sparkSession.sessionState.executePlan(command)
// call `QueryExecution.toRDD` to trigger the execution of commands.
SQLExecution.withNewExecutionId(qe, Some(name))(qe.toRdd)
SQLExecution.withNewExecutionId(qe, Some("command"))(qe.toRdd)
}
private def internalReplace(orCreate: Boolean): Unit = {
runCommand("replace") {
runCommand(
ReplaceTableAsSelectStatement(
tableName,
logicalPlan,
@ -210,8 +209,7 @@ final class DataFrameWriterV2[T] private[sql](table: String, ds: Dataset[T])
None,
options.toMap,
None,
orCreate = orCreate)
}
orCreate = orCreate))
}
}

View file

@ -221,16 +221,7 @@ class Dataset[T] private[sql](
}
@transient private[sql] val logicalPlan: LogicalPlan = {
// For various commands (like DDL) and queries with side effects, we force query execution
// to happen right away to let these side effects take place eagerly.
val plan = queryExecution.analyzed match {
case c: Command =>
LocalRelation(c.output, withAction("command", queryExecution)(_.executeCollect()))
case u @ Union(children, _, _) if children.forall(_.isInstanceOf[Command]) =>
LocalRelation(u.output, withAction("command", queryExecution)(_.executeCollect()))
case _ =>
queryExecution.analyzed
}
val plan = queryExecution.commandExecuted
if (sparkSession.sessionState.conf.getConf(SQLConf.FAIL_AMBIGUOUS_SELF_JOIN_ENABLED)) {
val dsIds = plan.getTagValue(Dataset.DATASET_ID_TAG).getOrElse(new HashSet[Long])
dsIds.add(id)

View file

@ -0,0 +1,99 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.sql.execution
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.{Attribute, UnsafeProjection}
import org.apache.spark.sql.catalyst.plans.QueryPlan
import org.apache.spark.sql.execution.metric.SQLMetrics
/**
* Physical plan node for holding data from a command.
*
* `commandPhysicalPlan` is just used to display the plan tree for EXPLAIN.
* `rows` may not be serializable and ideally we should not send `rows` to the executors.
* Thus marking them as transient.
*/
case class CommandResultExec(
output: Seq[Attribute],
@transient commandPhysicalPlan: SparkPlan,
@transient rows: Seq[InternalRow]) extends LeafExecNode with InputRDDCodegen {
override lazy val metrics = Map(
"numOutputRows" -> SQLMetrics.createMetric(sparkContext, "number of output rows"))
override def innerChildren: Seq[QueryPlan[_]] = Seq(commandPhysicalPlan)
@transient private lazy val unsafeRows: Array[InternalRow] = {
if (rows.isEmpty) {
Array.empty
} else {
val proj = UnsafeProjection.create(output, output)
rows.map(r => proj(r).copy()).toArray
}
}
@transient private lazy val rdd: RDD[InternalRow] = {
if (rows.isEmpty) {
sqlContext.sparkContext.emptyRDD
} else {
val numSlices = math.min(
unsafeRows.length, sqlContext.sparkSession.leafNodeDefaultParallelism)
sqlContext.sparkContext.parallelize(unsafeRows, numSlices)
}
}
override def doExecute(): RDD[InternalRow] = {
val numOutputRows = longMetric("numOutputRows")
rdd.map { r =>
numOutputRows += 1
r
}
}
override protected def stringArgs: Iterator[Any] = {
if (unsafeRows.isEmpty) {
Iterator("<empty>", output)
} else {
Iterator(output)
}
}
override def executeCollect(): Array[InternalRow] = {
longMetric("numOutputRows").add(rows.size)
rows.toArray
}
override def executeTake(limit: Int): Array[InternalRow] = {
val taken = unsafeRows.take(limit)
longMetric("numOutputRows").add(taken.size)
taken
}
override def executeTail(limit: Int): Array[InternalRow] = {
val taken: Seq[InternalRow] = unsafeRows.takeRight(limit)
longMetric("numOutputRows").add(taken.size)
taken.toArray
}
// Input is already UnsafeRows.
override protected val createUnsafeProjection: Boolean = false
override def inputRDD: RDD[InternalRow] = rdd
}

View file

@ -44,36 +44,42 @@ object HiveResult {
TimeFormatters(dateFormatter, timestampFormatter)
}
private def stripRootCommandResult(executedPlan: SparkPlan): SparkPlan = executedPlan match {
case CommandResultExec(_, plan, _) => plan
case other => other
}
/**
* Returns the result as a hive compatible sequence of strings. This is used in tests and
* `SparkSQLDriver` for CLI applications.
*/
def hiveResultString(executedPlan: SparkPlan): Seq[String] = executedPlan match {
case ExecutedCommandExec(_: DescribeCommandBase) =>
formatDescribeTableOutput(executedPlan.executeCollectPublic())
case _: DescribeTableExec =>
formatDescribeTableOutput(executedPlan.executeCollectPublic())
// SHOW TABLES in Hive only output table names while our v1 command outputs
// database, table name, isTemp.
case command @ ExecutedCommandExec(s: ShowTablesCommand) if !s.isExtended =>
command.executeCollect().map(_.getString(1))
// SHOW TABLES in Hive only output table names while our v2 command outputs
// namespace and table name.
case command : ShowTablesExec =>
command.executeCollect().map(_.getString(1))
// SHOW VIEWS in Hive only outputs view names while our v1 command outputs
// namespace, viewName, and isTemporary.
case command @ ExecutedCommandExec(_: ShowViewsCommand) =>
command.executeCollect().map(_.getString(1))
case other =>
val timeFormatters = getTimeFormatters
val result: Seq[Seq[Any]] = other.executeCollectPublic().map(_.toSeq).toSeq
// We need the types so we can output struct field names
val types = executedPlan.output.map(_.dataType)
// Reformat to match hive tab delimited output.
result.map(_.zip(types).map(e => toHiveString(e, false, timeFormatters)))
.map(_.mkString("\t"))
}
def hiveResultString(executedPlan: SparkPlan): Seq[String] =
stripRootCommandResult(executedPlan) match {
case ExecutedCommandExec(_: DescribeCommandBase) =>
formatDescribeTableOutput(executedPlan.executeCollectPublic())
case _: DescribeTableExec =>
formatDescribeTableOutput(executedPlan.executeCollectPublic())
// SHOW TABLES in Hive only output table names while our v1 command outputs
// database, table name, isTemp.
case ExecutedCommandExec(s: ShowTablesCommand) if !s.isExtended =>
executedPlan.executeCollect().map(_.getString(1))
// SHOW TABLES in Hive only output table names while our v2 command outputs
// namespace and table name.
case _ : ShowTablesExec =>
executedPlan.executeCollect().map(_.getString(1))
// SHOW VIEWS in Hive only outputs view names while our v1 command outputs
// namespace, viewName, and isTemporary.
case ExecutedCommandExec(_: ShowViewsCommand) =>
executedPlan.executeCollect().map(_.getString(1))
case other =>
val timeFormatters = getTimeFormatters
val result: Seq[Seq[Any]] = other.executeCollectPublic().map(_.toSeq).toSeq
// We need the types so we can output struct field names
val types = executedPlan.output.map(_.dataType)
// Reformat to match hive tab delimited output.
result.map(_.zip(types).map(e => toHiveString(e, false, timeFormatters)))
.map(_.mkString("\t"))
}
private def formatDescribeTableOutput(rows: Array[Row]): Seq[String] = {
rows.map {

View file

@ -30,7 +30,7 @@ import org.apache.spark.sql.catalyst.{InternalRow, QueryPlanningTracker}
import org.apache.spark.sql.catalyst.analysis.UnsupportedOperationChecker
import org.apache.spark.sql.catalyst.expressions.codegen.ByteCodeStats
import org.apache.spark.sql.catalyst.plans.QueryPlan
import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, ReturnAnswer}
import org.apache.spark.sql.catalyst.plans.logical.{Command, LogicalPlan, ReturnAnswer}
import org.apache.spark.sql.catalyst.rules.{PlanChangeLogger, Rule}
import org.apache.spark.sql.catalyst.util.StringUtils.PlanStringConcat
import org.apache.spark.sql.catalyst.util.truncatedString
@ -39,6 +39,7 @@ import org.apache.spark.sql.execution.bucketing.{CoalesceBucketsInJoin, DisableU
import org.apache.spark.sql.execution.dynamicpruning.PlanDynamicPruningFilters
import org.apache.spark.sql.execution.exchange.{EnsureRequirements, ReuseExchange}
import org.apache.spark.sql.execution.streaming.{IncrementalExecution, OffsetSeqMetadata}
import org.apache.spark.sql.expressions.CommandResult
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.streaming.OutputMode
import org.apache.spark.util.Utils
@ -53,7 +54,8 @@ import org.apache.spark.util.Utils
class QueryExecution(
val sparkSession: SparkSession,
val logical: LogicalPlan,
val tracker: QueryPlanningTracker = new QueryPlanningTracker) extends Logging {
val tracker: QueryPlanningTracker = new QueryPlanningTracker,
val mode: CommandExecutionMode.Value = CommandExecutionMode.ALL) extends Logging {
val id: Long = QueryExecution.nextExecutionId
@ -73,23 +75,51 @@ class QueryExecution(
sparkSession.sessionState.analyzer.executeAndCheck(logical, tracker)
}
lazy val commandExecuted: LogicalPlan = mode match {
case CommandExecutionMode.NON_ROOT => analyzed.mapChildren(eagerlyExecuteCommands)
case CommandExecutionMode.ALL => eagerlyExecuteCommands(analyzed)
case CommandExecutionMode.SKIP => analyzed
}
private def eagerlyExecuteCommands(p: LogicalPlan) = p transformDown {
case c: Command =>
val qe = sparkSession.sessionState.executePlan(c, CommandExecutionMode.NON_ROOT)
val result =
SQLExecution.withNewExecutionId(qe, Some("command"))(qe.executedPlan.executeCollect())
CommandResult(
qe.analyzed.output,
qe.commandExecuted,
qe.executedPlan,
result)
case other => other
}
lazy val withCachedData: LogicalPlan = sparkSession.withActive {
assertAnalyzed()
assertSupported()
// clone the plan to avoid sharing the plan instance between different stages like analyzing,
// optimizing and planning.
sparkSession.sharedState.cacheManager.useCachedData(analyzed.clone())
sparkSession.sharedState.cacheManager.useCachedData(commandExecuted.clone())
}
lazy val optimizedPlan: LogicalPlan = executePhase(QueryPlanningTracker.OPTIMIZATION) {
// clone the plan to avoid sharing the plan instance between different stages like analyzing,
// optimizing and planning.
val plan = sparkSession.sessionState.optimizer.executeAndTrack(withCachedData.clone(), tracker)
// We do not want optimized plans to be re-analyzed as literals that have been constant folded
// and such can cause issues during analysis. While `clone` should maintain the `analyzed` state
// of the LogicalPlan, we set the plan as analyzed here as well out of paranoia.
plan.setAnalyzed()
plan
private def assertCommandExecuted(): Unit = commandExecuted
lazy val optimizedPlan: LogicalPlan = {
// We need to materialize the commandExecuted here because optimizedPlan is also tracked under
// the optimizing phase
assertCommandExecuted()
executePhase(QueryPlanningTracker.OPTIMIZATION) {
// clone the plan to avoid sharing the plan instance between different stages like analyzing,
// optimizing and planning.
val plan =
sparkSession.sessionState.optimizer.executeAndTrack(withCachedData.clone(), tracker)
// We do not want optimized plans to be re-analyzed as literals that have been constant
// folded and such can cause issues during analysis. While `clone` should maintain the
// `analyzed` state of the LogicalPlan, we set the plan as analyzed here as well out of
// paranoia.
plan.setAnalyzed()
plan
}
}
private def assertOptimized(): Unit = optimizedPlan
@ -333,6 +363,19 @@ class QueryExecution(
}
}
/**
* SPARK-35378: Commands should be executed eagerly so that something like `sql("INSERT ...")`
* can trigger the table insertion immediately without a `.collect()`. To avoid end-less recursion
* we should use `NON_ROOT` when recursively executing commands. Note that we can't execute
* a query plan with leaf command nodes, because many commands return `GenericInternalRow`
* and can't be put in a query plan directly, otherwise the query engine may cast
* `GenericInternalRow` to `UnsafeRow` and fail. When running EXPLAIN, or commands inside other
* command, we should use `SKIP` to not eagerly trigger the command execution.
*/
object CommandExecutionMode extends Enumeration {
val SKIP, NON_ROOT, ALL = Value
}
object QueryExecution {
private val _nextExecutionId = new AtomicLong(0)

View file

@ -36,6 +36,7 @@ import org.apache.spark.sql.execution.exchange.{REPARTITION, REPARTITION_WITH_NU
import org.apache.spark.sql.execution.python._
import org.apache.spark.sql.execution.streaming._
import org.apache.spark.sql.execution.streaming.sources.MemoryPlan
import org.apache.spark.sql.expressions.CommandResult
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.streaming.OutputMode
import org.apache.spark.sql.types.StructType
@ -697,6 +698,7 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
execution.SampleExec(lb, ub, withReplacement, seed, planLater(child)) :: Nil
case logical.LocalRelation(output, data, _) =>
LocalTableScanExec(output, data) :: Nil
case CommandResult(output, _, plan, data) => CommandResultExec(output, plan, data) :: Nil
case logical.LocalLimit(IntegerLiteral(limit), child) =>
execution.LocalLimitExec(limit, planLater(child)) :: Nil
case logical.GlobalLimit(IntegerLiteral(limit), child) =>

View file

@ -934,6 +934,10 @@ case class CollapseCodegenStages(
// Do not make LogicalTableScanExec the root of WholeStageCodegen
// to support the fast driver-local collect/take paths.
plan
case plan: CommandResultExec =>
// Do not make CommandResultExec the root of WholeStageCodegen
// to support the fast driver-local collect/take paths.
plan
case plan: CodegenSupport if supportCodegen(plan) =>
// The whole-stage-codegen framework is row-based. If a plan supports columnar execution,
// it can't support whole-stage-codegen at the same time.

View file

@ -45,6 +45,7 @@ case class InsertAdaptiveSparkPlan(
private def applyInternal(plan: SparkPlan, isSubquery: Boolean): SparkPlan = plan match {
case _ if !conf.adaptiveExecutionEnabled => plan
case _: ExecutedCommandExec => plan
case _: CommandResultExec => plan
case c: DataWritingCommandExec => c.copy(child = apply(c.child))
case c: V2CommandExec => c.withNewChildren(c.children.map(apply))
case _ if shouldApplyAQE(plan, isSubquery) =>

View file

@ -28,7 +28,7 @@ import org.apache.spark.sql.catalyst.plans.QueryPlan
import org.apache.spark.sql.catalyst.plans.logical.{Command, LogicalPlan}
import org.apache.spark.sql.catalyst.trees.LeafLike
import org.apache.spark.sql.connector.ExternalCommandRunner
import org.apache.spark.sql.execution.{ExplainMode, LeafExecNode, SparkPlan, UnaryExecNode}
import org.apache.spark.sql.execution.{CommandExecutionMode, ExplainMode, LeafExecNode, SparkPlan, UnaryExecNode}
import org.apache.spark.sql.execution.metric.SQLMetric
import org.apache.spark.sql.execution.streaming.IncrementalExecution
import org.apache.spark.sql.types._
@ -163,7 +163,8 @@ case class ExplainCommand(
// Run through the optimizer to generate the physical plan.
override def run(sparkSession: SparkSession): Seq[Row] = try {
val outputString = sparkSession.sessionState.executePlan(logicalPlan).explainString(mode)
val outputString = sparkSession.sessionState.executePlan(logicalPlan, CommandExecutionMode.SKIP)
.explainString(mode)
Seq(Row(outputString))
} catch { case NonFatal(cause) =>
("Error occurred during query planning: \n" + cause.getMessage).split("\n").map(Row(_))

View file

@ -23,7 +23,7 @@ import org.apache.spark.sql._
import org.apache.spark.sql.catalyst.catalog._
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.catalyst.util.CharVarcharUtils
import org.apache.spark.sql.execution.SparkPlan
import org.apache.spark.sql.execution.{CommandExecutionMode, SparkPlan}
import org.apache.spark.sql.execution.datasources._
import org.apache.spark.sql.sources.BaseRelation
import org.apache.spark.sql.types.StructType
@ -195,7 +195,7 @@ case class CreateDataSourceTableAsSelectCommand(
sessionState.executePlan(RepairTableCommand(
table.identifier,
enableAddPartitions = true,
enableDropPartitions = false)).toRdd
enableDropPartitions = false), CommandExecutionMode.SKIP).toRdd
case _ =>
}
}

View file

@ -0,0 +1,44 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.sql.expressions
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.Attribute
import org.apache.spark.sql.catalyst.plans.QueryPlan
import org.apache.spark.sql.catalyst.plans.logical.{LeafNode, LogicalPlan, Statistics}
import org.apache.spark.sql.catalyst.plans.logical.statsEstimation.EstimationUtils
import org.apache.spark.sql.execution.SparkPlan
/**
* Logical plan node for holding data from a command.
*
* `commandLogicalPlan` and `commandPhysicalPlan` are just used to display the plan tree
* for EXPLAIN.
* `rows` may not be serializable and ideally we should not send `rows` to the executors.
* Thus marking them as transient.
*/
case class CommandResult(
output: Seq[Attribute],
@transient commandLogicalPlan: LogicalPlan,
@transient commandPhysicalPlan: SparkPlan,
@transient rows: Seq[InternalRow]) extends LeafNode {
override def innerChildren: Seq[QueryPlan[_]] = Seq(commandLogicalPlan)
override def computeStats(): Statistics =
Statistics(sizeInBytes = EstimationUtils.getSizePerRow(output) * rows.length)
}

View file

@ -25,7 +25,7 @@ import org.apache.spark.sql.catalyst.parser.ParserInterface
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.connector.catalog.CatalogManager
import org.apache.spark.sql.execution.{ColumnarRule, QueryExecution, SparkOptimizer, SparkPlan, SparkPlanner, SparkSqlParser}
import org.apache.spark.sql.execution.{ColumnarRule, CommandExecutionMode, QueryExecution, SparkOptimizer, SparkPlan, SparkPlanner, SparkSqlParser}
import org.apache.spark.sql.execution.aggregate.ResolveEncodersInScalaAgg
import org.apache.spark.sql.execution.analysis.DetectAmbiguousSelfJoin
import org.apache.spark.sql.execution.command.CommandCheck
@ -310,9 +310,9 @@ abstract class BaseSessionStateBuilder(
/**
* Create a query execution object.
*/
protected def createQueryExecution: LogicalPlan => QueryExecution = { plan =>
new QueryExecution(session, plan)
}
protected def createQueryExecution:
(LogicalPlan, CommandExecutionMode.Value) => QueryExecution =
(plan, mode) => new QueryExecution(session, plan, mode = mode)
/**
* Interface to start and stop streaming queries.

View file

@ -76,7 +76,7 @@ private[sql] class SessionState(
val streamingQueryManagerBuilder: () => StreamingQueryManager,
val listenerManager: ExecutionListenerManager,
resourceLoaderBuilder: () => SessionResourceLoader,
createQueryExecution: LogicalPlan => QueryExecution,
createQueryExecution: (LogicalPlan, CommandExecutionMode.Value) => QueryExecution,
createClone: (SparkSession, SessionState) => SessionState,
val columnarRules: Seq[ColumnarRule],
val queryStagePrepRules: Seq[Rule[SparkPlan]]) {
@ -119,7 +119,10 @@ private[sql] class SessionState(
// Helper methods, partially leftover from pre-2.0 days
// ------------------------------------------------------
def executePlan(plan: LogicalPlan): QueryExecution = createQueryExecution(plan)
def executePlan(
plan: LogicalPlan,
mode: CommandExecutionMode.Value = CommandExecutionMode.ALL): QueryExecution =
createQueryExecution(plan, mode)
}
private[sql] object SessionState {

View file

@ -178,7 +178,7 @@ class FileDataSourceV2FallBackSuite extends QueryTest with SharedSparkSession {
inputData.write.format(format).save(path.getCanonicalPath)
sparkContext.listenerBus.waitUntilEmpty()
assert(commands.length == 1)
assert(commands.head._1 == "save")
assert(commands.head._1 == "command")
assert(commands.head._2.isInstanceOf[InsertIntoHadoopFsRelationCommand])
assert(commands.head._2.asInstanceOf[InsertIntoHadoopFsRelationCommand]
.fileFormat.isInstanceOf[ParquetFileFormat])

View file

@ -29,7 +29,7 @@ import org.apache.spark.sql.connector.catalog.{Identifier, InMemoryTableCatalog}
import org.apache.spark.sql.connector.distributions.{Distribution, Distributions}
import org.apache.spark.sql.connector.expressions.{Expression, FieldReference, NullOrdering, SortDirection, SortOrder}
import org.apache.spark.sql.connector.expressions.LogicalExpressions._
import org.apache.spark.sql.execution.{QueryExecution, SortExec, SparkPlan}
import org.apache.spark.sql.execution.{CommandResultExec, QueryExecution, SortExec, SparkPlan}
import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper
import org.apache.spark.sql.execution.datasources.v2.V2TableWriteExec
import org.apache.spark.sql.execution.exchange.ShuffleExchangeLike
@ -778,7 +778,8 @@ class WriteDistributionAndOrderingSuite
sparkContext.listenerBus.waitUntilEmpty()
executedPlan match {
assert(executedPlan.isInstanceOf[CommandResultExec])
executedPlan.asInstanceOf[CommandResultExec].commandPhysicalPlan match {
case w: V2TableWriteExec =>
stripAQEPlan(w.query)
case _ =>

View file

@ -19,9 +19,12 @@ package org.apache.spark.sql.execution
import scala.io.Source
import org.apache.spark.sql.{AnalysisException, FastOperator}
import org.apache.spark.sql.catalyst.analysis.UnresolvedNamespace
import org.apache.spark.sql.catalyst.plans.QueryPlan
import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, OneRowRelation}
import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, OneRowRelation, Project, ShowTables, SubqueryAlias}
import org.apache.spark.sql.catalyst.trees.TreeNodeTag
import org.apache.spark.sql.execution.command.{ExecutedCommandExec, ShowTablesCommand}
import org.apache.spark.sql.expressions.CommandResult
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.test.SharedSparkSession
import org.apache.spark.util.Utils
@ -236,4 +239,30 @@ class QueryExecutionSuite extends SharedSparkSession {
assert(df.queryExecution.optimizedPlan.toString.startsWith("Relation default.spark_34129["))
}
}
test("SPARK-35378: Eagerly execute non-root Command") {
def qe(logicalPlan: LogicalPlan): QueryExecution = new QueryExecution(spark, logicalPlan)
val showTables = ShowTables(UnresolvedNamespace(Seq.empty[String]), None)
val showTablesQe = qe(showTables)
assert(showTablesQe.commandExecuted.isInstanceOf[CommandResult])
assert(showTablesQe.executedPlan.isInstanceOf[CommandResultExec])
val showTablesResultExec = showTablesQe.executedPlan.asInstanceOf[CommandResultExec]
assert(showTablesResultExec.commandPhysicalPlan.isInstanceOf[ExecutedCommandExec])
assert(showTablesResultExec.commandPhysicalPlan.asInstanceOf[ExecutedCommandExec]
.cmd.isInstanceOf[ShowTablesCommand])
val project = Project(showTables.output, SubqueryAlias("s", showTables))
val projectQe = qe(project)
assert(projectQe.commandExecuted.isInstanceOf[Project])
assert(projectQe.commandExecuted.children.length == 1)
assert(projectQe.commandExecuted.children(0).isInstanceOf[SubqueryAlias])
assert(projectQe.commandExecuted.children(0).children.length == 1)
assert(projectQe.commandExecuted.children(0).children(0).isInstanceOf[CommandResult])
assert(projectQe.executedPlan.isInstanceOf[CommandResultExec])
val cmdResultExec = projectQe.executedPlan.asInstanceOf[CommandResultExec]
assert(cmdResultExec.commandPhysicalPlan.isInstanceOf[ExecutedCommandExec])
assert(cmdResultExec.commandPhysicalPlan.asInstanceOf[ExecutedCommandExec]
.cmd.isInstanceOf[ShowTablesCommand])
}
}

View file

@ -27,7 +27,7 @@ import org.apache.spark.scheduler.{SparkListener, SparkListenerEvent, SparkListe
import org.apache.spark.sql.{Dataset, QueryTest, Row, SparkSession, Strategy}
import org.apache.spark.sql.catalyst.optimizer.{BuildLeft, BuildRight}
import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, LogicalPlan}
import org.apache.spark.sql.execution.{LocalTableScanExec, PartialReducerPartitionSpec, QueryExecution, ReusedSubqueryExec, ShuffledRowRDD, SortExec, SparkPlan, UnaryExecNode}
import org.apache.spark.sql.execution.{CommandResultExec, LocalTableScanExec, PartialReducerPartitionSpec, QueryExecution, ReusedSubqueryExec, ShuffledRowRDD, SortExec, SparkPlan, UnaryExecNode}
import org.apache.spark.sql.execution.command.DataWritingCommandExec
import org.apache.spark.sql.execution.datasources.noop.NoopDataSource
import org.apache.spark.sql.execution.datasources.v2.V2TableWriteExec
@ -1035,8 +1035,11 @@ class AdaptiveQueryExecSuite
SQLConf.ADAPTIVE_EXECUTION_FORCE_APPLY.key -> "true") {
withTable("t1") {
val plan = sql("CREATE TABLE t1 USING parquet AS SELECT 1 col").queryExecution.executedPlan
assert(plan.isInstanceOf[DataWritingCommandExec])
assert(plan.asInstanceOf[DataWritingCommandExec].child.isInstanceOf[AdaptiveSparkPlanExec])
assert(plan.isInstanceOf[CommandResultExec])
val commandResultExec = plan.asInstanceOf[CommandResultExec]
assert(commandResultExec.commandPhysicalPlan.isInstanceOf[DataWritingCommandExec])
assert(commandResultExec.commandPhysicalPlan.asInstanceOf[DataWritingCommandExec]
.child.isInstanceOf[AdaptiveSparkPlanExec])
}
}
}

View file

@ -25,7 +25,7 @@ import scala.util.Random
import org.apache.spark.sql._
import org.apache.spark.sql.catalyst.expressions.aggregate.{Final, Partial}
import org.apache.spark.sql.catalyst.plans.logical.LocalRelation
import org.apache.spark.sql.execution.{FilterExec, RangeExec, SparkPlan, WholeStageCodegenExec}
import org.apache.spark.sql.execution._
import org.apache.spark.sql.execution.adaptive.DisableAdaptiveExecutionSuite
import org.apache.spark.sql.execution.aggregate.HashAggregateExec
import org.apache.spark.sql.execution.command.DataWritingCommandExec
@ -791,9 +791,10 @@ class SQLMetricsSuite extends SharedSparkSession with SQLMetricsTestUtils
test("SPARK-34567: Add metrics for CTAS operator") {
withTable("t") {
val df = sql("CREATE TABLE t USING PARQUET AS SELECT 1 as a")
assert(df.queryExecution.executedPlan.isInstanceOf[CommandResultExec])
val commandResultExec = df.queryExecution.executedPlan.asInstanceOf[CommandResultExec]
val dataWritingCommandExec =
df.queryExecution.executedPlan.asInstanceOf[DataWritingCommandExec]
dataWritingCommandExec.executeCollect()
commandResultExec.commandPhysicalPlan.asInstanceOf[DataWritingCommandExec]
val createTableAsSelect = dataWritingCommandExec.cmd
assert(createTableAsSelect.metrics.contains("numFiles"))
assert(createTableAsSelect.metrics("numFiles").value == 1)

View file

@ -20,14 +20,13 @@ package org.apache.spark.sql.util
import scala.collection.mutable.ArrayBuffer
import org.apache.spark._
import org.apache.spark.sql.{functions, AnalysisException, Dataset, QueryTest, Row, SparkSession}
import org.apache.spark.sql.catalyst.analysis.UnresolvedRelation
import org.apache.spark.sql.{functions, Dataset, QueryTest, Row, SparkSession}
import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference}
import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, InsertIntoStatement, LogicalPlan, Project}
import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, LogicalPlan, Project}
import org.apache.spark.sql.execution.{QueryExecution, QueryExecutionException, WholeStageCodegenExec}
import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper
import org.apache.spark.sql.execution.command.LeafRunnableCommand
import org.apache.spark.sql.execution.datasources.{CreateTable, InsertIntoHadoopFsRelationCommand}
import org.apache.spark.sql.execution.command.{CreateDataSourceTableAsSelectCommand, LeafRunnableCommand}
import org.apache.spark.sql.execution.datasources.InsertIntoHadoopFsRelationCommand
import org.apache.spark.sql.execution.datasources.json.JsonFileFormat
import org.apache.spark.sql.test.SharedSparkSession
import org.apache.spark.sql.types.StringType
@ -194,7 +193,7 @@ class DataFrameCallbackSuite extends QueryTest
spark.range(10).write.format("json").save(path.getCanonicalPath)
sparkContext.listenerBus.waitUntilEmpty()
assert(commands.length == 1)
assert(commands.head._1 == "save")
assert(commands.head._1 == "command")
assert(commands.head._2.isInstanceOf[InsertIntoHadoopFsRelationCommand])
assert(commands.head._2.asInstanceOf[InsertIntoHadoopFsRelationCommand]
.fileFormat.isInstanceOf[JsonFileFormat])
@ -205,10 +204,10 @@ class DataFrameCallbackSuite extends QueryTest
spark.range(10).write.insertInto("tab")
sparkContext.listenerBus.waitUntilEmpty()
assert(commands.length == 3)
assert(commands(2)._1 == "insertInto")
assert(commands(2)._2.isInstanceOf[InsertIntoStatement])
assert(commands(2)._2.asInstanceOf[InsertIntoStatement].table
.asInstanceOf[UnresolvedRelation].multipartIdentifier == Seq("tab"))
assert(commands(2)._1 == "command")
assert(commands(2)._2.isInstanceOf[InsertIntoHadoopFsRelationCommand])
assert(commands(2)._2.asInstanceOf[InsertIntoHadoopFsRelationCommand]
.catalogTable.get.identifier.identifier == "tab")
}
// exiting withTable adds commands(3) via onSuccess (drops tab)
@ -216,19 +215,21 @@ class DataFrameCallbackSuite extends QueryTest
spark.range(10).select($"id", $"id" % 5 as "p").write.partitionBy("p").saveAsTable("tab")
sparkContext.listenerBus.waitUntilEmpty()
assert(commands.length == 5)
assert(commands(4)._1 == "saveAsTable")
assert(commands(4)._2.isInstanceOf[CreateTable])
assert(commands(4)._2.asInstanceOf[CreateTable].tableDesc.partitionColumnNames == Seq("p"))
assert(commands(4)._1 == "command")
assert(commands(4)._2.isInstanceOf[CreateDataSourceTableAsSelectCommand])
assert(commands(4)._2.asInstanceOf[CreateDataSourceTableAsSelectCommand]
.table.partitionColumnNames == Seq("p"))
}
withTable("tab") {
sql("CREATE TABLE tab(i long) using parquet")
val e = intercept[AnalysisException] {
spark.range(10).select($"id", $"id").write.insertInto("tab")
spark.udf.register("illegalUdf", udf((value: Long) => value / 0))
val e = intercept[SparkException] {
spark.range(10).selectExpr("illegalUdf(id)").write.insertInto("tab")
}
sparkContext.listenerBus.waitUntilEmpty()
assert(exceptions.length == 1)
assert(exceptions.head._1 == "insertInto")
assert(exceptions.head._1 == "command")
assert(exceptions.head._2 == e)
}
}

View file

@ -17,6 +17,7 @@
package org.apache.spark.sql.hive.execution
import org.apache.spark.sql.execution.CommandResultExec
import org.apache.spark.sql.execution.adaptive.DisableAdaptiveExecutionSuite
import org.apache.spark.sql.execution.command.DataWritingCommandExec
import org.apache.spark.sql.execution.metric.SQLMetricsTestUtils
@ -42,9 +43,10 @@ class SQLMetricsSuite extends SQLMetricsTestUtils with TestHiveSingleton
withSQLConf(HiveUtils.CONVERT_METASTORE_CTAS.key -> canOptimized.toString) {
withTable("t") {
val df = sql(s"CREATE TABLE t STORED AS PARQUET AS SELECT 1 as a")
assert(df.queryExecution.executedPlan.isInstanceOf[CommandResultExec])
val commandResultExec = df.queryExecution.executedPlan.asInstanceOf[CommandResultExec]
val dataWritingCommandExec =
df.queryExecution.executedPlan.asInstanceOf[DataWritingCommandExec]
dataWritingCommandExec.executeCollect()
commandResultExec.commandPhysicalPlan.asInstanceOf[DataWritingCommandExec]
val createTableAsSelect = dataWritingCommandExec.cmd
if (canOptimized) {
assert(createTableAsSelect.isInstanceOf[OptimizedCreateHiveTableAsSelectCommand])

View file

@ -42,7 +42,7 @@ import org.apache.spark.sql.catalyst.optimizer.ConvertToLocalRelation
import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, OneRowRelation}
import org.apache.spark.sql.connector.catalog.CatalogManager
import org.apache.spark.sql.connector.catalog.CatalogV2Implicits._
import org.apache.spark.sql.execution.{QueryExecution, SQLExecution}
import org.apache.spark.sql.execution.{CommandExecutionMode, QueryExecution, SQLExecution}
import org.apache.spark.sql.hive._
import org.apache.spark.sql.hive.client.HiveClient
import org.apache.spark.sql.internal.{SessionState, SharedState, SQLConf, WithTestConf}
@ -584,8 +584,9 @@ private[hive] class TestHiveSparkSession(
private[hive] class TestHiveQueryExecution(
sparkSession: TestHiveSparkSession,
logicalPlan: LogicalPlan)
extends QueryExecution(sparkSession, logicalPlan) with Logging {
logicalPlan: LogicalPlan,
mode: CommandExecutionMode.Value = CommandExecutionMode.ALL)
extends QueryExecution(sparkSession, logicalPlan, mode = mode) with Logging {
def this(sparkSession: TestHiveSparkSession, sql: String) = {
this(sparkSession, sparkSession.sessionState.sqlParser.parsePlan(sql))
@ -661,9 +662,10 @@ private[sql] class TestHiveSessionStateBuilder(
override def overrideConfs: Map[String, String] = TestHiveContext.overrideConfs
override def createQueryExecution: (LogicalPlan) => QueryExecution = { plan =>
new TestHiveQueryExecution(session.asInstanceOf[TestHiveSparkSession], plan)
}
override def createQueryExecution:
(LogicalPlan, CommandExecutionMode.Value) => QueryExecution =
(plan, mode) =>
new TestHiveQueryExecution(session.asInstanceOf[TestHiveSparkSession], plan, mode)
override protected def newBuilder: NewBuilder = new TestHiveSessionStateBuilder(_, _)
}