diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/QueryStageExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/QueryStageExec.scala index c2a9b46bce..f308829f66 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/QueryStageExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/QueryStageExec.scala @@ -80,7 +80,7 @@ abstract class QueryStageExec extends LeafExecNode { * broadcasting data, etc. The caller side can use the returned [[Future]] to wait until this * stage is ready. */ - final def materialize(): Future[Any] = executeQuery { + final def materialize(): Future[Any] = { logDebug(s"Materialize query stage ${this.getClass.getSimpleName}: $id") doMaterialize() } @@ -119,7 +119,6 @@ abstract class QueryStageExec extends LeafExecNode { override def executeTail(n: Int): Array[InternalRow] = plan.executeTail(n) override def executeToIterator(): Iterator[InternalRow] = plan.executeToIterator() - protected override def doPrepare(): Unit = plan.prepare() protected override def doExecute(): RDD[InternalRow] = plan.execute() override def supportsColumnar: Boolean = plan.supportsColumnar protected override def doExecuteColumnar(): RDD[ColumnarBatch] = plan.executeColumnar() @@ -171,7 +170,9 @@ case class ShuffleQueryStageExec( throw new IllegalStateException(s"wrong plan for shuffle stage:\n ${plan.treeString}") } - override def doMaterialize(): Future[Any] = shuffle.mapOutputStatisticsFuture + @transient private lazy val shuffleFuture = shuffle.submitShuffleJob + + override def doMaterialize(): Future[Any] = shuffleFuture override def newReuseInstance(newStageId: Int, newOutput: Seq[Attribute]): QueryStageExec = { val reuse = ShuffleQueryStageExec( @@ -182,13 +183,10 @@ case class ShuffleQueryStageExec( reuse } - override def cancel(): Unit = { - shuffle.mapOutputStatisticsFuture match { - case action: FutureAction[MapOutputStatistics] - if !shuffle.mapOutputStatisticsFuture.isCompleted => - action.cancel() - case _ => - } + override def cancel(): Unit = shuffleFuture match { + case action: FutureAction[MapOutputStatistics] if !action.isCompleted => + action.cancel() + case _ => } /** @@ -224,7 +222,7 @@ case class BroadcastQueryStageExec( } @transient private lazy val materializeWithTimeout = { - val broadcastFuture = broadcast.completionFuture + val broadcastFuture = broadcast.submitBroadcastJob val timeout = conf.broadcastTimeout val promise = Promise[Any]() val fail = BroadcastQueryStageExec.scheduledExecutor.schedule(new Runnable() { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/BroadcastExchangeExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/BroadcastExchangeExec.scala index cc12be6f32..7859785da8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/BroadcastExchangeExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/BroadcastExchangeExec.scala @@ -55,10 +55,15 @@ trait BroadcastExchangeLike extends Exchange { def relationFuture: Future[broadcast.Broadcast[Any]] /** - * For registering callbacks on `relationFuture`. - * Note that calling this method may not start the execution of broadcast job. + * The asynchronous job that materializes the broadcast. It's used for registering callbacks on + * `relationFuture`. Note that calling this method may not start the execution of broadcast job. + * It also does the preparations work, such as waiting for the subqueries. */ - def completionFuture: scala.concurrent.Future[broadcast.Broadcast[Any]] + final def submitBroadcastJob: scala.concurrent.Future[broadcast.Broadcast[Any]] = executeQuery { + completionFuture + } + + protected def completionFuture: scala.concurrent.Future[broadcast.Broadcast[Any]] /** * Returns the runtime statistics after broadcast materialization. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ShuffleExchangeExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ShuffleExchangeExec.scala index 1632a9e5b4..e8cf7684bc 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ShuffleExchangeExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ShuffleExchangeExec.scala @@ -61,9 +61,14 @@ trait ShuffleExchangeLike extends Exchange { def shuffleOrigin: ShuffleOrigin /** - * The asynchronous job that materializes the shuffle. + * The asynchronous job that materializes the shuffle. It also does the preparations work, + * such as waiting for the subqueries. */ - def mapOutputStatisticsFuture: Future[MapOutputStatistics] + final def submitShuffleJob: Future[MapOutputStatistics] = executeQuery { + mapOutputStatisticsFuture + } + + protected def mapOutputStatisticsFuture: Future[MapOutputStatistics] /** * Returns the shuffle RDD with specified partition specs. @@ -123,13 +128,14 @@ case class ShuffleExchangeExec( override def nodeName: String = "Exchange" - private val serializer: Serializer = + private lazy val serializer: Serializer = new UnsafeRowSerializer(child.output.size, longMetric("dataSize")) @transient lazy val inputRDD: RDD[InternalRow] = child.execute() // 'mapOutputStatisticsFuture' is only needed when enable AQE. - @transient override lazy val mapOutputStatisticsFuture: Future[MapOutputStatistics] = { + @transient + override lazy val mapOutputStatisticsFuture: Future[MapOutputStatistics] = { if (inputRDD.getNumPartitions == 0) { Future.successful(null) } else { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SparkSessionExtensionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SparkSessionExtensionSuite.scala index 8225204089..b1c3fd5af0 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SparkSessionExtensionSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SparkSessionExtensionSuite.scala @@ -829,7 +829,7 @@ case class MyShuffleExchangeExec(delegate: ShuffleExchangeExec) extends ShuffleE delegate.shuffleOrigin } override def mapOutputStatisticsFuture: Future[MapOutputStatistics] = - delegate.mapOutputStatisticsFuture + delegate.submitShuffleJob override def getShuffleRDD(partitionSpecs: Array[ShufflePartitionSpec]): RDD[_] = delegate.getShuffleRDD(partitionSpecs) override def runtimeStatistics: Statistics = delegate.runtimeStatistics @@ -848,7 +848,7 @@ case class MyBroadcastExchangeExec(delegate: BroadcastExchangeExec) extends Broa override def runId: UUID = delegate.runId override def relationFuture: java.util.concurrent.Future[Broadcast[Any]] = delegate.relationFuture - override def completionFuture: Future[Broadcast[Any]] = delegate.completionFuture + override def completionFuture: Future[Broadcast[Any]] = delegate.submitBroadcastJob override def runtimeStatistics: Statistics = delegate.runtimeStatistics override def child: SparkPlan = delegate.child override protected def doPrepare(): Unit = delegate.prepare() diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/adaptive/AdaptiveQueryExecSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/adaptive/AdaptiveQueryExecSuite.scala index 8bc67fca20..d811ba7180 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/adaptive/AdaptiveQueryExecSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/adaptive/AdaptiveQueryExecSuite.scala @@ -1952,6 +1952,13 @@ class AdaptiveQueryExecSuite } } } + + test("SPARK-35874: AQE Shuffle should wait for its subqueries to finish before materializing") { + withSQLConf(SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true") { + val query = "SELECT b FROM testData2 DISTRIBUTE BY (b, (SELECT max(key) FROM testData))" + runAdaptiveAndVerifyResult(query) + } + } } /**