[SPARK-33494][SQL][AQE] Do not use local shuffle reader for repartition

### What changes were proposed in this pull request?

This PR updates `ShuffleExchangeExec` to carry more information about how much we can change the partitioning. For `repartition(col)`, we should preserve the user-specified partitioning and don't apply the AQE local shuffle reader.

### Why are the changes needed?

Similar to `repartition(number, col)`, we should respect the user-specified partitioning.

### Does this PR introduce _any_ user-facing change?

No

### How was this patch tested?

a new test

Closes #30432 from cloud-fan/aqe.

Authored-by: Wenchen Fan <wenchen@databricks.com>
Signed-off-by: Wenchen Fan <wenchen@databricks.com>
This commit is contained in:
Wenchen Fan 2020-11-25 02:02:32 +00:00
parent 01321bc0fe
commit d1b4f06179
8 changed files with 86 additions and 48 deletions

View file

@ -32,7 +32,7 @@ import org.apache.spark.sql.catalyst.streaming.{InternalOutputModes, StreamingRe
import org.apache.spark.sql.execution.aggregate.AggUtils
import org.apache.spark.sql.execution.columnar.{InMemoryRelation, InMemoryTableScanExec}
import org.apache.spark.sql.execution.command._
import org.apache.spark.sql.execution.exchange.ShuffleExchangeExec
import org.apache.spark.sql.execution.exchange.{REPARTITION, REPARTITION_WITH_NUM, ShuffleExchangeExec}
import org.apache.spark.sql.execution.python._
import org.apache.spark.sql.execution.streaming._
import org.apache.spark.sql.execution.streaming.sources.MemoryPlan
@ -670,7 +670,7 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
case logical.Repartition(numPartitions, shuffle, child) =>
if (shuffle) {
ShuffleExchangeExec(RoundRobinPartitioning(numPartitions),
planLater(child), noUserSpecifiedNumPartition = false) :: Nil
planLater(child), REPARTITION_WITH_NUM) :: Nil
} else {
execution.CoalesceExec(numPartitions, planLater(child)) :: Nil
}
@ -703,10 +703,12 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
case r: logical.Range =>
execution.RangeExec(r) :: Nil
case r: logical.RepartitionByExpression =>
exchange.ShuffleExchangeExec(
r.partitioning,
planLater(r.child),
noUserSpecifiedNumPartition = r.optNumPartitions.isEmpty) :: Nil
val shuffleOrigin = if (r.optNumPartitions.isEmpty) {
REPARTITION
} else {
REPARTITION_WITH_NUM
}
exchange.ShuffleExchangeExec(r.partitioning, planLater(r.child), shuffleOrigin) :: Nil
case ExternalRDD(outputObjAttr, rdd) => ExternalRDDScanExec(outputObjAttr, rdd) :: Nil
case r: LogicalRDD =>
RDDScanExec(r.output, r.rdd, "ExistingRDD", r.outputPartitioning, r.outputOrdering) :: Nil

View file

@ -18,8 +18,10 @@
package org.apache.spark.sql.execution.adaptive
import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.catalyst.plans.physical.SinglePartition
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.execution.SparkPlan
import org.apache.spark.sql.execution.exchange.{ENSURE_REQUIREMENTS, REPARTITION, ShuffleExchangeLike}
import org.apache.spark.sql.internal.SQLConf
/**
@ -47,7 +49,7 @@ case class CoalesceShufflePartitions(session: SparkSession) extends Rule[SparkPl
val shuffleStages = collectShuffleStages(plan)
// ShuffleExchanges introduced by repartition do not support changing the number of partitions.
// We change the number of partitions in the stage only if all the ShuffleExchanges support it.
if (!shuffleStages.forall(_.shuffle.canChangeNumPartitions)) {
if (!shuffleStages.forall(s => supportCoalesce(s.shuffle))) {
plan
} else {
// `ShuffleQueryStageExec#mapStats` returns None when the input RDD has 0 partitions,
@ -82,4 +84,9 @@ case class CoalesceShufflePartitions(session: SparkSession) extends Rule[SparkPl
}
}
}
private def supportCoalesce(s: ShuffleExchangeLike): Boolean = {
s.outputPartitioning != SinglePartition &&
(s.shuffleOrigin == ENSURE_REQUIREMENTS || s.shuffleOrigin == REPARTITION)
}
}

View file

@ -18,9 +18,10 @@
package org.apache.spark.sql.execution.adaptive
import org.apache.spark.sql.catalyst.optimizer.{BuildLeft, BuildRight, BuildSide}
import org.apache.spark.sql.catalyst.plans.physical.SinglePartition
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.execution._
import org.apache.spark.sql.execution.exchange.{EnsureRequirements, ShuffleExchangeExec}
import org.apache.spark.sql.execution.exchange.{ENSURE_REQUIREMENTS, EnsureRequirements, ShuffleExchangeExec, ShuffleExchangeLike}
import org.apache.spark.sql.execution.joins.BroadcastHashJoinExec
import org.apache.spark.sql.internal.SQLConf
@ -136,9 +137,13 @@ object OptimizeLocalShuffleReader extends Rule[SparkPlan] {
def canUseLocalShuffleReader(plan: SparkPlan): Boolean = plan match {
case s: ShuffleQueryStageExec =>
s.shuffle.canChangeNumPartitions && s.mapStats.isDefined
s.mapStats.isDefined && supportLocalReader(s.shuffle)
case CustomShuffleReaderExec(s: ShuffleQueryStageExec, partitionSpecs) =>
s.shuffle.canChangeNumPartitions && s.mapStats.isDefined && partitionSpecs.nonEmpty
s.mapStats.isDefined && partitionSpecs.nonEmpty && supportLocalReader(s.shuffle)
case _ => false
}
private def supportLocalReader(s: ShuffleExchangeLike): Boolean = {
s.outputPartitioning != SinglePartition && s.shuffleOrigin == ENSURE_REQUIREMENTS
}
}

View file

@ -57,9 +57,9 @@ trait ShuffleExchangeLike extends Exchange {
def numPartitions: Int
/**
* Returns whether the shuffle partition number can be changed.
* The origin of this shuffle operator.
*/
def canChangeNumPartitions: Boolean
def shuffleOrigin: ShuffleOrigin
/**
* The asynchronous job that materializes the shuffle.
@ -77,18 +77,30 @@ trait ShuffleExchangeLike extends Exchange {
def runtimeStatistics: Statistics
}
// Describes where the shuffle operator comes from.
sealed trait ShuffleOrigin
// Indicates that the shuffle operator was added by the internal `EnsureRequirements` rule. It
// means that the shuffle operator is used to ensure internal data partitioning requirements and
// Spark is free to optimize it as long as the requirements are still ensured.
case object ENSURE_REQUIREMENTS extends ShuffleOrigin
// Indicates that the shuffle operator was added by the user-specified repartition operator. Spark
// can still optimize it via changing shuffle partition number, as data partitioning won't change.
case object REPARTITION extends ShuffleOrigin
// Indicates that the shuffle operator was added by the user-specified repartition operator with
// a certain partition number. Spark can't optimize it.
case object REPARTITION_WITH_NUM extends ShuffleOrigin
/**
* Performs a shuffle that will result in the desired partitioning.
*/
case class ShuffleExchangeExec(
override val outputPartitioning: Partitioning,
child: SparkPlan,
noUserSpecifiedNumPartition: Boolean = true) extends ShuffleExchangeLike {
// If users specify the num partitions via APIs like `repartition`, we shouldn't change it.
// For `SinglePartition`, it requires exactly one partition and we can't change it either.
override def canChangeNumPartitions: Boolean =
noUserSpecifiedNumPartition && outputPartitioning != SinglePartition
shuffleOrigin: ShuffleOrigin = ENSURE_REQUIREMENTS)
extends ShuffleExchangeLike {
private lazy val writeMetrics =
SQLShuffleWriteMetricsReporter.createShuffleWriteMetrics(sparkContext)

View file

@ -1,5 +1,5 @@
-- Automatically generated by SQLQueryTestSuite
-- Number of queries: 23
-- Number of queries: 24
-- !query
@ -67,10 +67,10 @@ Aggregate [sum(distinct cast(val#x as bigint)) AS sum(DISTINCT val)#xL]
== Physical Plan ==
AdaptiveSparkPlan isFinalPlan=false
+- HashAggregate(keys=[], functions=[sum(distinct cast(val#x as bigint)#xL)], output=[sum(DISTINCT val)#xL])
+- Exchange SinglePartition, true, [id=#x]
+- Exchange SinglePartition, ENSURE_REQUIREMENTS, [id=#x]
+- HashAggregate(keys=[], functions=[partial_sum(distinct cast(val#x as bigint)#xL)], output=[sum#xL])
+- HashAggregate(keys=[cast(val#x as bigint)#xL], functions=[], output=[cast(val#x as bigint)#xL])
+- Exchange hashpartitioning(cast(val#x as bigint)#xL, 4), true, [id=#x]
+- Exchange hashpartitioning(cast(val#x as bigint)#xL, 4), ENSURE_REQUIREMENTS, [id=#x]
+- HashAggregate(keys=[cast(val#x as bigint) AS cast(val#x as bigint)#xL], functions=[], output=[cast(val#x as bigint)#xL])
+- FileScan parquet default.explain_temp1[val#x] Batched: true, DataFilters: [], Format: Parquet, Location [not included in comparison]/{warehouse_dir}/explain_temp1], PartitionFilters: [], PushedFilters: [], ReadSchema: struct<val:int>
@ -116,7 +116,7 @@ Results [2]: [key#x, max#x]
(4) Exchange
Input [2]: [key#x, max#x]
Arguments: hashpartitioning(key#x, 4), true, [id=#x]
Arguments: hashpartitioning(key#x, 4), ENSURE_REQUIREMENTS, [id=#x]
(5) HashAggregate
Input [2]: [key#x, max#x]
@ -127,7 +127,7 @@ Results [2]: [key#x, max(val#x)#x AS max(val)#x]
(6) Exchange
Input [2]: [key#x, max(val)#x]
Arguments: rangepartitioning(key#x ASC NULLS FIRST, 4), true, [id=#x]
Arguments: rangepartitioning(key#x ASC NULLS FIRST, 4), ENSURE_REQUIREMENTS, [id=#x]
(7) Sort
Input [2]: [key#x, max(val)#x]
@ -179,7 +179,7 @@ Results [2]: [key#x, max#x]
(4) Exchange
Input [2]: [key#x, max#x]
Arguments: hashpartitioning(key#x, 4), true, [id=#x]
Arguments: hashpartitioning(key#x, 4), ENSURE_REQUIREMENTS, [id=#x]
(5) HashAggregate
Input [2]: [key#x, max#x]
@ -254,7 +254,7 @@ Results [2]: [key#x, val#x]
(7) Exchange
Input [2]: [key#x, val#x]
Arguments: hashpartitioning(key#x, val#x, 4), true, [id=#x]
Arguments: hashpartitioning(key#x, val#x, 4), ENSURE_REQUIREMENTS, [id=#x]
(8) HashAggregate
Input [2]: [key#x, val#x]
@ -576,7 +576,7 @@ Results [2]: [key#x, max#x]
(4) Exchange
Input [2]: [key#x, max#x]
Arguments: hashpartitioning(key#x, 4), true, [id=#x]
Arguments: hashpartitioning(key#x, 4), ENSURE_REQUIREMENTS, [id=#x]
(5) HashAggregate
Input [2]: [key#x, max#x]
@ -605,7 +605,7 @@ Results [2]: [key#x, max#x]
(9) Exchange
Input [2]: [key#x, max#x]
Arguments: hashpartitioning(key#x, 4), true, [id=#x]
Arguments: hashpartitioning(key#x, 4), ENSURE_REQUIREMENTS, [id=#x]
(10) HashAggregate
Input [2]: [key#x, max#x]
@ -687,7 +687,7 @@ Results [3]: [count#xL, sum#xL, count#xL]
(3) Exchange
Input [3]: [count#xL, sum#xL, count#xL]
Arguments: SinglePartition, true, [id=#x]
Arguments: SinglePartition, ENSURE_REQUIREMENTS, [id=#x]
(4) HashAggregate
Input [3]: [count#xL, sum#xL, count#xL]
@ -732,7 +732,7 @@ Results [2]: [key#x, buf#x]
(3) Exchange
Input [2]: [key#x, buf#x]
Arguments: hashpartitioning(key#x, 4), true, [id=#x]
Arguments: hashpartitioning(key#x, 4), ENSURE_REQUIREMENTS, [id=#x]
(4) ObjectHashAggregate
Input [2]: [key#x, buf#x]
@ -783,7 +783,7 @@ Results [2]: [key#x, min#x]
(4) Exchange
Input [2]: [key#x, min#x]
Arguments: hashpartitioning(key#x, 4), true, [id=#x]
Arguments: hashpartitioning(key#x, 4), ENSURE_REQUIREMENTS, [id=#x]
(5) Sort
Input [2]: [key#x, min#x]

View file

@ -1,5 +1,5 @@
-- Automatically generated by SQLQueryTestSuite
-- Number of queries: 23
-- Number of queries: 24
-- !query
@ -66,10 +66,10 @@ Aggregate [sum(distinct cast(val#x as bigint)) AS sum(DISTINCT val)#xL]
== Physical Plan ==
*HashAggregate(keys=[], functions=[sum(distinct cast(val#x as bigint)#xL)], output=[sum(DISTINCT val)#xL])
+- Exchange SinglePartition, true, [id=#x]
+- Exchange SinglePartition, ENSURE_REQUIREMENTS, [id=#x]
+- *HashAggregate(keys=[], functions=[partial_sum(distinct cast(val#x as bigint)#xL)], output=[sum#xL])
+- *HashAggregate(keys=[cast(val#x as bigint)#xL], functions=[], output=[cast(val#x as bigint)#xL])
+- Exchange hashpartitioning(cast(val#x as bigint)#xL, 4), true, [id=#x]
+- Exchange hashpartitioning(cast(val#x as bigint)#xL, 4), ENSURE_REQUIREMENTS, [id=#x]
+- *HashAggregate(keys=[cast(val#x as bigint) AS cast(val#x as bigint)#xL], functions=[], output=[cast(val#x as bigint)#xL])
+- *ColumnarToRow
+- FileScan parquet default.explain_temp1[val#x] Batched: true, DataFilters: [], Format: Parquet, Location [not included in comparison]/{warehouse_dir}/explain_temp1], PartitionFilters: [], PushedFilters: [], ReadSchema: struct<val:int>
@ -119,7 +119,7 @@ Results [2]: [key#x, max#x]
(5) Exchange
Input [2]: [key#x, max#x]
Arguments: hashpartitioning(key#x, 4), true, [id=#x]
Arguments: hashpartitioning(key#x, 4), ENSURE_REQUIREMENTS, [id=#x]
(6) HashAggregate [codegen id : 2]
Input [2]: [key#x, max#x]
@ -130,7 +130,7 @@ Results [2]: [key#x, max(val#x)#x AS max(val)#x]
(7) Exchange
Input [2]: [key#x, max(val)#x]
Arguments: rangepartitioning(key#x ASC NULLS FIRST, 4), true, [id=#x]
Arguments: rangepartitioning(key#x ASC NULLS FIRST, 4), ENSURE_REQUIREMENTS, [id=#x]
(8) Sort [codegen id : 3]
Input [2]: [key#x, max(val)#x]
@ -181,7 +181,7 @@ Results [2]: [key#x, max#x]
(5) Exchange
Input [2]: [key#x, max#x]
Arguments: hashpartitioning(key#x, 4), true, [id=#x]
Arguments: hashpartitioning(key#x, 4), ENSURE_REQUIREMENTS, [id=#x]
(6) HashAggregate [codegen id : 2]
Input [2]: [key#x, max#x]
@ -259,7 +259,7 @@ Results [2]: [key#x, val#x]
(9) Exchange
Input [2]: [key#x, val#x]
Arguments: hashpartitioning(key#x, val#x, 4), true, [id=#x]
Arguments: hashpartitioning(key#x, val#x, 4), ENSURE_REQUIREMENTS, [id=#x]
(10) HashAggregate [codegen id : 4]
Input [2]: [key#x, val#x]
@ -452,7 +452,7 @@ Results [1]: [max#x]
(9) Exchange
Input [1]: [max#x]
Arguments: SinglePartition, true, [id=#x]
Arguments: SinglePartition, ENSURE_REQUIREMENTS, [id=#x]
(10) HashAggregate [codegen id : 2]
Input [1]: [max#x]
@ -498,7 +498,7 @@ Results [1]: [max#x]
(16) Exchange
Input [1]: [max#x]
Arguments: SinglePartition, true, [id=#x]
Arguments: SinglePartition, ENSURE_REQUIREMENTS, [id=#x]
(17) HashAggregate [codegen id : 2]
Input [1]: [max#x]
@ -580,7 +580,7 @@ Results [1]: [max#x]
(9) Exchange
Input [1]: [max#x]
Arguments: SinglePartition, true, [id=#x]
Arguments: SinglePartition, ENSURE_REQUIREMENTS, [id=#x]
(10) HashAggregate [codegen id : 2]
Input [1]: [max#x]
@ -626,7 +626,7 @@ Results [2]: [sum#x, count#xL]
(16) Exchange
Input [2]: [sum#x, count#xL]
Arguments: SinglePartition, true, [id=#x]
Arguments: SinglePartition, ENSURE_REQUIREMENTS, [id=#x]
(17) HashAggregate [codegen id : 2]
Input [2]: [sum#x, count#xL]
@ -690,7 +690,7 @@ Results [2]: [sum#x, count#xL]
(7) Exchange
Input [2]: [sum#x, count#xL]
Arguments: SinglePartition, true, [id=#x]
Arguments: SinglePartition, ENSURE_REQUIREMENTS, [id=#x]
(8) HashAggregate [codegen id : 2]
Input [2]: [sum#x, count#xL]
@ -810,7 +810,7 @@ Results [2]: [key#x, max#x]
(5) Exchange
Input [2]: [key#x, max#x]
Arguments: hashpartitioning(key#x, 4), true, [id=#x]
Arguments: hashpartitioning(key#x, 4), ENSURE_REQUIREMENTS, [id=#x]
(6) HashAggregate [codegen id : 4]
Input [2]: [key#x, max#x]
@ -901,7 +901,7 @@ Results [3]: [count#xL, sum#xL, count#xL]
(4) Exchange
Input [3]: [count#xL, sum#xL, count#xL]
Arguments: SinglePartition, true, [id=#x]
Arguments: SinglePartition, ENSURE_REQUIREMENTS, [id=#x]
(5) HashAggregate [codegen id : 2]
Input [3]: [count#xL, sum#xL, count#xL]
@ -945,7 +945,7 @@ Results [2]: [key#x, buf#x]
(4) Exchange
Input [2]: [key#x, buf#x]
Arguments: hashpartitioning(key#x, 4), true, [id=#x]
Arguments: hashpartitioning(key#x, 4), ENSURE_REQUIREMENTS, [id=#x]
(5) ObjectHashAggregate
Input [2]: [key#x, buf#x]
@ -995,7 +995,7 @@ Results [2]: [key#x, min#x]
(5) Exchange
Input [2]: [key#x, min#x]
Arguments: hashpartitioning(key#x, 4), true, [id=#x]
Arguments: hashpartitioning(key#x, 4), ENSURE_REQUIREMENTS, [id=#x]
(6) Sort [codegen id : 2]
Input [2]: [key#x, min#x]

View file

@ -33,7 +33,7 @@ import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.catalyst.trees.TreeNodeTag
import org.apache.spark.sql.execution._
import org.apache.spark.sql.execution.adaptive.{AdaptiveSparkPlanExec, QueryStageExec}
import org.apache.spark.sql.execution.exchange.{BroadcastExchangeExec, BroadcastExchangeLike, ShuffleExchangeExec, ShuffleExchangeLike}
import org.apache.spark.sql.execution.exchange.{BroadcastExchangeExec, BroadcastExchangeLike, ShuffleExchangeExec, ShuffleExchangeLike, ShuffleOrigin}
import org.apache.spark.sql.execution.vectorized.OnHeapColumnVector
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.internal.SQLConf.COLUMN_BATCH_SIZE
@ -766,7 +766,9 @@ case class PreRuleReplaceAddWithBrokenVersion() extends Rule[SparkPlan] {
case class MyShuffleExchangeExec(delegate: ShuffleExchangeExec) extends ShuffleExchangeLike {
override def numMappers: Int = delegate.numMappers
override def numPartitions: Int = delegate.numPartitions
override def canChangeNumPartitions: Boolean = delegate.canChangeNumPartitions
override def shuffleOrigin: ShuffleOrigin = {
delegate.shuffleOrigin
}
override def mapOutputStatisticsFuture: Future[MapOutputStatistics] =
delegate.mapOutputStatisticsFuture
override def getShuffleRDD(partitionSpecs: Array[ShufflePartitionSpec]): RDD[_] =

View file

@ -1307,4 +1307,14 @@ class AdaptiveQueryExecSuite
spark.listenerManager.unregister(listener)
}
}
test("SPARK-33494: Do not use local shuffle reader for repartition") {
withSQLConf(SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true") {
val df = spark.table("testData").repartition('key)
df.collect()
// local shuffle reader breaks partitioning and shouldn't be used for repartition operation
// which is specified by users.
checkNumLocalShuffleReaders(df.queryExecution.executedPlan, numShufflesWithoutLocalReader = 1)
}
}
}