diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/AdaptiveSparkPlanExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/AdaptiveSparkPlanExec.scala index 0e032569bb..5e75e26e6d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/AdaptiveSparkPlanExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/AdaptiveSparkPlanExec.scala @@ -36,6 +36,8 @@ import org.apache.spark.sql.catalyst.rules.{PlanChangeLogger, Rule} import org.apache.spark.sql.catalyst.trees.TreeNodeTag import org.apache.spark.sql.execution._ import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanExec._ +import org.apache.spark.sql.execution.command.DataWritingCommandExec +import org.apache.spark.sql.execution.datasources.v2.V2TableWriteExec import org.apache.spark.sql.execution.exchange._ import org.apache.spark.sql.execution.ui.{SparkListenerSQLAdaptiveExecutionUpdate, SparkListenerSQLAdaptiveSQLMetricUpdates, SQLPlanMetric} import org.apache.spark.sql.internal.SQLConf @@ -102,6 +104,16 @@ case class AdaptiveSparkPlanExec( OptimizeLocalShuffleReader(conf) ) + private def finalStageOptimizerRules: Seq[Rule[SparkPlan]] = + context.qe.sparkPlan match { + case _: DataWritingCommandExec | _: V2TableWriteExec => + // SPARK-32932: Local shuffle reader could break partitioning that works best + // for the following writing command + queryStageOptimizerRules.filterNot(_.isInstanceOf[OptimizeLocalShuffleReader]) + case _ => + queryStageOptimizerRules + } + // A list of physical optimizer rules to be applied right after a new stage is created. The input // plan to these rules has exchange as its root node. @transient private val postStageCreationRules = Seq( @@ -235,7 +247,7 @@ case class AdaptiveSparkPlanExec( // Run the final plan when there's no more unfinished stages. currentPhysicalPlan = applyPhysicalRules( result.newPlan, - queryStageOptimizerRules ++ postStageCreationRules, + finalStageOptimizerRules ++ postStageCreationRules, Some((planChangeLogger, "AQE Final Query Stage Optimization"))) isFinalPlan = true executionId.foreach(onUpdatePlan(_, Seq(currentPhysicalPlan))) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/adaptive/AdaptiveQueryExecSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/adaptive/AdaptiveQueryExecSuite.scala index 0dfb1d2fd9..38a323b1c0 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/adaptive/AdaptiveQueryExecSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/adaptive/AdaptiveQueryExecSuite.scala @@ -26,15 +26,19 @@ import org.apache.spark.scheduler.{SparkListener, SparkListenerEvent, SparkListe import org.apache.spark.sql.{Dataset, QueryTest, Row, SparkSession, Strategy} import org.apache.spark.sql.catalyst.optimizer.{BuildLeft, BuildRight} import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, LogicalPlan} -import org.apache.spark.sql.execution.{PartialReducerPartitionSpec, ReusedSubqueryExec, ShuffledRowRDD, SparkPlan} +import org.apache.spark.sql.execution.{PartialReducerPartitionSpec, QueryExecution, ReusedSubqueryExec, ShuffledRowRDD, SparkPlan, UnaryExecNode} import org.apache.spark.sql.execution.command.DataWritingCommandExec +import org.apache.spark.sql.execution.datasources.noop.NoopDataSource +import org.apache.spark.sql.execution.datasources.v2.V2TableWriteExec import org.apache.spark.sql.execution.exchange.{BroadcastExchangeExec, Exchange, ReusedExchangeExec, ShuffleExchangeExec} import org.apache.spark.sql.execution.joins.{BaseJoinExec, BroadcastHashJoinExec, SortMergeJoinExec} import org.apache.spark.sql.execution.ui.SparkListenerSQLAdaptiveExecutionUpdate import org.apache.spark.sql.functions._ import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.internal.SQLConf.PartitionOverwriteMode import org.apache.spark.sql.test.SharedSparkSession import org.apache.spark.sql.types.{IntegerType, StructType} +import org.apache.spark.sql.util.QueryExecutionListener import org.apache.spark.util.Utils class AdaptiveQueryExecSuite @@ -1258,4 +1262,49 @@ class AdaptiveQueryExecSuite } } } + + test("SPARK-32932: Do not use local shuffle reader at final stage on write command") { + withSQLConf(SQLConf.PARTITION_OVERWRITE_MODE.key -> PartitionOverwriteMode.DYNAMIC.toString, + SQLConf.SHUFFLE_PARTITIONS.key -> "5", + SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true") { + val data = for ( + i <- 1L to 10L; + j <- 1L to 3L + ) yield (i, j) + + val df = data.toDF("i", "j").repartition($"j") + var noLocalReader: Boolean = false + val listener = new QueryExecutionListener { + override def onSuccess(funcName: String, qe: QueryExecution, durationNs: Long): Unit = { + qe.executedPlan match { + case plan@(_: DataWritingCommandExec | _: V2TableWriteExec) => + assert(plan.asInstanceOf[UnaryExecNode].child.isInstanceOf[AdaptiveSparkPlanExec]) + noLocalReader = collect(plan) { + case exec: CustomShuffleReaderExec if exec.isLocalReader => exec + }.isEmpty + case _ => // ignore other events + } + } + override def onFailure(funcName: String, qe: QueryExecution, + exception: Exception): Unit = {} + } + spark.listenerManager.register(listener) + + withTable("t") { + df.write.partitionBy("j").saveAsTable("t") + sparkContext.listenerBus.waitUntilEmpty() + assert(noLocalReader) + noLocalReader = false + } + + // Test DataSource v2 + val format = classOf[NoopDataSource].getName + df.write.format(format).mode("overwrite").save() + sparkContext.listenerBus.waitUntilEmpty() + assert(noLocalReader) + noLocalReader = false + + spark.listenerManager.unregister(listener) + } + } }