[SPARK-35874][SQL] AQE Shuffle should wait for its subqueries to finish before materializing

### What changes were proposed in this pull request?

Currently, AQE uses a very tricky way to trigger and wait for the subqueries:
1. submitting stage calls `QueryStageExec.materialize`
2. `QueryStageExec.materialize` calls `executeQuery`
3. `executeQuery` does some preparation works, which goes to `QueryStageExec.doPrepare`
4. `QueryStageExec.doPrepare` calls `prepare` of shuffle/broadcast, which triggers all the subqueries in this stage
5. `executeQuery` then calls `waitForSubqueries`, which does nothing because `QueryStageExec` itself has no subqueries
6. then we submit the shuffle/broadcast job, without waiting for subqueries
7. for `ShuffleExchangeExec.mapOutputStatisticsFuture`, it calls `child.execute`, which calls `executeQuery` and wait for subqueries in the query tree of `child`
8. The only missing case is: `ShuffleExchangeExec` itself may contain subqueries(repartition expression) and AQE doesn't wait for it.

A simple fix would be overwriting `waitForSubqueries` in `QueryStageExec`, and forward the request to shuffle/broadcast, but this PR proposes a different and probably cleaner way: we follow `execute`/`doExecute` in `SparkPlan`, and add similar APIs in the AQE version of "execute", which gets a future from shuffle/broadcast.

### Why are the changes needed?

bug fix

### Does this PR introduce _any_ user-facing change?

a query fails without the fix and can run now

### How was this patch tested?

new test

Closes #33058 from cloud-fan/aqe.

Authored-by: Wenchen Fan <wenchen@databricks.com>
Signed-off-by: Wenchen Fan <wenchen@databricks.com>
(cherry picked from commit 2df67a1a1b)
Signed-off-by: Wenchen Fan <wenchen@databricks.com>
This commit is contained in:
Wenchen Fan 2021-07-09 00:20:50 +08:00
parent f31cf163d9
commit b8d3da16b1
5 changed files with 36 additions and 20 deletions

View file

@ -80,7 +80,7 @@ abstract class QueryStageExec extends LeafExecNode {
* broadcasting data, etc. The caller side can use the returned [[Future]] to wait until this
* stage is ready.
*/
final def materialize(): Future[Any] = executeQuery {
final def materialize(): Future[Any] = {
logDebug(s"Materialize query stage ${this.getClass.getSimpleName}: $id")
doMaterialize()
}
@ -119,7 +119,6 @@ abstract class QueryStageExec extends LeafExecNode {
override def executeTail(n: Int): Array[InternalRow] = plan.executeTail(n)
override def executeToIterator(): Iterator[InternalRow] = plan.executeToIterator()
protected override def doPrepare(): Unit = plan.prepare()
protected override def doExecute(): RDD[InternalRow] = plan.execute()
override def supportsColumnar: Boolean = plan.supportsColumnar
protected override def doExecuteColumnar(): RDD[ColumnarBatch] = plan.executeColumnar()
@ -171,7 +170,9 @@ case class ShuffleQueryStageExec(
throw new IllegalStateException(s"wrong plan for shuffle stage:\n ${plan.treeString}")
}
override def doMaterialize(): Future[Any] = shuffle.mapOutputStatisticsFuture
@transient private lazy val shuffleFuture = shuffle.submitShuffleJob
override def doMaterialize(): Future[Any] = shuffleFuture
override def newReuseInstance(newStageId: Int, newOutput: Seq[Attribute]): QueryStageExec = {
val reuse = ShuffleQueryStageExec(
@ -182,13 +183,10 @@ case class ShuffleQueryStageExec(
reuse
}
override def cancel(): Unit = {
shuffle.mapOutputStatisticsFuture match {
case action: FutureAction[MapOutputStatistics]
if !shuffle.mapOutputStatisticsFuture.isCompleted =>
action.cancel()
case _ =>
}
override def cancel(): Unit = shuffleFuture match {
case action: FutureAction[MapOutputStatistics] if !action.isCompleted =>
action.cancel()
case _ =>
}
/**
@ -224,7 +222,7 @@ case class BroadcastQueryStageExec(
}
@transient private lazy val materializeWithTimeout = {
val broadcastFuture = broadcast.completionFuture
val broadcastFuture = broadcast.submitBroadcastJob
val timeout = conf.broadcastTimeout
val promise = Promise[Any]()
val fail = BroadcastQueryStageExec.scheduledExecutor.schedule(new Runnable() {

View file

@ -55,10 +55,15 @@ trait BroadcastExchangeLike extends Exchange {
def relationFuture: Future[broadcast.Broadcast[Any]]
/**
* For registering callbacks on `relationFuture`.
* Note that calling this method may not start the execution of broadcast job.
* The asynchronous job that materializes the broadcast. It's used for registering callbacks on
* `relationFuture`. Note that calling this method may not start the execution of broadcast job.
* It also does the preparations work, such as waiting for the subqueries.
*/
def completionFuture: scala.concurrent.Future[broadcast.Broadcast[Any]]
final def submitBroadcastJob: scala.concurrent.Future[broadcast.Broadcast[Any]] = executeQuery {
completionFuture
}
protected def completionFuture: scala.concurrent.Future[broadcast.Broadcast[Any]]
/**
* Returns the runtime statistics after broadcast materialization.

View file

@ -61,9 +61,14 @@ trait ShuffleExchangeLike extends Exchange {
def shuffleOrigin: ShuffleOrigin
/**
* The asynchronous job that materializes the shuffle.
* The asynchronous job that materializes the shuffle. It also does the preparations work,
* such as waiting for the subqueries.
*/
def mapOutputStatisticsFuture: Future[MapOutputStatistics]
final def submitShuffleJob: Future[MapOutputStatistics] = executeQuery {
mapOutputStatisticsFuture
}
protected def mapOutputStatisticsFuture: Future[MapOutputStatistics]
/**
* Returns the shuffle RDD with specified partition specs.
@ -123,13 +128,14 @@ case class ShuffleExchangeExec(
override def nodeName: String = "Exchange"
private val serializer: Serializer =
private lazy val serializer: Serializer =
new UnsafeRowSerializer(child.output.size, longMetric("dataSize"))
@transient lazy val inputRDD: RDD[InternalRow] = child.execute()
// 'mapOutputStatisticsFuture' is only needed when enable AQE.
@transient override lazy val mapOutputStatisticsFuture: Future[MapOutputStatistics] = {
@transient
override lazy val mapOutputStatisticsFuture: Future[MapOutputStatistics] = {
if (inputRDD.getNumPartitions == 0) {
Future.successful(null)
} else {

View file

@ -829,7 +829,7 @@ case class MyShuffleExchangeExec(delegate: ShuffleExchangeExec) extends ShuffleE
delegate.shuffleOrigin
}
override def mapOutputStatisticsFuture: Future[MapOutputStatistics] =
delegate.mapOutputStatisticsFuture
delegate.submitShuffleJob
override def getShuffleRDD(partitionSpecs: Array[ShufflePartitionSpec]): RDD[_] =
delegate.getShuffleRDD(partitionSpecs)
override def runtimeStatistics: Statistics = delegate.runtimeStatistics
@ -848,7 +848,7 @@ case class MyBroadcastExchangeExec(delegate: BroadcastExchangeExec) extends Broa
override def runId: UUID = delegate.runId
override def relationFuture: java.util.concurrent.Future[Broadcast[Any]] =
delegate.relationFuture
override def completionFuture: Future[Broadcast[Any]] = delegate.completionFuture
override def completionFuture: Future[Broadcast[Any]] = delegate.submitBroadcastJob
override def runtimeStatistics: Statistics = delegate.runtimeStatistics
override def child: SparkPlan = delegate.child
override protected def doPrepare(): Unit = delegate.prepare()

View file

@ -1952,6 +1952,13 @@ class AdaptiveQueryExecSuite
}
}
}
test("SPARK-35874: AQE Shuffle should wait for its subqueries to finish before materializing") {
withSQLConf(SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true") {
val query = "SELECT b FROM testData2 DISTRIBUTE BY (b, (SELECT max(key) FROM testData))"
runAdaptiveAndVerifyResult(query)
}
}
}
/**