Graceful shutdown after a single transmission in the swarm is over.

There might still be a problem with the Tracker shutdown. It must be done explicitly by SparkContext.
This commit is contained in:
Mosharaf Chowdhury 2010-11-04 22:09:14 -07:00
parent 10fc66b1c4
commit 878d157ce3
2 changed files with 160 additions and 77 deletions

View file

@ -1 +1 @@
-Dspark.broadcast.MasterHostAddress=127.0.0.1 -Dspark.broadcast.MasterTrackerPort=11111 -Dspark.broadcast.BlockSize=256 -Dspark.broadcast.MaxRetryCount=2 -Dspark.broadcast.ServerSocketTimout=50000 -Dspark.broadcast.MaxChatTime=500
-Dspark.broadcast.MasterHostAddress=127.0.0.1 -Dspark.broadcast.MasterTrackerPort=11111 -Dspark.broadcast.BlockSize=256 -Dspark.broadcast.MaxRetryCount=2 -Dspark.broadcast.TrackerSocketTimeout=50000 -Dspark.broadcast.ServerSocketTimout=10000 -Dspark.broadcast.MaxChatTime=500

View file

@ -8,7 +8,7 @@ import com.google.common.collect.MapMaker
import java.util.concurrent.{Executors, ExecutorService, ThreadPoolExecutor}
import scala.collection.mutable.{ListBuffer, Map}
import scala.collection.mutable.{ListBuffer, Map, Set}
import org.apache.hadoop.conf.Configuration
import org.apache.hadoop.fs.{FileSystem, Path, RawLocalFileSystem}
@ -66,6 +66,7 @@ extends BroadcastRecipe with Logging {
@transient var guidePort = -1
@transient var hasCopyInHDFS = false
@transient var stopBroadcast = false
// Must call this after all the variables have been created/initialized
if (!local) {
@ -77,7 +78,8 @@ extends BroadcastRecipe with Logging {
// TODO: Turned OFF for now
// val out = new ObjectOutputStream (BroadcastCH.openFileForWriting(uuid))
// out.writeObject (value_)
// out.close
// out.close
// TODO: Fix this at some point
hasCopyInHDFS = true
// Create a variableInfo object and store it in valueInfos
@ -203,6 +205,8 @@ extends BroadcastRecipe with Logging {
listenPort = -1
listOfSources = ListBuffer[SourceInfo] ()
stopBroadcast = false
}
private def blockifyObject (obj: T, blockSize: Int): VariableInfo = {
@ -270,6 +274,8 @@ extends BroadcastRecipe with Logging {
var localSourceInfo = SourceInfo (hostAddress, listenPort, totalBlocks,
totalBytes)
localSourceInfo.hasBlocks = hasBlocks
hasBlocksBitVector.synchronized {
localSourceInfo.hasBlocksBitVector = hasBlocksBitVector
@ -299,40 +305,48 @@ extends BroadcastRecipe with Logging {
}
}
class TalkToGuide (gInfo: SourceInfo)
class TalkToGuide (gInfo: SourceInfo)
extends Thread with Logging {
override def run = {
// Connect to Guide and send this worker's information
// Keep exchaning information until all blocks have been received
while (hasBlocks < totalBlocks) {
talkOnce
Thread.sleep ( BroadcastBT.ranGen.nextInt (
BroadcastBT.MaxKnockInterval - BroadcastBT.MinKnockInterval) +
BroadcastBT.MinKnockInterval)
}
// Talk one more time to let the Guide know of reception completion
talkOnce
}
// Connect to Guide and send this worker's information
private def talkOnce = {
var clientSocketToGuide: Socket = null
var oosGuide: ObjectOutputStream = null
var oisGuide: ObjectInputStream = null
while (hasBlocks < totalBlocks) {
clientSocketToGuide = new Socket(gInfo.hostAddress, gInfo.listenPort)
oosGuide = new ObjectOutputStream (clientSocketToGuide.getOutputStream)
oosGuide.flush
oisGuide = new ObjectInputStream (clientSocketToGuide.getInputStream)
clientSocketToGuide = new Socket(gInfo.hostAddress, gInfo.listenPort)
oosGuide = new ObjectOutputStream (clientSocketToGuide.getOutputStream)
oosGuide.flush
oisGuide = new ObjectInputStream (clientSocketToGuide.getInputStream)
// Send local information
oosGuide.writeObject(getLocalSourceInfo)
oosGuide.flush
// Send local information
oosGuide.writeObject(getLocalSourceInfo)
oosGuide.flush
// Receive source information from Guide
var suitableSources =
oisGuide.readObject.asInstanceOf[ListBuffer[SourceInfo]]
logInfo("Received suitableSources from Master " + suitableSources)
addToListOfSources (suitableSources)
oisGuide.close
oosGuide.close
clientSocketToGuide.close
Thread.sleep ( BroadcastBT.ranGen.nextInt (
BroadcastBT.MaxKnockInterval - BroadcastBT.MinKnockInterval) +
BroadcastBT.MinKnockInterval)
}
}
// Receive source information from Guide
var suitableSources =
oisGuide.readObject.asInstanceOf[ListBuffer[SourceInfo]]
logInfo("Received suitableSources from Master " + suitableSources)
addToListOfSources (suitableSources)
oisGuide.close
oosGuide.close
clientSocketToGuide.close
}
}
def getGuideInfo (variableUUID: UUID): SourceInfo = {
@ -427,7 +441,7 @@ extends BroadcastRecipe with Logging {
// TODO: Must fix this. This might never break if broadcast fails.
// We should be able to break and send false. Also need to kill threads
while (hasBlocks != totalBlocks) {
while (hasBlocks < totalBlocks) {
Thread.sleep(1234)
}
@ -448,7 +462,7 @@ extends BroadcastRecipe with Logging {
Math.min (listOfSources.size, BroadcastBT.MaxTxPeers) -
threadPool.getActiveCount
while(numThreadsToCreate > 0 && hasBlocks < totalBlocks) {
while (hasBlocks < totalBlocks && numThreadsToCreate > 0) {
var peerToTalkTo = pickPeerToTalkTo
if (peerToTalkTo != null) {
threadPool.execute (new TalkToPeer (peerToTalkTo))
@ -464,8 +478,11 @@ extends BroadcastRecipe with Logging {
}
// Sleep for a while before starting some more threads
// TODO: Whats up with this?
Thread.sleep (500)
}
// Shutdown the thread pool
threadPool.shutdown
}
// TODO: Right now picking the one that has the most blocks this peer wants
@ -510,7 +527,6 @@ extends BroadcastRecipe with Logging {
private var oisSource: ObjectInputStream = null
override def run = {
// Setup the timeout mechanism
var timeOutTask = new TimerTask {
override def run = {
@ -616,6 +632,9 @@ extends BroadcastRecipe with Logging {
class GuideMultipleRequests
extends Thread with Logging {
// Keep track of sources that have completed reception
private var setOfCompletedSources = Set[SourceInfo] ()
override def run = {
// TODO: Cached threadpool has 60s keep alive timer
var threadPool = Executors.newCachedThreadPool
@ -629,18 +648,24 @@ extends BroadcastRecipe with Logging {
guidePortLock.notifyAll
}
var keepAccepting = true
try {
// Don't stop until there is a copy in HDFS
while (keepAccepting || !hasCopyInHDFS) {
while (!stopBroadcast || !hasCopyInHDFS) {
var clientSocket: Socket = null
try {
serverSocket.setSoTimeout (BroadcastBT.ServerSocketTimout)
serverSocket.setSoTimeout (BroadcastBT.ServerSocketTimeout)
clientSocket = serverSocket.accept
} catch {
case e: Exception => {
logInfo ("GuideMultipleRequests Timeout. Stopping listening..." + hasCopyInHDFS)
keepAccepting = false
case e: Exception => {
logInfo ("GuideMultipleRequests Timeout.")
// Stop broadcast if at least one worker has connected and
// everyone connected so far are done. Comparing with
// listOfSources.size - 1, because it includes the Guide itself
if (listOfSources.size > 1 &&
setOfCompletedSources.size == listOfSources.size - 1) {
stopBroadcast = true
}
}
}
if (clientSocket != null) {
@ -655,9 +680,46 @@ extends BroadcastRecipe with Logging {
}
}
}
// Shutdown the thread pool
threadPool.shutdown
logInfo ("Sending stopBroadcast notifications...")
sendStopBroadcastNotifications
BroadcastBT.unregisterValue (uuid)
} finally {
serverSocket.close
if (serverSocket != null) {
logInfo ("GuideMultipleRequests now stopping...")
serverSocket.close
}
}
}
private def sendStopBroadcastNotifications = {
listOfSources.synchronized {
listOfSources.foreach { sourceInfo =>
// Connect to the source
var guideSocketToSource =
new Socket (sourceInfo.hostAddress, sourceInfo.listenPort)
var gosSource =
new ObjectOutputStream (guideSocketToSource.getOutputStream)
gosSource.flush
var gisSource =
new ObjectInputStream (guideSocketToSource.getInputStream)
// Throw away whatever comes in
gisSource.readObject.asInstanceOf[SourceInfo]
// Sent stopBroadcast signal. listenPort = SourceInfo.StopBroadcast
gosSource.writeObject(SourceInfo("", SourceInfo.StopBroadcast,
SourceInfo.UnusedParam, SourceInfo.UnusedParam))
gosSource.flush
gisSource.close
gosSource.close
guideSocketToSource.close
}
}
}
@ -670,6 +732,7 @@ extends BroadcastRecipe with Logging {
private var sourceInfo: SourceInfo = null
private var selectedSources: ListBuffer[SourceInfo] = null
// Used to select a rolling window of peers from listOfSources
private var rollOverIndex = 0
override def run = {
@ -689,7 +752,7 @@ extends BroadcastRecipe with Logging {
} catch {
case e: Exception => {
// Assuming exception caused by receiver failure: remove
if (listOfSources != null) {
if (listOfSources != null) {
listOfSources.synchronized {
listOfSources = listOfSources - sourceInfo
}
@ -703,11 +766,20 @@ extends BroadcastRecipe with Logging {
}
// TODO: Randomly select some sources to send back.
// Right now just rolls over the listOfSources to send back
// Right now just rolls over the listOfSources to send back
// BroadcastBT.MaxPeersInGuideResponse number of possible sources
private def selectSuitableSources(skipSourceInfo: SourceInfo): ListBuffer[SourceInfo] = {
var curIndex = rollOverIndex
private def selectSuitableSources(skipSourceInfo: SourceInfo): ListBuffer[SourceInfo] = {
var selectedSources = ListBuffer[SourceInfo] ()
// If skipSourceInfo.hasBlocksBitVector has all bits set to 'true'
// then add skipSourceInfo to setOfCompletedSources. Return blank.
if (skipSourceInfo.hasBlocks == totalBlocks) {
setOfCompletedSources += skipSourceInfo
return selectedSources
}
var curIndex = rollOverIndex
listOfSources.synchronized {
do {
if (listOfSources(curIndex) != skipSourceInfo) {
@ -726,12 +798,10 @@ extends BroadcastRecipe with Logging {
class ServeMultipleRequests
extends Thread with Logging {
override def run = {
// TODO: Look into ExecutorService shutdown and shutdownNow methods
// TODO: Not sure if this will be able to fix the number of outgoing links
// We should have a timeout mechanism on the receiver side
var threadPool =
Executors.newFixedThreadPool(
BroadcastBT.MaxRxPeers).asInstanceOf[ThreadPoolExecutor]
Executors.newFixedThreadPool(BroadcastBT.MaxRxPeers)
var serverSocket = new ServerSocket (0)
listenPort = serverSocket.getLocalPort
@ -742,17 +812,15 @@ extends BroadcastRecipe with Logging {
listenPortLock.notifyAll
}
var keepAccepting = true
try {
while (keepAccepting) {
while (!stopBroadcast) {
var clientSocket: Socket = null
try {
serverSocket.setSoTimeout (BroadcastBT.ServerSocketTimout)
serverSocket.setSoTimeout (BroadcastBT.ServerSocketTimeout)
clientSocket = serverSocket.accept
} catch {
case e: Exception => {
logInfo ("ServeMultipleRequests Timeout. Stopping listening...")
keepAccepting = false
logInfo ("ServeMultipleRequests Timeout.")
}
}
if (clientSocket != null) {
@ -768,10 +836,13 @@ extends BroadcastRecipe with Logging {
}
}
} finally {
if (serverSocket != null) {
if (serverSocket != null) {
logInfo ("ServeMultipleRequests now stopping...")
serverSocket.close
}
}
}
// Shutdown the thread pool
threadPool.shutdown
}
class ServeSingleRequest (val clientSocket: Socket)
@ -791,18 +862,22 @@ extends BroadcastRecipe with Logging {
oos.flush
// Receive latest SourceInfo from the receiver
var rxSourceInfo = ois.readObject.asInstanceOf[SourceInfo]
var rxSourceInfo = ois.readObject.asInstanceOf[SourceInfo]
logInfo("rxSourceInfo: " + rxSourceInfo + " with " + rxSourceInfo.hasBlocksBitVector)
addToListOfSources (rxSourceInfo)
// TODO: NOT the most efficient way to do time-based break;
// but using timer can cause a break in the middle :-S
if (rxSourceInfo.listenPort == SourceInfo.StopBroadcast) {
stopBroadcast = true
} else {
// Carry on
addToListOfSources (rxSourceInfo)
}
val startTime = System.currentTimeMillis
var curTime = startTime
var keepSending = true
var blocksToSend = BroadcastBT.MaxChatBlocks
while (keepSending && blocksToSend > 0 &&
while (!stopBroadcast && keepSending && blocksToSend > 0 &&
(curTime - startTime) < BroadcastBT.MaxChatTime) {
val sentBlock = pickAndSendBlock (rxSourceInfo.hasBlocksBitVector)
if (sentBlock < 0) {
@ -930,6 +1005,7 @@ case class SourceInfo (val hostAddress: String, val listenPort: Int,
var currentLeechers = 0
var receptionFailed = false
var hasBlocks = 0
var hasBlocksBitVector: BitSet = new BitSet (totalBlocks)
}
@ -938,6 +1014,7 @@ object SourceInfo {
val TxNotStartedRetry = -1
val TxOverGoToHDFS = 0
// Other constants
val StopBroadcast = -2
val UnusedParam = 0
}
@ -1021,7 +1098,9 @@ extends Logging {
private var MasterTrackerPort_ : Int = 11111
private var BlockSize_ : Int = 512 * 1024
private var MaxRetryCount_ : Int = 2
private var ServerSocketTimout_ : Int = 50000
private var TrackerSocketTimeout_ : Int = 50000
private var ServerSocketTimeout_ : Int = 10000
private var trackMV: TrackMultipleValues = null
@ -1055,8 +1134,11 @@ extends Logging {
System.getProperty ("spark.broadcast.BlockSize", "512").toInt * 1024
MaxRetryCount_ =
System.getProperty ("spark.broadcast.MaxRetryCount", "2").toInt
ServerSocketTimout_ =
System.getProperty ("spark.broadcast.ServerSocketTimout", "50000").toInt
TrackerSocketTimeout_ =
System.getProperty ("spark.broadcast.TrackerSocketTimeout", "50000").toInt
ServerSocketTimeout_ =
System.getProperty ("spark.broadcast.ServerSocketTimeout", "10000").toInt
MinKnockInterval_ =
System.getProperty ("spark.broadcast.MinKnockInterval", "500").toInt
@ -1094,7 +1176,9 @@ extends Logging {
def MasterTrackerPort = MasterTrackerPort_
def BlockSize = BlockSize_
def MaxRetryCount = MaxRetryCount_
def ServerSocketTimout = ServerSocketTimout_
def TrackerSocketTimeout = TrackerSocketTimeout_
def ServerSocketTimeout = ServerSocketTimeout_
def isMaster = isMaster_
@ -1122,41 +1206,38 @@ extends Logging {
valueToGuideMap.synchronized {
valueToGuideMap (uuid) = SourceInfo ("", SourceInfo.TxOverGoToHDFS,
SourceInfo.UnusedParam, SourceInfo.UnusedParam)
logInfo ("Value unregistered from the Tracker " + valueToGuideMap)
logInfo ("Value unregistered from the Tracker " + valueToGuideMap)
}
}
// def startMultiTracker
// def stopMultiTracker
class TrackMultipleValues
extends Thread with Logging {
var keepAccepting = true
var stopTracker = false
override def run = {
var threadPool = Executors.newCachedThreadPool
var serverSocket: ServerSocket = null
serverSocket = new ServerSocket (BroadcastBT.MasterTrackerPort)
logInfo ("TrackMultipleValues" + serverSocket)
logInfo ("TrackMultipleValues" + serverSocket)
try {
while (keepAccepting) {
while (!stopTracker) {
var clientSocket: Socket = null
try {
// TODO:
serverSocket.setSoTimeout (ServerSocketTimout)
serverSocket.setSoTimeout (TrackerSocketTimeout)
clientSocket = serverSocket.accept
} catch {
case e: Exception => {
case e: Exception => {
logInfo ("TrackMultipleValues Timeout. Stopping listening...")
// TODO: Tracking should be explicitly stopped by the SparkContext
keepAccepting = false
stopTracker = true
}
}
if (clientSocket != null) {
try {
try {
threadPool.execute (new Thread {
override def run = {
val oos = new ObjectOutputStream (clientSocket.getOutputStream)
@ -1164,12 +1245,12 @@ extends Logging {
val ois = new ObjectInputStream (clientSocket.getInputStream)
try {
val uuid = ois.readObject.asInstanceOf[UUID]
var gInfo =
var gInfo =
if (valueToGuideMap.contains (uuid)) {
valueToGuideMap (uuid)
} else SourceInfo ("", SourceInfo.TxNotStartedRetry,
} else SourceInfo ("", SourceInfo.TxNotStartedRetry,
SourceInfo.UnusedParam, SourceInfo.UnusedParam)
logInfo ("TrackMultipleValues:Got new request: " + clientSocket + " for " + uuid + " : " + gInfo.listenPort)
logInfo ("TrackMultipleValues:Got new request: " + clientSocket + " for " + uuid + " : " + gInfo.listenPort)
oos.writeObject (gInfo)
} catch {
case e: Exception => { }
@ -1190,7 +1271,9 @@ extends Logging {
}
} finally {
serverSocket.close
}
}
// Shutdown the thread pool
threadPool.shutdown
}
}
}