[SPARK-7341] [STREAMING] [TESTS] Fix the flaky test: org.apache.spark.stre...

...aming.InputStreamsSuite.socket input stream

Remove non-deterministic "Thread.sleep" and use deterministic strategies to fix the flaky failure: https://amplab.cs.berkeley.edu/jenkins/job/Spark-Master-Maven-pre-YARN/hadoop.version=1.0.4,label=centos/2127/testReport/junit/org.apache.spark.streaming/InputStreamsSuite/socket_input_stream/

Author: zsxwing <zsxwing@gmail.com>

Closes #5891 from zsxwing/SPARK-7341 and squashes the following commits:

611157a [zsxwing] Add wait methods to BatchCounter and use BatchCounter in InputStreamsSuite
014b58f [zsxwing] Use withXXX to clean up the resources
c9bf746 [zsxwing] Move 'waitForStart' into the 'start' method and fix the code style
9d0de6d [zsxwing] [SPARK-7341][Streaming][Tests] Fix the flaky test: org.apache.spark.streaming.InputStreamsSuite.socket input stream
This commit is contained in:
zsxwing 2015-05-05 02:15:39 -07:00 committed by Tathagata Das
parent 8436f7e98e
commit 4d29867ede
2 changed files with 137 additions and 60 deletions

View file

@ -18,9 +18,9 @@
package org.apache.spark.streaming
import java.io.{File, BufferedWriter, OutputStreamWriter}
import java.net.{SocketException, ServerSocket}
import java.net.{Socket, SocketException, ServerSocket}
import java.nio.charset.Charset
import java.util.concurrent.{Executors, TimeUnit, ArrayBlockingQueue}
import java.util.concurrent.{CountDownLatch, Executors, TimeUnit, ArrayBlockingQueue}
import java.util.concurrent.atomic.AtomicInteger
import scala.collection.mutable.{SynchronizedBuffer, ArrayBuffer, SynchronizedQueue}
@ -36,6 +36,7 @@ import org.scalatest.concurrent.Eventually._
import org.apache.spark.Logging
import org.apache.spark.rdd.RDD
import org.apache.spark.storage.StorageLevel
import org.apache.spark.streaming.scheduler.{StreamingListenerBatchCompleted, StreamingListener}
import org.apache.spark.util.{ManualClock, Utils}
import org.apache.spark.streaming.dstream.{InputDStream, ReceiverInputDStream}
import org.apache.spark.streaming.receiver.Receiver
@ -43,51 +44,57 @@ import org.apache.spark.streaming.receiver.Receiver
class InputStreamsSuite extends TestSuiteBase with BeforeAndAfter {
test("socket input stream") {
// Start the server
val testServer = new TestServer()
testServer.start()
withTestServer(new TestServer()) { testServer =>
// Start the server
testServer.start()
// Set up the streaming context and input streams
val ssc = new StreamingContext(conf, batchDuration)
val networkStream = ssc.socketTextStream(
"localhost", testServer.port, StorageLevel.MEMORY_AND_DISK)
val outputBuffer = new ArrayBuffer[Seq[String]] with SynchronizedBuffer[Seq[String]]
val outputStream = new TestOutputStream(networkStream, outputBuffer)
def output: ArrayBuffer[String] = outputBuffer.flatMap(x => x)
outputStream.register()
ssc.start()
// Set up the streaming context and input streams
withStreamingContext(new StreamingContext(conf, batchDuration)) { ssc =>
val input = Seq(1, 2, 3, 4, 5)
// Use "batchCount" to make sure we check the result after all batches finish
val batchCounter = new BatchCounter(ssc)
val networkStream = ssc.socketTextStream(
"localhost", testServer.port, StorageLevel.MEMORY_AND_DISK)
val outputBuffer = new ArrayBuffer[Seq[String]] with SynchronizedBuffer[Seq[String]]
val outputStream = new TestOutputStream(networkStream, outputBuffer)
outputStream.register()
ssc.start()
// Feed data to the server to send to the network receiver
val clock = ssc.scheduler.clock.asInstanceOf[ManualClock]
val input = Seq(1, 2, 3, 4, 5)
val expectedOutput = input.map(_.toString)
Thread.sleep(1000)
for (i <- 0 until input.size) {
testServer.send(input(i).toString + "\n")
Thread.sleep(500)
clock.advance(batchDuration.milliseconds)
}
Thread.sleep(1000)
logInfo("Stopping server")
testServer.stop()
logInfo("Stopping context")
ssc.stop()
// Feed data to the server to send to the network receiver
val clock = ssc.scheduler.clock.asInstanceOf[ManualClock]
val expectedOutput = input.map(_.toString)
for (i <- 0 until input.size) {
testServer.send(input(i).toString + "\n")
Thread.sleep(500)
clock.advance(batchDuration.milliseconds)
}
// Make sure we finish all batches before "stop"
if (!batchCounter.waitUntilBatchesCompleted(input.size, 30000)) {
fail("Timeout: cannot finish all batches in 30 seconds")
}
logInfo("Stopping server")
testServer.stop()
logInfo("Stopping context")
ssc.stop()
// Verify whether data received was as expected
logInfo("--------------------------------")
logInfo("output.size = " + outputBuffer.size)
logInfo("output")
outputBuffer.foreach(x => logInfo("[" + x.mkString(",") + "]"))
logInfo("expected output.size = " + expectedOutput.size)
logInfo("expected output")
expectedOutput.foreach(x => logInfo("[" + x.mkString(",") + "]"))
logInfo("--------------------------------")
// Verify whether data received was as expected
logInfo("--------------------------------")
logInfo("output.size = " + outputBuffer.size)
logInfo("output")
outputBuffer.foreach(x => logInfo("[" + x.mkString(",") + "]"))
logInfo("expected output.size = " + expectedOutput.size)
logInfo("expected output")
expectedOutput.foreach(x => logInfo("[" + x.mkString(",") + "]"))
logInfo("--------------------------------")
// Verify whether all the elements received are as expected
// (whether the elements were received one in each interval is not verified)
assert(output.size === expectedOutput.size)
for (i <- 0 until output.size) {
assert(output(i) === expectedOutput(i))
// Verify whether all the elements received are as expected
// (whether the elements were received one in each interval is not verified)
val output: ArrayBuffer[String] = outputBuffer.flatMap(x => x)
assert(output.size === expectedOutput.size)
for (i <- 0 until output.size) {
assert(output(i) === expectedOutput(i))
}
}
}
}
@ -368,31 +375,45 @@ class TestServer(portToBind: Int = 0) extends Logging {
val serverSocket = new ServerSocket(portToBind)
private val startLatch = new CountDownLatch(1)
val servingThread = new Thread() {
override def run() {
try {
while(true) {
logInfo("Accepting connections on port " + port)
val clientSocket = serverSocket.accept()
logInfo("New connection")
try {
clientSocket.setTcpNoDelay(true)
val outputStream = new BufferedWriter(
new OutputStreamWriter(clientSocket.getOutputStream))
if (startLatch.getCount == 1) {
// The first connection is a test connection to implement "waitForStart", so skip it
// and send a signal
if (!clientSocket.isClosed) {
clientSocket.close()
}
startLatch.countDown()
} else {
// Real connections
logInfo("New connection")
try {
clientSocket.setTcpNoDelay(true)
val outputStream = new BufferedWriter(
new OutputStreamWriter(clientSocket.getOutputStream))
while(clientSocket.isConnected) {
val msg = queue.poll(100, TimeUnit.MILLISECONDS)
if (msg != null) {
outputStream.write(msg)
outputStream.flush()
logInfo("Message '" + msg + "' sent")
while (clientSocket.isConnected) {
val msg = queue.poll(100, TimeUnit.MILLISECONDS)
if (msg != null) {
outputStream.write(msg)
outputStream.flush()
logInfo("Message '" + msg + "' sent")
}
}
} catch {
case e: SocketException => logError("TestServer error", e)
} finally {
logInfo("Connection closed")
if (!clientSocket.isClosed) {
clientSocket.close()
}
}
} catch {
case e: SocketException => logError("TestServer error", e)
} finally {
logInfo("Connection closed")
if (!clientSocket.isClosed) clientSocket.close()
}
}
} catch {
@ -404,7 +425,29 @@ class TestServer(portToBind: Int = 0) extends Logging {
}
}
def start() { servingThread.start() }
def start(): Unit = {
servingThread.start()
if (!waitForStart(10000)) {
stop()
throw new AssertionError("Timeout: TestServer cannot start in 10 seconds")
}
}
/**
* Wait until the server starts. Return true if the server starts in "millis" milliseconds.
* Otherwise, return false to indicate it's timeout.
*/
private def waitForStart(millis: Long): Boolean = {
// We will create a test connection to the server so that we can make sure it has started.
val socket = new Socket("localhost", port)
try {
startLatch.await(millis, TimeUnit.MILLISECONDS)
} finally {
if (!socket.isClosed) {
socket.close()
}
}
}
def send(msg: String) { queue.put(msg) }

View file

@ -146,6 +146,40 @@ class BatchCounter(ssc: StreamingContext) {
def getNumStartedBatches: Int = this.synchronized {
numStartedBatches
}
/**
* Wait until `expectedNumCompletedBatches` batches are completed, or timeout. Return true if
* `expectedNumCompletedBatches` batches are completed. Otherwise, return false to indicate it's
* timeout.
*
* @param expectedNumCompletedBatches the `expectedNumCompletedBatches` batches to wait
* @param timeout the maximum time to wait in milliseconds.
*/
def waitUntilBatchesCompleted(expectedNumCompletedBatches: Int, timeout: Long): Boolean =
waitUntilConditionBecomeTrue(numCompletedBatches >= expectedNumCompletedBatches, timeout)
/**
* Wait until `expectedNumStartedBatches` batches are completed, or timeout. Return true if
* `expectedNumStartedBatches` batches are completed. Otherwise, return false to indicate it's
* timeout.
*
* @param expectedNumStartedBatches the `expectedNumStartedBatches` batches to wait
* @param timeout the maximum time to wait in milliseconds.
*/
def waitUntilBatchesStarted(expectedNumStartedBatches: Int, timeout: Long): Boolean =
waitUntilConditionBecomeTrue(numStartedBatches >= expectedNumStartedBatches, timeout)
private def waitUntilConditionBecomeTrue(condition: => Boolean, timeout: Long): Boolean = {
synchronized {
var now = System.currentTimeMillis()
val timeoutTick = now + timeout
while (!condition && timeoutTick > now) {
wait(timeoutTick - now)
now = System.currentTimeMillis()
}
condition
}
}
}
/**