diff --git a/streaming/src/main/scala/org/apache/spark/streaming/util/MasterFailureTest.scala b/streaming/src/main/scala/org/apache/spark/streaming/util/MasterFailureTest.scala index 1559f7a9f7..162b19d7f0 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/util/MasterFailureTest.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/util/MasterFailureTest.scala @@ -42,6 +42,7 @@ object MasterFailureTest extends Logging { @volatile var killed = false @volatile var killCount = 0 + @volatile var setupCalled = false def main(args: Array[String]) { if (args.size < 2) { @@ -131,38 +132,6 @@ object MasterFailureTest extends Logging { // Just making sure that the expected output does not have duplicates assert(expectedOutput.distinct.toSet == expectedOutput.toSet) - // Setup the stream computation with the given operation - val (ssc, checkpointDir, testDir) = setupStreams(directory, batchDuration, operation) - - // Start generating files in the a different thread - val fileGeneratingThread = new FileGeneratingThread(input, testDir, batchDuration.milliseconds) - fileGeneratingThread.start() - - // Run the streams and repeatedly kill it until the last expected output - // has been generated, or until it has run for twice the expected time - val lastExpectedOutput = expectedOutput.last - val maxTimeToRun = expectedOutput.size * batchDuration.milliseconds * 2 - val mergedOutput = runStreams(ssc, lastExpectedOutput, maxTimeToRun) - - // Delete directories - fileGeneratingThread.join() - val fs = checkpointDir.getFileSystem(new Configuration()) - fs.delete(checkpointDir, true) - fs.delete(testDir, true) - logInfo("Finished test after " + killCount + " failures") - mergedOutput - } - - /** - * Sets up the stream computation with the given operation, directory (local or HDFS), - * and batch duration. Returns the streaming context and the directory to which - * files should be written for testing. - */ - private def setupStreams[T: ClassTag]( - directory: String, - batchDuration: Duration, - operation: DStream[String] => DStream[T] - ): (StreamingContext, Path, Path) = { // Reset all state reset() @@ -175,16 +144,56 @@ object MasterFailureTest extends Logging { fs.mkdirs(checkpointDir) fs.mkdirs(testDir) + // Setup the stream computation with the given operation + val ssc = StreamingContext.getOrCreate(checkpointDir.toString, () => { + setupStreams(batchDuration, operation, checkpointDir, testDir) + }) + + // Check if setupStream was called to create StreamingContext + // (and not created from checkpoint file) + assert(setupCalled, "Setup was not called in the first call to StreamingContext.getOrCreate") + + // Start generating files in the a different thread + val fileGeneratingThread = new FileGeneratingThread(input, testDir, batchDuration.milliseconds) + fileGeneratingThread.start() + + // Run the streams and repeatedly kill it until the last expected output + // has been generated, or until it has run for twice the expected time + val lastExpectedOutput = expectedOutput.last + val maxTimeToRun = expectedOutput.size * batchDuration.milliseconds * 2 + val mergedOutput = runStreams(ssc, lastExpectedOutput, maxTimeToRun) + + fileGeneratingThread.join() + fs.delete(checkpointDir, true) + fs.delete(testDir, true) + logInfo("Finished test after " + killCount + " failures") + mergedOutput + } + + /** + * Sets up the stream computation with the given operation, directory (local or HDFS), + * and batch duration. Returns the streaming context and the directory to which + * files should be written for testing. + */ + private def setupStreams[T: ClassTag]( + batchDuration: Duration, + operation: DStream[String] => DStream[T], + checkpointDir: Path, + testDir: Path + ): StreamingContext = { + // Mark that setup was called + setupCalled = true + // Setup the streaming computation with the given operation System.clearProperty("spark.driver.port") System.clearProperty("spark.hostPort") - var ssc = new StreamingContext("local[4]", "MasterFailureTest", batchDuration, null, Nil, Map()) + val ssc = new StreamingContext("local[4]", "MasterFailureTest", batchDuration, null, Nil, Map()) ssc.checkpoint(checkpointDir.toString) val inputStream = ssc.textFileStream(testDir.toString) val operatedStream = operation(inputStream) val outputStream = new TestOutputStream(operatedStream) ssc.registerOutputStream(outputStream) - (ssc, checkpointDir, testDir) + ssc } @@ -204,7 +213,7 @@ object MasterFailureTest extends Logging { var isTimedOut = false val mergedOutput = new ArrayBuffer[T]() val checkpointDir = ssc.checkpointDir - var batchDuration = ssc.graph.batchDuration + val batchDuration = ssc.graph.batchDuration while(!isLastOutputGenerated && !isTimedOut) { // Get the output buffer @@ -261,7 +270,10 @@ object MasterFailureTest extends Logging { ) Thread.sleep(sleepTime) // Recreate the streaming context from checkpoint - ssc = new StreamingContext(checkpointDir) + ssc = StreamingContext.getOrCreate(checkpointDir, () => { + throw new Exception("Trying to create new context when it " + + "should be reading from checkpoint file") + }) } } mergedOutput @@ -297,6 +309,7 @@ object MasterFailureTest extends Logging { private def reset() { killed = false killCount = 0 + setupCalled = false } }