[SPARK-27460][FOLLOW-UP][TESTS] Fix flaky tests

## What changes were proposed in this pull request?

This patch makes several test flakiness fixes.

## How was this patch tested?
N/A

Closes #24434 from gatorsmile/fixFlakyTest.

Lead-authored-by: gatorsmile <gatorsmile@gmail.com>
Co-authored-by: Hyukjin Kwon <gurwls223@gmail.com>
Signed-off-by: Wenchen Fan <wenchen@databricks.com>
This commit is contained in:
gatorsmile 2019-04-24 17:36:29 +08:00 committed by Wenchen Fan
parent a30983db57
commit cd4a284030
13 changed files with 99 additions and 39 deletions

View file

@ -21,7 +21,7 @@ import scala.collection.mutable
import org.mockito.ArgumentMatchers.{any, eq => meq} import org.mockito.ArgumentMatchers.{any, eq => meq}
import org.mockito.Mockito.{mock, never, verify, when} import org.mockito.Mockito.{mock, never, verify, when}
import org.scalatest.{BeforeAndAfter, PrivateMethodTester} import org.scalatest.PrivateMethodTester
import org.apache.spark.executor.TaskMetrics import org.apache.spark.executor.TaskMetrics
import org.apache.spark.internal.config import org.apache.spark.internal.config
@ -38,20 +38,24 @@ import org.apache.spark.util.ManualClock
*/ */
class ExecutorAllocationManagerSuite class ExecutorAllocationManagerSuite
extends SparkFunSuite extends SparkFunSuite
with LocalSparkContext with LocalSparkContext {
with BeforeAndAfter {
import ExecutorAllocationManager._ import ExecutorAllocationManager._
import ExecutorAllocationManagerSuite._ import ExecutorAllocationManagerSuite._
private val contexts = new mutable.ListBuffer[SparkContext]() private val contexts = new mutable.ListBuffer[SparkContext]()
before { override def beforeEach(): Unit = {
super.beforeEach()
contexts.clear() contexts.clear()
} }
after { override def afterEach(): Unit = {
try {
contexts.foreach(_.stop()) contexts.foreach(_.stop())
} finally {
super.afterEach()
}
} }
private def post(bus: LiveListenerBus, event: SparkListenerEvent): Unit = { private def post(bus: LiveListenerBus, event: SparkListenerEvent): Unit = {
@ -282,7 +286,7 @@ class ExecutorAllocationManagerSuite
assert(totalRunningTasks(manager) === 0) assert(totalRunningTasks(manager) === 0)
} }
test("cancel pending executors when no longer needed") { testRetry("cancel pending executors when no longer needed") {
sc = createSparkContext(0, 10, 0) sc = createSparkContext(0, 10, 0)
val manager = sc.executorAllocationManager.get val manager = sc.executorAllocationManager.get
post(sc.listenerBus, SparkListenerStageSubmitted(createStageInfo(2, 5))) post(sc.listenerBus, SparkListenerStageSubmitted(createStageInfo(2, 5)))

View file

@ -17,7 +17,10 @@
package org.apache.spark package org.apache.spark
import scala.concurrent.duration._
import org.scalatest.Assertions import org.scalatest.Assertions
import org.scalatest.concurrent.Eventually._
import org.apache.spark.storage.StorageLevel import org.apache.spark.storage.StorageLevel
@ -58,10 +61,12 @@ class SparkContextInfoSuite extends SparkFunSuite with LocalSparkContext {
test("getRDDStorageInfo only reports on RDDs that actually persist data") { test("getRDDStorageInfo only reports on RDDs that actually persist data") {
sc = new SparkContext("local", "test") sc = new SparkContext("local", "test")
val rdd = sc.makeRDD(Array(1, 2, 3, 4), 2).cache() val rdd = sc.makeRDD(Array(1, 2, 3, 4), 2).cache()
assert(sc.getRDDStorageInfo.size === 0) assert(sc.getRDDStorageInfo.length === 0)
rdd.collect() rdd.collect()
sc.listenerBus.waitUntilEmpty(10000) sc.listenerBus.waitUntilEmpty(10000)
assert(sc.getRDDStorageInfo.size === 1) eventually(timeout(10.seconds), interval(100.milliseconds)) {
assert(sc.getRDDStorageInfo.length === 1)
}
assert(sc.getRDDStorageInfo.head.isCached) assert(sc.getRDDStorageInfo.head.isCached)
assert(sc.getRDDStorageInfo.head.memSize > 0) assert(sc.getRDDStorageInfo.head.memSize > 0)
assert(sc.getRDDStorageInfo.head.storageLevel === StorageLevel.MEMORY_ONLY) assert(sc.getRDDStorageInfo.head.storageLevel === StorageLevel.MEMORY_ONLY)

View file

@ -20,8 +20,10 @@ package org.apache.spark
// scalastyle:off // scalastyle:off
import java.io.File import java.io.File
import scala.annotation.tailrec
import org.apache.log4j.{Appender, Level, Logger} import org.apache.log4j.{Appender, Level, Logger}
import org.scalatest.{BeforeAndAfterAll, FunSuite, Outcome} import org.scalatest.{BeforeAndAfter, BeforeAndAfterAll, BeforeAndAfterEach, FunSuite, Outcome}
import org.apache.spark.internal.Logging import org.apache.spark.internal.Logging
import org.apache.spark.internal.config.Tests.IS_TESTING import org.apache.spark.internal.config.Tests.IS_TESTING
@ -54,6 +56,7 @@ import org.apache.spark.util.{AccumulatorContext, Utils}
abstract class SparkFunSuite abstract class SparkFunSuite
extends FunSuite extends FunSuite
with BeforeAndAfterAll with BeforeAndAfterAll
with BeforeAndAfterEach
with ThreadAudit with ThreadAudit
with Logging { with Logging {
// scalastyle:on // scalastyle:on
@ -89,6 +92,47 @@ abstract class SparkFunSuite
getTestResourceFile(file).getCanonicalPath getTestResourceFile(file).getCanonicalPath
} }
/**
* Note: this method doesn't support `BeforeAndAfter`. You must use `BeforeAndAfterEach` to
* set up and tear down resources.
*/
def testRetry(s: String, n: Int = 2)(body: => Unit): Unit = {
test(s) {
retry(n) {
body
}
}
}
/**
* Note: this method doesn't support `BeforeAndAfter`. You must use `BeforeAndAfterEach` to
* set up and tear down resources.
*/
def retry[T](n: Int)(body: => T): T = {
if (this.isInstanceOf[BeforeAndAfter]) {
throw new UnsupportedOperationException(
s"testRetry/retry cannot be used with ${classOf[BeforeAndAfter]}. " +
s"Please use ${classOf[BeforeAndAfterEach]} instead.")
}
retry0(n, n)(body)
}
@tailrec private final def retry0[T](n: Int, n0: Int)(body: => T): T = {
try body
catch { case e: Throwable =>
if (n > 0) {
logWarning(e.getMessage, e)
logInfo(s"\n\n===== RETRY #${n0 - n + 1} =====\n")
// Reset state before re-attempting in order so that tests which use patterns like
// LocalSparkContext to clean up state can work correctly when retried.
afterEach()
beforeEach()
retry0(n-1, n0)(body)
}
else throw e
}
}
/** /**
* Log the suite name and the test name before and after each test. * Log the suite name and the test name before and after each test.
* *

View file

@ -27,7 +27,7 @@ import org.apache.spark.JobExecutionStatus._
class StatusTrackerSuite extends SparkFunSuite with Matchers with LocalSparkContext { class StatusTrackerSuite extends SparkFunSuite with Matchers with LocalSparkContext {
test("basic status API usage") { testRetry("basic status API usage") {
sc = new SparkContext("local", "test", new SparkConf(false)) sc = new SparkContext("local", "test", new SparkConf(false))
val jobFuture = sc.parallelize(1 to 10000, 2).map(identity).groupBy(identity).collectAsync() val jobFuture = sc.parallelize(1 to 10000, 2).map(identity).groupBy(identity).collectAsync()
val jobId: Int = eventually(timeout(10.seconds)) { val jobId: Int = eventually(timeout(10.seconds)) {

View file

@ -34,7 +34,6 @@ import org.apache.hadoop.security.AccessControlException
import org.json4s.jackson.JsonMethods._ import org.json4s.jackson.JsonMethods._
import org.mockito.ArgumentMatchers.{any, argThat} import org.mockito.ArgumentMatchers.{any, argThat}
import org.mockito.Mockito.{doThrow, mock, spy, verify, when} import org.mockito.Mockito.{doThrow, mock, spy, verify, when}
import org.scalatest.BeforeAndAfter
import org.scalatest.Matchers import org.scalatest.Matchers
import org.scalatest.concurrent.Eventually._ import org.scalatest.concurrent.Eventually._
@ -52,16 +51,21 @@ import org.apache.spark.status.api.v1.{ApplicationAttemptInfo, ApplicationInfo}
import org.apache.spark.util.{Clock, JsonProtocol, ManualClock, Utils} import org.apache.spark.util.{Clock, JsonProtocol, ManualClock, Utils}
import org.apache.spark.util.logging.DriverLogger import org.apache.spark.util.logging.DriverLogger
class FsHistoryProviderSuite extends SparkFunSuite with BeforeAndAfter with Matchers with Logging { class FsHistoryProviderSuite extends SparkFunSuite with Matchers with Logging {
private var testDir: File = null private var testDir: File = null
before { override def beforeEach(): Unit = {
super.beforeEach()
testDir = Utils.createTempDir(namePrefix = s"a b%20c+d") testDir = Utils.createTempDir(namePrefix = s"a b%20c+d")
} }
after { override def afterEach(): Unit = {
try {
Utils.deleteRecursively(testDir) Utils.deleteRecursively(testDir)
} finally {
super.afterEach()
}
} }
/** Create a fake log file using the new log format used in Spark 1.3+ */ /** Create a fake log file using the new log format used in Spark 1.3+ */
@ -733,7 +737,7 @@ class FsHistoryProviderSuite extends SparkFunSuite with BeforeAndAfter with Matc
provider.inSafeMode = false provider.inSafeMode = false
clock.setTime(10000) clock.setTime(10000)
eventually(timeout(1.second), interval(10.milliseconds)) { eventually(timeout(3.second), interval(10.milliseconds)) {
provider.getConfig().keys should not contain ("HDFS State") provider.getConfig().keys should not contain ("HDFS State")
} }
} finally { } finally {
@ -741,7 +745,7 @@ class FsHistoryProviderSuite extends SparkFunSuite with BeforeAndAfter with Matc
} }
} }
test("provider reports error after FS leaves safe mode") { testRetry("provider reports error after FS leaves safe mode") {
testDir.delete() testDir.delete()
val clock = new ManualClock() val clock = new ManualClock()
val provider = new SafeModeTestProvider(createTestConf(), clock) val provider = new SafeModeTestProvider(createTestConf(), clock)
@ -751,7 +755,7 @@ class FsHistoryProviderSuite extends SparkFunSuite with BeforeAndAfter with Matc
provider.inSafeMode = false provider.inSafeMode = false
clock.setTime(10000) clock.setTime(10000)
eventually(timeout(1.second), interval(10.milliseconds)) { eventually(timeout(3.second), interval(10.milliseconds)) {
verify(errorHandler).uncaughtException(any(), any()) verify(errorHandler).uncaughtException(any(), any())
} }
} finally { } finally {

View file

@ -19,25 +19,23 @@ package org.apache.spark.scheduler
import scala.collection.mutable import scala.collection.mutable
import org.scalatest.{BeforeAndAfter, BeforeAndAfterAll}
import org.apache.spark.{LocalSparkContext, SparkContext, SparkFunSuite, TestUtils} import org.apache.spark.{LocalSparkContext, SparkContext, SparkFunSuite, TestUtils}
import org.apache.spark.scheduler.cluster.ExecutorInfo import org.apache.spark.scheduler.cluster.ExecutorInfo
/** /**
* Unit tests for SparkListener that require a local cluster. * Unit tests for SparkListener that require a local cluster.
*/ */
class SparkListenerWithClusterSuite extends SparkFunSuite with LocalSparkContext class SparkListenerWithClusterSuite extends SparkFunSuite with LocalSparkContext {
with BeforeAndAfter with BeforeAndAfterAll {
/** Length of time to wait while draining listener events. */ /** Length of time to wait while draining listener events. */
val WAIT_TIMEOUT_MILLIS = 10000 val WAIT_TIMEOUT_MILLIS = 10000
before { override def beforeEach(): Unit = {
super.beforeEach()
sc = new SparkContext("local-cluster[2,1,1024]", "SparkListenerSuite") sc = new SparkContext("local-cluster[2,1,1024]", "SparkListenerSuite")
} }
test("SparkListener sends executor added message") { testRetry("SparkListener sends executor added message") {
val listener = new SaveExecutorInfo val listener = new SaveExecutorInfo
sc.addSparkListener(listener) sc.addSparkListener(listener)

View file

@ -167,7 +167,7 @@ abstract class BaseYarnClusterSuite
val handle = launcher.startApplication() val handle = launcher.startApplication()
try { try {
eventually(timeout(2.minutes), interval(1.second)) { eventually(timeout(3.minutes), interval(1.second)) {
assert(handle.getState().isFinal()) assert(handle.getState().isFinal())
} }
} finally { } finally {

View file

@ -205,7 +205,7 @@ class YarnClusterSuite extends BaseYarnClusterSuite {
.startApplication() .startApplication()
try { try {
eventually(timeout(30.seconds), interval(100.milliseconds)) { eventually(timeout(3.minutes), interval(100.milliseconds)) {
handle.getState() should be (SparkAppHandle.State.RUNNING) handle.getState() should be (SparkAppHandle.State.RUNNING)
} }
@ -213,7 +213,7 @@ class YarnClusterSuite extends BaseYarnClusterSuite {
handle.getAppId() should startWith ("application_") handle.getAppId() should startWith ("application_")
handle.stop() handle.stop()
eventually(timeout(30.seconds), interval(100.milliseconds)) { eventually(timeout(3.minutes), interval(100.milliseconds)) {
handle.getState() should be (SparkAppHandle.State.KILLED) handle.getState() should be (SparkAppHandle.State.KILLED)
} }
} finally { } finally {

View file

@ -502,7 +502,8 @@ class SQLAppStatusListenerSuite extends SparkFunSuite with SharedSQLContext with
} }
// Wait for listener to finish computing the metrics for the execution. // Wait for listener to finish computing the metrics for the execution.
while (statusStore.executionsList().last.metricValues == null) { while (statusStore.executionsList().isEmpty ||
statusStore.executionsList().last.metricValues == null) {
Thread.sleep(100) Thread.sleep(100)
} }

View file

@ -195,7 +195,7 @@ class FileStreamSourceSuite extends FileStreamSourceTest {
import testImplicits._ import testImplicits._
override val streamingTimeout = 20.seconds override val streamingTimeout = 80.seconds
/** Use `format` and `path` to create FileStreamSource via DataFrameReader */ /** Use `format` and `path` to create FileStreamSource via DataFrameReader */
private def createFileStreamSource( private def createFileStreamSource(

View file

@ -89,7 +89,7 @@ trait StreamTest extends QueryTest with SharedSQLContext with TimeLimits with Be
protected val defaultUseV2Sink = false protected val defaultUseV2Sink = false
/** How long to wait for an active stream to catch up when checking a result. */ /** How long to wait for an active stream to catch up when checking a result. */
val streamingTimeout = 10.seconds val streamingTimeout = 60.seconds
/** A trait for actions that can be performed while testing a streaming DataFrame. */ /** A trait for actions that can be performed while testing a streaming DataFrame. */
trait StreamAction trait StreamAction

View file

@ -23,7 +23,6 @@ import scala.concurrent.Future
import scala.util.Random import scala.util.Random
import scala.util.control.NonFatal import scala.util.control.NonFatal
import org.scalatest.BeforeAndAfter
import org.scalatest.concurrent.PatienceConfiguration.Timeout import org.scalatest.concurrent.PatienceConfiguration.Timeout
import org.scalatest.time.Span import org.scalatest.time.Span
import org.scalatest.time.SpanSugar._ import org.scalatest.time.SpanSugar._
@ -35,21 +34,26 @@ import org.apache.spark.sql.execution.streaming._
import org.apache.spark.sql.streaming.util.BlockingSource import org.apache.spark.sql.streaming.util.BlockingSource
import org.apache.spark.util.Utils import org.apache.spark.util.Utils
class StreamingQueryManagerSuite extends StreamTest with BeforeAndAfter { class StreamingQueryManagerSuite extends StreamTest {
import AwaitTerminationTester._ import AwaitTerminationTester._
import testImplicits._ import testImplicits._
override val streamingTimeout = 20.seconds override val streamingTimeout = 20.seconds
before { override def beforeEach(): Unit = {
super.beforeEach()
assert(spark.streams.active.isEmpty) assert(spark.streams.active.isEmpty)
spark.streams.resetTerminated() spark.streams.resetTerminated()
} }
after { override def afterEach(): Unit = {
try {
assert(spark.streams.active.isEmpty) assert(spark.streams.active.isEmpty)
spark.streams.resetTerminated() spark.streams.resetTerminated()
} finally {
super.afterEach()
}
} }
testQuietly("listing") { testQuietly("listing") {
@ -83,7 +87,7 @@ class StreamingQueryManagerSuite extends StreamTest with BeforeAndAfter {
} }
} }
testQuietly("awaitAnyTermination without timeout and resetTerminated") { testRetry("awaitAnyTermination without timeout and resetTerminated") {
val datasets = Seq.fill(5)(makeDataset._2) val datasets = Seq.fill(5)(makeDataset._2)
withQueriesOn(datasets: _*) { queries => withQueriesOn(datasets: _*) { queries =>
require(queries.size === datasets.size) require(queries.size === datasets.size)

View file

@ -122,7 +122,7 @@ class ReceiverSuite extends TestSuiteBase with TimeLimits with Serializable {
} }
// Verify that stopping actually stops the thread // Verify that stopping actually stops the thread
failAfter(100.milliseconds) { failAfter(1.second) {
receiver.stop("test") receiver.stop("test")
assert(receiver.isStopped) assert(receiver.isStopped)
assert(!receiver.otherThread.isAlive) assert(!receiver.otherThread.isAlive)