[SPARK-5681] [STREAMING] Move 'stopReceivers' to the event loop to resolve the race condition

This is an alternative way to fix `SPARK-5681`. It minimizes the changes.

Closes #4467

Author: zsxwing <zsxwing@gmail.com>
Author: Liang-Chi Hsieh <viirya@gmail.com>

Closes #6294 from zsxwing/pr4467 and squashes the following commits:

709ac1f [zsxwing] Fix the comment
e103e8a [zsxwing] Move ReceiverTracker.stop into ReceiverTracker.stop
f637142 [zsxwing] Address minor code style comments
a178d37 [zsxwing] Move 'stopReceivers' to the event looop to resolve the race condition
51fb07e [zsxwing] Fix the code style
3cb19a3 [zsxwing] Merge branch 'master' into pr4467
b4c29e7 [zsxwing] Stop receiver only if we start it
c41ee94 [zsxwing] Make stopReceivers private
7c73c1f [zsxwing] Use trackerStateLock to protect trackerState
a8120c0 [zsxwing] Merge branch 'master' into pr4467
7b1d9af [zsxwing] "case Throwable" => "case NonFatal"
15ed4a1 [zsxwing] Register before starting the receiver
fff63f9 [zsxwing] Use a lock to eliminate the race condition when stopping receivers and registering receivers happen at the same time.
e0ef72a [Liang-Chi Hsieh] Merge remote-tracking branch 'upstream/master' into tracker_status_timeout
19b76d9 [Liang-Chi Hsieh] Remove timeout.
34c18dc [Liang-Chi Hsieh] Merge remote-tracking branch 'upstream/master' into tracker_status_timeout
c419677 [Liang-Chi Hsieh] Fix style.
9e1a760 [Liang-Chi Hsieh] Merge remote-tracking branch 'upstream/master' into tracker_status_timeout
355f9ce [Liang-Chi Hsieh] Separate register and start events for receivers.
3d568e8 [Liang-Chi Hsieh] Let receivers get registered first before going started.
ae0d9fd [Liang-Chi Hsieh] Merge branch 'master' into tracker_status_timeout
77983f3 [Liang-Chi Hsieh] Add tracker status and stop to receive messages when stopping tracker.
This commit is contained in:
zsxwing 2015-07-17 14:00:31 -07:00 committed by Tathagata Das
parent 074085d678
commit ad0954f6de
5 changed files with 138 additions and 62 deletions

View file

@ -22,6 +22,7 @@ import java.util.concurrent.CountDownLatch
import scala.collection.mutable.ArrayBuffer
import scala.concurrent._
import scala.util.control.NonFatal
import org.apache.spark.{Logging, SparkConf}
import org.apache.spark.storage.StreamBlockId
@ -36,7 +37,7 @@ private[streaming] abstract class ReceiverSupervisor(
conf: SparkConf
) extends Logging {
/** Enumeration to identify current state of the StreamingContext */
/** Enumeration to identify current state of the Receiver */
object ReceiverState extends Enumeration {
type CheckpointState = Value
val Initialized, Started, Stopped = Value
@ -97,8 +98,8 @@ private[streaming] abstract class ReceiverSupervisor(
/** Called when supervisor is stopped */
protected def onStop(message: String, error: Option[Throwable]) { }
/** Called when receiver is started */
protected def onReceiverStart() { }
/** Called when receiver is started. Return true if the driver accepts us */
protected def onReceiverStart(): Boolean
/** Called when receiver is stopped */
protected def onReceiverStop(message: String, error: Option[Throwable]) { }
@ -121,13 +122,17 @@ private[streaming] abstract class ReceiverSupervisor(
/** Start receiver */
def startReceiver(): Unit = synchronized {
try {
logInfo("Starting receiver")
receiver.onStart()
logInfo("Called receiver onStart")
onReceiverStart()
receiverState = Started
if (onReceiverStart()) {
logInfo("Starting receiver")
receiverState = Started
receiver.onStart()
logInfo("Called receiver onStart")
} else {
// The driver refused us
stop("Registered unsuccessfully because Driver refused to start receiver " + streamId, None)
}
} catch {
case t: Throwable =>
case NonFatal(t) =>
stop("Error starting receiver " + streamId, Some(t))
}
}
@ -136,12 +141,19 @@ private[streaming] abstract class ReceiverSupervisor(
def stopReceiver(message: String, error: Option[Throwable]): Unit = synchronized {
try {
logInfo("Stopping receiver with message: " + message + ": " + error.getOrElse(""))
receiverState = Stopped
receiver.onStop()
logInfo("Called receiver onStop")
onReceiverStop(message, error)
receiverState match {
case Initialized =>
logWarning("Skip stopping receiver because it has not yet stared")
case Started =>
receiverState = Stopped
receiver.onStop()
logInfo("Called receiver onStop")
onReceiverStop(message, error)
case Stopped =>
logWarning("Receiver has been stopped")
}
} catch {
case t: Throwable =>
case NonFatal(t) =>
logError("Error stopping receiver " + streamId + t.getStackTraceString)
}
}
@ -167,7 +179,7 @@ private[streaming] abstract class ReceiverSupervisor(
}(futureExecutionContext)
}
/** Check if receiver has been marked for stopping */
/** Check if receiver has been marked for starting */
def isReceiverStarted(): Boolean = {
logDebug("state = " + receiverState)
receiverState == Started

View file

@ -162,7 +162,7 @@ private[streaming] class ReceiverSupervisorImpl(
env.rpcEnv.stop(endpoint)
}
override protected def onReceiverStart() {
override protected def onReceiverStart(): Boolean = {
val msg = RegisterReceiver(
streamId, receiver.getClass.getSimpleName, Utils.localHostName(), endpoint)
trackerEndpoint.askWithRetry[Boolean](msg)

View file

@ -20,7 +20,6 @@ package org.apache.spark.streaming.scheduler
import scala.collection.mutable.{ArrayBuffer, HashMap, SynchronizedMap}
import scala.language.existentials
import scala.math.max
import org.apache.spark.rdd._
import org.apache.spark.streaming.util.WriteAheadLogUtils
import org.apache.spark.{Logging, SparkEnv, SparkException}
@ -47,6 +46,8 @@ private[streaming] case class ReportError(streamId: Int, message: String, error:
private[streaming] case class DeregisterReceiver(streamId: Int, msg: String, error: String)
extends ReceiverTrackerMessage
private[streaming] case object StopAllReceivers extends ReceiverTrackerMessage
/**
* This class manages the execution of the receivers of ReceiverInputDStreams. Instance of
* this class must be created after all input streams have been added and StreamingContext.start()
@ -71,13 +72,23 @@ class ReceiverTracker(ssc: StreamingContext, skipReceiverLaunch: Boolean = false
)
private val listenerBus = ssc.scheduler.listenerBus
/** Enumeration to identify current state of the ReceiverTracker */
object TrackerState extends Enumeration {
type TrackerState = Value
val Initialized, Started, Stopping, Stopped = Value
}
import TrackerState._
/** State of the tracker. Protected by "trackerStateLock" */
@volatile private var trackerState = Initialized
// endpoint is created when generator starts.
// This not being null means the tracker has been started and not stopped
private var endpoint: RpcEndpointRef = null
/** Start the endpoint and receiver execution thread. */
def start(): Unit = synchronized {
if (endpoint != null) {
if (isTrackerStarted) {
throw new SparkException("ReceiverTracker already started")
}
@ -86,20 +97,46 @@ class ReceiverTracker(ssc: StreamingContext, skipReceiverLaunch: Boolean = false
"ReceiverTracker", new ReceiverTrackerEndpoint(ssc.env.rpcEnv))
if (!skipReceiverLaunch) receiverExecutor.start()
logInfo("ReceiverTracker started")
trackerState = Started
}
}
/** Stop the receiver execution thread. */
def stop(graceful: Boolean): Unit = synchronized {
if (!receiverInputStreams.isEmpty && endpoint != null) {
if (isTrackerStarted) {
// First, stop the receivers
if (!skipReceiverLaunch) receiverExecutor.stop(graceful)
trackerState = Stopping
if (!skipReceiverLaunch) {
// Send the stop signal to all the receivers
endpoint.askWithRetry[Boolean](StopAllReceivers)
// Wait for the Spark job that runs the receivers to be over
// That is, for the receivers to quit gracefully.
receiverExecutor.awaitTermination(10000)
if (graceful) {
val pollTime = 100
logInfo("Waiting for receiver job to terminate gracefully")
while (receiverInfo.nonEmpty || receiverExecutor.running) {
Thread.sleep(pollTime)
}
logInfo("Waited for receiver job to terminate gracefully")
}
// Check if all the receivers have been deregistered or not
if (receiverInfo.nonEmpty) {
logWarning("Not all of the receivers have deregistered, " + receiverInfo)
} else {
logInfo("All of the receivers have deregistered successfully")
}
}
// Finally, stop the endpoint
ssc.env.rpcEnv.stop(endpoint)
endpoint = null
receivedBlockTracker.stop()
logInfo("ReceiverTracker stopped")
trackerState = Stopped
}
}
@ -145,14 +182,23 @@ class ReceiverTracker(ssc: StreamingContext, skipReceiverLaunch: Boolean = false
host: String,
receiverEndpoint: RpcEndpointRef,
senderAddress: RpcAddress
) {
): Boolean = {
if (!receiverInputStreamIds.contains(streamId)) {
throw new SparkException("Register received for unexpected id " + streamId)
}
receiverInfo(streamId) = ReceiverInfo(
streamId, s"${typ}-${streamId}", receiverEndpoint, true, host)
listenerBus.post(StreamingListenerReceiverStarted(receiverInfo(streamId)))
logInfo("Registered receiver for stream " + streamId + " from " + senderAddress)
if (isTrackerStopping || isTrackerStopped) {
false
} else {
// "stopReceivers" won't happen at the same time because both "registerReceiver" and are
// called in the event loop. So here we can assume "stopReceivers" has not yet been called. If
// "stopReceivers" is called later, it should be able to see this receiver.
receiverInfo(streamId) = ReceiverInfo(
streamId, s"${typ}-${streamId}", receiverEndpoint, true, host)
listenerBus.post(StreamingListenerReceiverStarted(receiverInfo(streamId)))
logInfo("Registered receiver for stream " + streamId + " from " + senderAddress)
true
}
}
/** Deregister a receiver */
@ -220,20 +266,33 @@ class ReceiverTracker(ssc: StreamingContext, skipReceiverLaunch: Boolean = false
override def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] = {
case RegisterReceiver(streamId, typ, host, receiverEndpoint) =>
registerReceiver(streamId, typ, host, receiverEndpoint, context.sender.address)
context.reply(true)
val successful =
registerReceiver(streamId, typ, host, receiverEndpoint, context.sender.address)
context.reply(successful)
case AddBlock(receivedBlockInfo) =>
context.reply(addBlock(receivedBlockInfo))
case DeregisterReceiver(streamId, message, error) =>
deregisterReceiver(streamId, message, error)
context.reply(true)
case StopAllReceivers =>
assert(isTrackerStopping || isTrackerStopped)
stopReceivers()
context.reply(true)
}
/** Send stop signal to the receivers. */
private def stopReceivers() {
// Signal the receivers to stop
receiverInfo.values.flatMap { info => Option(info.endpoint)}
.foreach { _.send(StopReceiver) }
logInfo("Sent stop signal to all " + receiverInfo.size + " receivers")
}
}
/** This thread class runs all the receivers on the cluster. */
class ReceiverLauncher {
@transient val env = ssc.env
@volatile @transient private var running = false
@volatile @transient var running = false
@transient val thread = new Thread() {
override def run() {
try {
@ -249,31 +308,6 @@ class ReceiverTracker(ssc: StreamingContext, skipReceiverLaunch: Boolean = false
thread.start()
}
def stop(graceful: Boolean) {
// Send the stop signal to all the receivers
stopReceivers()
// Wait for the Spark job that runs the receivers to be over
// That is, for the receivers to quit gracefully.
thread.join(10000)
if (graceful) {
val pollTime = 100
logInfo("Waiting for receiver job to terminate gracefully")
while (receiverInfo.nonEmpty || running) {
Thread.sleep(pollTime)
}
logInfo("Waited for receiver job to terminate gracefully")
}
// Check if all the receivers have been deregistered or not
if (receiverInfo.nonEmpty) {
logWarning("Not all of the receivers have deregistered, " + receiverInfo)
} else {
logInfo("All of the receivers have deregistered successfully")
}
}
/**
* Get the list of executors excluding driver
*/
@ -358,17 +392,30 @@ class ReceiverTracker(ssc: StreamingContext, skipReceiverLaunch: Boolean = false
// Distribute the receivers and start them
logInfo("Starting " + receivers.length + " receivers")
running = true
ssc.sparkContext.runJob(tempRDD, ssc.sparkContext.clean(startReceiver))
running = false
logInfo("All of the receivers have been terminated")
try {
ssc.sparkContext.runJob(tempRDD, ssc.sparkContext.clean(startReceiver))
logInfo("All of the receivers have been terminated")
} finally {
running = false
}
}
/** Stops the receivers. */
private def stopReceivers() {
// Signal the receivers to stop
receiverInfo.values.flatMap { info => Option(info.endpoint)}
.foreach { _.send(StopReceiver) }
logInfo("Sent stop signal to all " + receiverInfo.size + " receivers")
/**
* Wait until the Spark job that runs the receivers is terminated, or return when
* `milliseconds` elapses
*/
def awaitTermination(milliseconds: Long): Unit = {
thread.join(milliseconds)
}
}
/** Check if tracker has been marked for starting */
private def isTrackerStarted(): Boolean = trackerState == Started
/** Check if tracker has been marked for stopping */
private def isTrackerStopping(): Boolean = trackerState == Stopping
/** Check if tracker has been marked for stopped */
private def isTrackerStopped(): Boolean = trackerState == Stopped
}

View file

@ -346,6 +346,8 @@ class ReceiverSuite extends TestSuiteBase with Timeouts with Serializable {
def reportError(message: String, throwable: Throwable) {
errors += throwable
}
override protected def onReceiverStart(): Boolean = true
}
/**

View file

@ -285,6 +285,21 @@ class StreamingContextSuite extends SparkFunSuite with BeforeAndAfter with Timeo
}
}
test("stop gracefully even if a receiver misses StopReceiver") {
// This is not a deterministic unit. But if this unit test is flaky, then there is definitely
// something wrong. See SPARK-5681
val conf = new SparkConf().setMaster(master).setAppName(appName)
sc = new SparkContext(conf)
ssc = new StreamingContext(sc, Milliseconds(100))
val input = ssc.receiverStream(new TestReceiver)
input.foreachRDD(_ => {})
ssc.start()
// Call `ssc.stop` at once so that it's possible that the receiver will miss "StopReceiver"
failAfter(30000 millis) {
ssc.stop(stopSparkContext = true, stopGracefully = true)
}
}
test("stop slow receiver gracefully") {
val conf = new SparkConf().setMaster(master).setAppName(appName)
conf.set("spark.streaming.gracefulStopTimeout", "20000s")