diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameRangeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameRangeSuite.scala index 57a930dfaf..a0fd74088c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameRangeSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameRangeSuite.scala @@ -17,14 +17,16 @@ package org.apache.spark.sql +import java.util.concurrent.{CountDownLatch, TimeUnit} + import scala.concurrent.duration._ import scala.math.abs import scala.util.Random import org.scalatest.concurrent.Eventually -import org.apache.spark.{SparkException, TaskContext} -import org.apache.spark.scheduler.{SparkListener, SparkListenerJobStart} +import org.apache.spark.{SparkContext, SparkException} +import org.apache.spark.scheduler.{SparkListener, SparkListenerTaskStart} import org.apache.spark.sql.functions._ import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SharedSQLContext @@ -152,39 +154,53 @@ class DataFrameRangeSuite extends QueryTest with SharedSQLContext with Eventuall } test("Cancelling stage in a query with Range.") { - val listener = new SparkListener { - override def onJobStart(jobStart: SparkListenerJobStart): Unit = { - eventually(timeout(10.seconds), interval(1.millis)) { - assert(DataFrameRangeSuite.stageToKill > 0) - } - sparkContext.cancelStage(DataFrameRangeSuite.stageToKill) - } - } + // Save and restore the value because SparkContext is shared + val savedInterruptOnCancel = sparkContext + .getLocalProperty(SparkContext.SPARK_JOB_INTERRUPT_ON_CANCEL) - sparkContext.addSparkListener(listener) - for (codegen <- Seq(true, false)) { - withSQLConf(SQLConf.WHOLESTAGE_CODEGEN_ENABLED.key -> codegen.toString()) { - DataFrameRangeSuite.stageToKill = -1 - val ex = intercept[SparkException] { - spark.range(0, 100000000000L, 1, 1).map { x => - DataFrameRangeSuite.stageToKill = TaskContext.get().stageId() - x - }.toDF("id").agg(sum("id")).collect() + try { + sparkContext.setLocalProperty(SparkContext.SPARK_JOB_INTERRUPT_ON_CANCEL, "true") + + for (codegen <- Seq(true, false)) { + // This countdown latch used to make sure with all the stages cancelStage called in listener + val latch = new CountDownLatch(2) + + val listener = new SparkListener { + override def onTaskStart(taskStart: SparkListenerTaskStart): Unit = { + sparkContext.cancelStage(taskStart.stageId) + latch.countDown() + } } - ex.getCause() match { - case null => - assert(ex.getMessage().contains("cancelled")) - case cause: SparkException => - assert(cause.getMessage().contains("cancelled")) - case cause: Throwable => - fail("Expected the cause to be SparkException, got " + cause.toString() + " instead.") + + sparkContext.addSparkListener(listener) + withSQLConf(SQLConf.WHOLESTAGE_CODEGEN_ENABLED.key -> codegen.toString()) { + val ex = intercept[SparkException] { + sparkContext.range(0, 10000L, numSlices = 10).mapPartitions { x => + x.synchronized { + x.wait() + } + x + }.toDF("id").agg(sum("id")).collect() + } + ex.getCause() match { + case null => + assert(ex.getMessage().contains("cancelled")) + case cause: SparkException => + assert(cause.getMessage().contains("cancelled")) + case cause: Throwable => + fail("Expected the cause to be SparkException, got " + cause.toString() + " instead.") + } } + latch.await(20, TimeUnit.SECONDS) + eventually(timeout(20.seconds)) { + assert(sparkContext.statusTracker.getExecutorInfos.map(_.numRunningTasks()).sum == 0) + } + sparkContext.removeSparkListener(listener) } - eventually(timeout(20.seconds)) { - assert(sparkContext.statusTracker.getExecutorInfos.map(_.numRunningTasks()).sum == 0) - } + } finally { + sparkContext.setLocalProperty(SparkContext.SPARK_JOB_INTERRUPT_ON_CANCEL, + savedInterruptOnCancel) } - sparkContext.removeSparkListener(listener) } test("SPARK-20430 Initialize Range parameters in a driver side") { @@ -204,7 +220,3 @@ class DataFrameRangeSuite extends QueryTest with SharedSQLContext with Eventuall } } } - -object DataFrameRangeSuite { - @volatile var stageToKill = -1 -}