[SPARK-36221][SQL] Make sure CustomShuffleReaderExec has at least one partition

### What changes were proposed in this pull request?

* Add non-empty partition check in `CustomShuffleReaderExec`
* Make sure `OptimizeLocalShuffleReader` doesn't return empty partition

### Why are the changes needed?

Since SPARK-32083, AQE coalesce always return at least one partition, it should be robust to add non-empty check in `CustomShuffleReaderExec`.

### Does this PR introduce _any_ user-facing change?

No

### How was this patch tested?

not need

Closes #33431 from ulysses-you/non-empty-partition.

Authored-by: ulysses-you <ulyssesyou18@gmail.com>
Signed-off-by: Wenchen Fan <wenchen@databricks.com>
(cherry picked from commit b70c25881c)
Signed-off-by: Wenchen Fan <wenchen@databricks.com>
This commit is contained in:
ulysses-you 2021-07-20 20:48:35 +08:00 committed by Wenchen Fan
parent 3bc9346a3a
commit 677104f495
2 changed files with 11 additions and 11 deletions

View file

@ -34,11 +34,14 @@ import org.apache.spark.sql.vectorized.ColumnarBatch
*
* @param child It is usually `ShuffleQueryStageExec`, but can be the shuffle exchange
* node during canonicalization.
* @param partitionSpecs The partition specs that defines the arrangement.
* @param partitionSpecs The partition specs that defines the arrangement, requires at least one
* partition.
*/
case class CustomShuffleReaderExec private(
child: SparkPlan,
partitionSpecs: Seq[ShufflePartitionSpec]) extends UnaryExecNode {
assert(partitionSpecs.nonEmpty, "CustomShuffleReaderExec requires at least one partition")
// If this reader is to read shuffle files locally, then all partition specs should be
// `PartialMapperPartitionSpec`.
if (partitionSpecs.exists(_.isInstanceOf[PartialMapperPartitionSpec])) {
@ -52,8 +55,7 @@ case class CustomShuffleReaderExec private(
// If it is a local shuffle reader with one mapper per task, then the output partitioning is
// the same as the plan before shuffle.
// TODO this check is based on assumptions of callers' behavior but is sufficient for now.
if (partitionSpecs.nonEmpty &&
partitionSpecs.forall(_.isInstanceOf[PartialMapperPartitionSpec]) &&
if (partitionSpecs.forall(_.isInstanceOf[PartialMapperPartitionSpec]) &&
partitionSpecs.map(_.asInstanceOf[PartialMapperPartitionSpec].mapIndex).toSet.size ==
partitionSpecs.length) {
child match {
@ -111,7 +113,7 @@ case class CustomShuffleReaderExec private(
}
@transient private lazy val partitionDataSizes: Option[Seq[Long]] = {
if (partitionSpecs.nonEmpty && !isLocalReader && shuffleStage.get.mapStats.isDefined) {
if (!isLocalReader && shuffleStage.get.mapStats.isDefined) {
Some(partitionSpecs.map {
case p: CoalescedPartitionSpec =>
assert(p.dataSize.isDefined)

View file

@ -68,13 +68,11 @@ object OptimizeLocalShuffleReader extends CustomShuffleReaderRule {
shuffleStage: ShuffleQueryStageExec,
advisoryParallelism: Option[Int]): Seq[ShufflePartitionSpec] = {
val numMappers = shuffleStage.shuffle.numMappers
// ShuffleQueryStageExec.mapStats.isDefined promise numMappers > 0
assert(numMappers > 0)
val numReducers = shuffleStage.shuffle.numPartitions
val expectedParallelism = advisoryParallelism.getOrElse(numReducers)
val splitPoints = if (numMappers == 0) {
Seq.empty
} else {
equallyDivide(numReducers, math.max(1, expectedParallelism / numMappers))
}
val splitPoints = equallyDivide(numReducers, math.max(1, expectedParallelism / numMappers))
(0 until numMappers).flatMap { mapIndex =>
(splitPoints :+ numReducers).sliding(2).map {
case Seq(start, end) => PartialMapperPartitionSpec(mapIndex, start, end)
@ -127,8 +125,8 @@ object OptimizeLocalShuffleReader extends CustomShuffleReaderRule {
def canUseLocalShuffleReader(plan: SparkPlan): Boolean = plan match {
case s: ShuffleQueryStageExec =>
s.mapStats.isDefined && supportLocalReader(s.shuffle)
case CustomShuffleReaderExec(s: ShuffleQueryStageExec, partitionSpecs) =>
s.mapStats.isDefined && partitionSpecs.nonEmpty && supportLocalReader(s.shuffle) &&
case CustomShuffleReaderExec(s: ShuffleQueryStageExec, _) =>
s.mapStats.isDefined && supportLocalReader(s.shuffle) &&
s.shuffle.shuffleOrigin == ENSURE_REQUIREMENTS
case _ => false
}