diff --git a/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala b/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala index a8973d1b60..2a7a55cbb9 100644 --- a/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala @@ -72,27 +72,31 @@ trait TestPrematureExit { mainObject.printStream = printStream @volatile var exitedCleanly = false + val original = mainObject.exitFn mainObject.exitFn = (_) => exitedCleanly = true - - @volatile var exception: Exception = null - val thread = new Thread { - override def run() = try { - mainObject.main(input) - } catch { - // Capture the exception to check whether the exception contains searchString or not - case e: Exception => exception = e + try { + @volatile var exception: Exception = null + val thread = new Thread { + override def run() = try { + mainObject.main(input) + } catch { + // Capture the exception to check whether the exception contains searchString or not + case e: Exception => exception = e + } } - } - thread.start() - thread.join() - if (exitedCleanly) { - val joined = printStream.lineBuffer.mkString("\n") - assert(joined.contains(searchString)) - } else { - assert(exception != null) - if (!exception.getMessage.contains(searchString)) { - throw exception + thread.start() + thread.join() + if (exitedCleanly) { + val joined = printStream.lineBuffer.mkString("\n") + assert(joined.contains(searchString)) + } else { + assert(exception != null) + if (!exception.getMessage.contains(searchString)) { + throw exception + } } + } finally { + mainObject.exitFn = original } } }