[SPARK-34899][SQL] Use origin plan if we can not coalesce shuffle partition
### What changes were proposed in this pull request? Add check if `CoalesceShufflePartitions` really coalesce shuffle partition number. ### Why are the changes needed? The `CoalesceShufflePartitions` can not coalesce such case if the total shuffle partitions size of mappers are big enough. Then it's confused to use `CustomShuffleReaderExec` which marked as `coalesced` but has no affect with partition number. ### Does this PR introduce _any_ user-facing change? Probably yes, the plan changed. ### How was this patch tested? Add test. Closes #31994 from ulysses-you/SPARK-34899. Authored-by: ulysses-you <ulyssesyou18@gmail.com> Signed-off-by: Wenchen Fan <wenchen@databricks.com>
This commit is contained in:
parent
de66fa63f9
commit
24d39a5ee2
|
@ -72,14 +72,20 @@ case class CoalesceShufflePartitions(session: SparkSession) extends CustomShuffl
|
|||
validMetrics.toArray,
|
||||
advisoryTargetSize = conf.getConf(SQLConf.ADVISORY_PARTITION_SIZE_IN_BYTES),
|
||||
minNumPartitions = minPartitionNum)
|
||||
// This transformation adds new nodes, so we must use `transformUp` here.
|
||||
val stageIds = shuffleStages.map(_.id).toSet
|
||||
plan.transformUp {
|
||||
// even for shuffle exchange whose input RDD has 0 partition, we should still update its
|
||||
// `partitionStartIndices`, so that all the leaf shuffles in a stage have the same
|
||||
// number of output partitions.
|
||||
case stage: ShuffleQueryStageExec if stageIds.contains(stage.id) =>
|
||||
CustomShuffleReaderExec(stage, partitionSpecs)
|
||||
// We can never extend the shuffle partition number, so if we get the same number here,
|
||||
// that means we can not coalesce shuffle partition. Just return the origin plan.
|
||||
if (partitionSpecs.length == distinctNumPreShufflePartitions.head) {
|
||||
plan
|
||||
} else {
|
||||
// This transformation adds new nodes, so we must use `transformUp` here.
|
||||
val stageIds = shuffleStages.map(_.id).toSet
|
||||
plan.transformUp {
|
||||
// even for shuffle exchange whose input RDD has 0 partition, we should still update its
|
||||
// `partitionStartIndices`, so that all the leaf shuffles in a stage have the same
|
||||
// number of output partitions.
|
||||
case stage: ShuffleQueryStageExec if stageIds.contains(stage.id) =>
|
||||
CustomShuffleReaderExec(stage, partitionSpecs)
|
||||
}
|
||||
}
|
||||
} else {
|
||||
plan
|
||||
|
|
|
@ -94,7 +94,7 @@ class CoalesceShufflePartitionsSuite extends SparkFunSuite with BeforeAndAfterAl
|
|||
}
|
||||
|
||||
test(s"determining the number of reducers: aggregate operator$testNameNote") {
|
||||
val test = { spark: SparkSession =>
|
||||
val test: SparkSession => Unit = { spark: SparkSession =>
|
||||
val df =
|
||||
spark
|
||||
.range(0, 1000, 1, numInputPartitions)
|
||||
|
@ -113,14 +113,13 @@ class CoalesceShufflePartitionsSuite extends SparkFunSuite with BeforeAndAfterAl
|
|||
val shuffleReaders = finalPlan.collect {
|
||||
case r @ CoalescedShuffleReader() => r
|
||||
}
|
||||
assert(shuffleReaders.length === 1)
|
||||
|
||||
minNumPostShufflePartitions match {
|
||||
case Some(numPartitions) =>
|
||||
shuffleReaders.foreach { reader =>
|
||||
assert(reader.outputPartitioning.numPartitions === numPartitions)
|
||||
}
|
||||
assert(shuffleReaders.isEmpty)
|
||||
|
||||
case None =>
|
||||
assert(shuffleReaders.length === 1)
|
||||
shuffleReaders.foreach { reader =>
|
||||
assert(reader.outputPartitioning.numPartitions === 3)
|
||||
}
|
||||
|
@ -131,7 +130,7 @@ class CoalesceShufflePartitionsSuite extends SparkFunSuite with BeforeAndAfterAl
|
|||
}
|
||||
|
||||
test(s"determining the number of reducers: join operator$testNameNote") {
|
||||
val test = { spark: SparkSession =>
|
||||
val test: SparkSession => Unit = { spark: SparkSession =>
|
||||
val df1 =
|
||||
spark
|
||||
.range(0, 1000, 1, numInputPartitions)
|
||||
|
@ -160,14 +159,13 @@ class CoalesceShufflePartitionsSuite extends SparkFunSuite with BeforeAndAfterAl
|
|||
val shuffleReaders = finalPlan.collect {
|
||||
case r @ CoalescedShuffleReader() => r
|
||||
}
|
||||
assert(shuffleReaders.length === 2)
|
||||
|
||||
minNumPostShufflePartitions match {
|
||||
case Some(numPartitions) =>
|
||||
shuffleReaders.foreach { reader =>
|
||||
assert(reader.outputPartitioning.numPartitions === numPartitions)
|
||||
}
|
||||
assert(shuffleReaders.isEmpty)
|
||||
|
||||
case None =>
|
||||
assert(shuffleReaders.length === 2)
|
||||
shuffleReaders.foreach { reader =>
|
||||
assert(reader.outputPartitioning.numPartitions === 2)
|
||||
}
|
||||
|
@ -212,14 +210,13 @@ class CoalesceShufflePartitionsSuite extends SparkFunSuite with BeforeAndAfterAl
|
|||
val shuffleReaders = finalPlan.collect {
|
||||
case r @ CoalescedShuffleReader() => r
|
||||
}
|
||||
assert(shuffleReaders.length === 2)
|
||||
|
||||
minNumPostShufflePartitions match {
|
||||
case Some(numPartitions) =>
|
||||
shuffleReaders.foreach { reader =>
|
||||
assert(reader.outputPartitioning.numPartitions === numPartitions)
|
||||
}
|
||||
assert(shuffleReaders.isEmpty)
|
||||
|
||||
case None =>
|
||||
assert(shuffleReaders.length === 2)
|
||||
shuffleReaders.foreach { reader =>
|
||||
assert(reader.outputPartitioning.numPartitions === 2)
|
||||
}
|
||||
|
@ -264,14 +261,13 @@ class CoalesceShufflePartitionsSuite extends SparkFunSuite with BeforeAndAfterAl
|
|||
val shuffleReaders = finalPlan.collect {
|
||||
case r @ CoalescedShuffleReader() => r
|
||||
}
|
||||
assert(shuffleReaders.length === 2)
|
||||
|
||||
minNumPostShufflePartitions match {
|
||||
case Some(numPartitions) =>
|
||||
shuffleReaders.foreach { reader =>
|
||||
assert(reader.outputPartitioning.numPartitions === numPartitions)
|
||||
}
|
||||
assert(shuffleReaders.isEmpty)
|
||||
|
||||
case None =>
|
||||
assert(shuffleReaders.length === 2)
|
||||
shuffleReaders.foreach { reader =>
|
||||
assert(reader.outputPartitioning.numPartitions === 3)
|
||||
}
|
||||
|
|
|
@ -31,7 +31,7 @@ import org.apache.spark.sql.execution.{PartialReducerPartitionSpec, QueryExecuti
|
|||
import org.apache.spark.sql.execution.command.DataWritingCommandExec
|
||||
import org.apache.spark.sql.execution.datasources.noop.NoopDataSource
|
||||
import org.apache.spark.sql.execution.datasources.v2.V2TableWriteExec
|
||||
import org.apache.spark.sql.execution.exchange.{BroadcastExchangeExec, Exchange, REPARTITION, REPARTITION_WITH_NUM, ReusedExchangeExec, ShuffleExchangeExec, ShuffleExchangeLike}
|
||||
import org.apache.spark.sql.execution.exchange.{BroadcastExchangeExec, ENSURE_REQUIREMENTS, Exchange, REPARTITION, REPARTITION_WITH_NUM, ReusedExchangeExec, ShuffleExchangeExec, ShuffleExchangeLike, ShuffleOrigin}
|
||||
import org.apache.spark.sql.execution.joins.{BaseJoinExec, BroadcastHashJoinExec, SortMergeJoinExec}
|
||||
import org.apache.spark.sql.execution.metric.SQLShuffleReadMetricsReporter
|
||||
import org.apache.spark.sql.execution.ui.SparkListenerSQLAdaptiveExecutionUpdate
|
||||
|
@ -39,6 +39,7 @@ import org.apache.spark.sql.functions._
|
|||
import org.apache.spark.sql.internal.SQLConf
|
||||
import org.apache.spark.sql.internal.SQLConf.PartitionOverwriteMode
|
||||
import org.apache.spark.sql.test.SharedSparkSession
|
||||
import org.apache.spark.sql.test.SQLTestData.TestData
|
||||
import org.apache.spark.sql.types.{IntegerType, StructType}
|
||||
import org.apache.spark.sql.util.QueryExecutionListener
|
||||
import org.apache.spark.util.Utils
|
||||
|
@ -1543,4 +1544,35 @@ class AdaptiveQueryExecSuite
|
|||
assert(materializeLogs(0).startsWith("Materialize query stage BroadcastQueryStageExec"))
|
||||
assert(materializeLogs(1).startsWith("Materialize query stage ShuffleQueryStageExec"))
|
||||
}
|
||||
|
||||
test("SPARK-34899: Use origin plan if we can not coalesce shuffle partition") {
|
||||
def checkNoCoalescePartitions(ds: Dataset[Row], origin: ShuffleOrigin): Unit = {
|
||||
assert(collect(ds.queryExecution.executedPlan) {
|
||||
case s: ShuffleExchangeExec if s.shuffleOrigin == origin && s.numPartitions == 2 => s
|
||||
}.size == 1)
|
||||
ds.collect()
|
||||
val plan = ds.queryExecution.executedPlan
|
||||
assert(collect(plan) {
|
||||
case c: CustomShuffleReaderExec => c
|
||||
}.isEmpty)
|
||||
assert(collect(plan) {
|
||||
case s: ShuffleExchangeExec if s.shuffleOrigin == origin && s.numPartitions == 2 => s
|
||||
}.size == 1)
|
||||
checkAnswer(ds, testData)
|
||||
}
|
||||
|
||||
withSQLConf(SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true",
|
||||
SQLConf.COALESCE_PARTITIONS_ENABLED.key -> "true",
|
||||
SQLConf.ADVISORY_PARTITION_SIZE_IN_BYTES.key -> "2258",
|
||||
SQLConf.COALESCE_PARTITIONS_MIN_PARTITION_NUM.key -> "1",
|
||||
SQLConf.SHUFFLE_PARTITIONS.key -> "2") {
|
||||
val df = spark.sparkContext.parallelize(
|
||||
(1 to 100).map(i => TestData(i, i.toString)), 10).toDF()
|
||||
|
||||
// partition size [1420, 1420]
|
||||
checkNoCoalescePartitions(df.repartition(), REPARTITION)
|
||||
// partition size [1140, 1119]
|
||||
checkNoCoalescePartitions(df.sort($"key"), ENSURE_REQUIREMENTS)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Reference in a new issue