CustomBlockedInMemoryShuffle is an in- memroy implementation of CustomBlockedLFS
This commit is contained in:
parent
a064835808
commit
8dc44bfa96
|
@ -1 +1 @@
|
|||
-Dspark.shuffle.class=spark.CustomBlockedLocalFileShuffle -Dspark.blockedLocalFileShuffle.maxRxConnections=2 -Dspark.blockedLocalFileShuffle.maxTxConnections=2 -Dspark.blockedLocalFileShuffle.blockSize=256 -Dspark.blockedLocalFileShuffle.minKnockInterval=50 -Dspark.parallelLocalFileShuffle.maxRxConnections=2 -Dspark.parallelLocalFileShuffle.maxTxConnections=2 -Dspark.parallelLocalFileShuffle.minKnockInterval=50 -Dspark.parallelLocalFileShuffle.maxKnockInterval=2000 -Dspark.parallelInMemoryShuffle.maxRxConnections=2 -Dspark.parallelInMemoryShuffle.maxTxConnections=2 -Dspark.parallelInMemoryShuffle.minKnockInterval=50 -Dspark.parallelInMemoryShuffle.maxKnockInterval=2000
|
||||
-Dspark.shuffle.class=spark.CustomBlockedInMemoryShuffle -Dspark.blockedLocalFileShuffle.maxRxConnections=2 -Dspark.blockedLocalFileShuffle.maxTxConnections=2 -Dspark.blockedLocalFileShuffle.blockSize=256 -Dspark.blockedLocalFileShuffle.minKnockInterval=50 -Dspark.blockedInMemoryShuffle.maxRxConnections=2 -Dspark.blockedInMemoryShuffle.maxTxConnections=2 -Dspark.blockedInMemoryShuffle.minKnockInterval=50 -Dspark.blockedInMemoryShuffle.maxKnockInterval=2000 -Dspark.parallelLocalFileShuffle.maxRxConnections=2 -Dspark.parallelLocalFileShuffle.maxTxConnections=2 -Dspark.parallelLocalFileShuffle.minKnockInterval=50 -Dspark.parallelLocalFileShuffle.maxKnockInterval=2000 -Dspark.parallelInMemoryShuffle.maxRxConnections=2 -Dspark.parallelInMemoryShuffle.maxTxConnections=2 -Dspark.parallelInMemoryShuffle.minKnockInterval=50 -Dspark.parallelInMemoryShuffle.maxKnockInterval=2000
|
||||
|
|
657
src/scala/spark/CustomBlockedInMemoryShuffle.scala
Normal file
657
src/scala/spark/CustomBlockedInMemoryShuffle.scala
Normal file
|
@ -0,0 +1,657 @@
|
|||
package spark
|
||||
|
||||
import java.io._
|
||||
import java.net._
|
||||
import java.util.{BitSet, Random, Timer, TimerTask, UUID}
|
||||
import java.util.concurrent.atomic.AtomicLong
|
||||
import java.util.concurrent.{LinkedBlockingQueue, Executors, ThreadPoolExecutor, ThreadFactory}
|
||||
|
||||
import scala.collection.mutable.{ArrayBuffer, HashMap}
|
||||
|
||||
/**
|
||||
* TODO: THIS IS AN ABSOLUTELY EXPERIMENTAL IMPLEMENTATON (FOR NOW).
|
||||
*
|
||||
* An implementation of shuffle using local memory served through custom server
|
||||
* where receivers create simultaneous connections to multiple servers by
|
||||
* setting the 'spark.blockedLocalFileShuffle.maxRxConnections' config option.
|
||||
*
|
||||
* By controlling the 'spark.blockedLocalFileShuffle.blockSize' config option
|
||||
* one can also control the largest block size to divide each map output into.
|
||||
* Essentially, instead of creating one large output file for each reducer, maps
|
||||
* create multiple smaller files to enable finer level of engagement.
|
||||
*
|
||||
* 'spark.parallelLocalFileShuffle.maxTxConnections' enforces server-side cap.
|
||||
* Ideally maxTxConnections >= maxRxConnections * numReducersPerMachine
|
||||
*
|
||||
* TODO: Add support for compression when spark.compress is set to true.
|
||||
*/
|
||||
@serializable
|
||||
class CustomBlockedInMemoryShuffle[K, V, C]
|
||||
extends Shuffle[K, V, C] with Logging {
|
||||
@transient var totalSplits = 0
|
||||
@transient var hasSplits = 0
|
||||
|
||||
@transient var totalBlocksInSplit: Array[Int] = null
|
||||
@transient var hasBlocksInSplit: Array[Int] = null
|
||||
|
||||
@transient var hasSplitsBitVector: BitSet = null
|
||||
@transient var splitsInRequestBitVector: BitSet = null
|
||||
|
||||
@transient var receivedData: LinkedBlockingQueue[(Int, Array[Byte])] = null
|
||||
@transient var combiners: HashMap[K,C] = null
|
||||
|
||||
override def compute(input: RDD[(K, V)],
|
||||
numOutputSplits: Int,
|
||||
createCombiner: V => C,
|
||||
mergeValue: (C, V) => C,
|
||||
mergeCombiners: (C, C) => C)
|
||||
: RDD[(K, C)] =
|
||||
{
|
||||
val sc = input.sparkContext
|
||||
val shuffleId = CustomBlockedInMemoryShuffle.newShuffleId()
|
||||
logInfo("Shuffle ID: " + shuffleId)
|
||||
|
||||
val splitRdd = new NumberedSplitRDD(input)
|
||||
val numInputSplits = splitRdd.splits.size
|
||||
|
||||
// Run a parallel map and collect to write the intermediate data files,
|
||||
// returning a list of inputSplitId -> serverUri pairs
|
||||
val outputLocs = splitRdd.map((pair: (Int, Iterator[(K, V)])) => {
|
||||
val myIndex = pair._1
|
||||
val myIterator = pair._2
|
||||
val buckets = Array.tabulate(numOutputSplits)(_ => new HashMap[K, C])
|
||||
for ((k, v) <- myIterator) {
|
||||
var bucketId = k.hashCode % numOutputSplits
|
||||
if (bucketId < 0) { // Fix bucket ID if hash code was negative
|
||||
bucketId += numOutputSplits
|
||||
}
|
||||
val bucket = buckets(bucketId)
|
||||
bucket(k) = bucket.get(k) match {
|
||||
case Some(c) => mergeValue(c, v)
|
||||
case None => createCombiner(v)
|
||||
}
|
||||
}
|
||||
|
||||
for (i <- 0 until numOutputSplits) {
|
||||
var blockNum = 0
|
||||
var isDirty = false
|
||||
|
||||
var splitName = ""
|
||||
var baos: ByteArrayOutputStream = null
|
||||
var oos: ObjectOutputStream = null
|
||||
|
||||
var writeStartTime: Long = 0
|
||||
|
||||
buckets(i).foreach(pair => {
|
||||
// Open a new file if necessary
|
||||
if (!isDirty) {
|
||||
splitName = CustomBlockedInMemoryShuffle.getSplitName(shuffleId,
|
||||
myIndex, i, blockNum)
|
||||
|
||||
baos = new ByteArrayOutputStream
|
||||
oos = new ObjectOutputStream(baos)
|
||||
|
||||
writeStartTime = System.currentTimeMillis
|
||||
logInfo("BEGIN WRITE: " + splitName)
|
||||
}
|
||||
|
||||
oos.writeObject(pair)
|
||||
isDirty = true
|
||||
|
||||
// Close the old file if has crossed the blockSize limit
|
||||
if (baos.size > CustomBlockedInMemoryShuffle.BlockSize) {
|
||||
CustomBlockedInMemoryShuffle.splitsCache(splitName) =
|
||||
baos.toByteArray
|
||||
|
||||
logInfo("END WRITE: " + splitName)
|
||||
val writeTime = System.currentTimeMillis - writeStartTime
|
||||
logInfo("Writing " + splitName + " of size " + baos.size + " bytes took " + writeTime + " millis.")
|
||||
|
||||
blockNum = blockNum + 1
|
||||
isDirty = false
|
||||
oos.close()
|
||||
}
|
||||
})
|
||||
|
||||
if (isDirty) {
|
||||
CustomBlockedInMemoryShuffle.splitsCache(splitName) = baos.toByteArray
|
||||
|
||||
logInfo("END WRITE: " + splitName)
|
||||
val writeTime = System.currentTimeMillis - writeStartTime
|
||||
logInfo("Writing " + splitName + " of size " + baos.size + " bytes took " + writeTime + " millis.")
|
||||
|
||||
blockNum = blockNum + 1
|
||||
oos.close()
|
||||
}
|
||||
|
||||
// Store BLOCKNUM info
|
||||
splitName = CustomBlockedInMemoryShuffle.getBlockNumOutputName(
|
||||
shuffleId, myIndex, i)
|
||||
baos = new ByteArrayOutputStream
|
||||
oos = new ObjectOutputStream(baos)
|
||||
oos.writeObject(blockNum)
|
||||
CustomBlockedInMemoryShuffle.splitsCache(splitName) = baos.toByteArray
|
||||
|
||||
// Close streams
|
||||
oos.close()
|
||||
}
|
||||
|
||||
(myIndex, CustomBlockedInMemoryShuffle.serverAddress,
|
||||
CustomBlockedInMemoryShuffle.serverPort)
|
||||
}).collect()
|
||||
|
||||
val splitsByUri = new ArrayBuffer[(String, Int, Int)]
|
||||
for ((inputId, serverAddress, serverPort) <- outputLocs) {
|
||||
splitsByUri += ((serverAddress, serverPort, inputId))
|
||||
}
|
||||
|
||||
// TODO: Could broadcast outputLocs
|
||||
|
||||
// Return an RDD that does each of the merges for a given partition
|
||||
val indexes = sc.parallelize(0 until numOutputSplits, numOutputSplits)
|
||||
return indexes.flatMap((myId: Int) => {
|
||||
totalSplits = outputLocs.size
|
||||
hasSplits = 0
|
||||
|
||||
totalBlocksInSplit = Array.tabulate(totalSplits)(_ => -1)
|
||||
hasBlocksInSplit = Array.tabulate(totalSplits)(_ => 0)
|
||||
|
||||
hasSplitsBitVector = new BitSet(totalSplits)
|
||||
splitsInRequestBitVector = new BitSet(totalSplits)
|
||||
|
||||
receivedData = new LinkedBlockingQueue[(Int, Array[Byte])]
|
||||
combiners = new HashMap[K, C]
|
||||
|
||||
// Start consumer
|
||||
var shuffleConsumer = new ShuffleConsumer(mergeCombiners)
|
||||
shuffleConsumer.setDaemon(true)
|
||||
shuffleConsumer.start()
|
||||
logInfo("ShuffleConsumer started...")
|
||||
|
||||
var threadPool = CustomBlockedInMemoryShuffle.newDaemonFixedThreadPool(
|
||||
CustomBlockedInMemoryShuffle.MaxRxConnections)
|
||||
|
||||
while (hasSplits < totalSplits) {
|
||||
var numThreadsToCreate =
|
||||
Math.min(totalSplits, CustomBlockedInMemoryShuffle.MaxRxConnections) -
|
||||
threadPool.getActiveCount
|
||||
|
||||
while (hasSplits < totalSplits && numThreadsToCreate > 0) {
|
||||
// Select a random split to pull
|
||||
val splitIndex = selectRandomSplit
|
||||
|
||||
if (splitIndex != -1) {
|
||||
val (serverAddress, serverPort, inputId) = splitsByUri(splitIndex)
|
||||
|
||||
threadPool.execute(new ShuffleClient(serverAddress, serverPort,
|
||||
shuffleId.toInt, inputId, myId, splitIndex))
|
||||
|
||||
// splitIndex is in transit. Will be unset in the ShuffleClient
|
||||
splitsInRequestBitVector.synchronized {
|
||||
splitsInRequestBitVector.set(splitIndex)
|
||||
}
|
||||
}
|
||||
|
||||
numThreadsToCreate = numThreadsToCreate - 1
|
||||
}
|
||||
|
||||
// Sleep for a while before creating new threads
|
||||
Thread.sleep(CustomBlockedInMemoryShuffle.MinKnockInterval)
|
||||
}
|
||||
|
||||
threadPool.shutdown()
|
||||
combiners
|
||||
})
|
||||
}
|
||||
|
||||
def selectRandomSplit: Int = {
|
||||
var requiredSplits = new ArrayBuffer[Int]
|
||||
|
||||
synchronized {
|
||||
for (i <- 0 until totalSplits) {
|
||||
if (!hasSplitsBitVector.get(i) && !splitsInRequestBitVector.get(i)) {
|
||||
requiredSplits += i
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (requiredSplits.size > 0) {
|
||||
requiredSplits(CustomBlockedInMemoryShuffle.ranGen.nextInt(
|
||||
requiredSplits.size))
|
||||
} else {
|
||||
-1
|
||||
}
|
||||
}
|
||||
|
||||
class ShuffleConsumer(mergeCombiners: (C, C) => C)
|
||||
extends Thread with Logging {
|
||||
override def run: Unit = {
|
||||
// Run until all splits are here
|
||||
while (hasSplits < totalSplits) {
|
||||
var splitIndex = -1
|
||||
var recvByteArray: Array[Byte] = null
|
||||
|
||||
try {
|
||||
var tempPair = receivedData.take().asInstanceOf[(Int, Array[Byte])]
|
||||
splitIndex = tempPair._1
|
||||
recvByteArray = tempPair._2
|
||||
} catch {
|
||||
case e: Exception => {
|
||||
logInfo("Exception during taking data from receivedData")
|
||||
}
|
||||
}
|
||||
|
||||
val inputStream =
|
||||
new ObjectInputStream(new ByteArrayInputStream(recvByteArray))
|
||||
|
||||
try{
|
||||
while (true) {
|
||||
val (k, c) = inputStream.readObject.asInstanceOf[(K, C)]
|
||||
combiners(k) = combiners.get(k) match {
|
||||
case Some(oldC) => mergeCombiners(oldC, c)
|
||||
case None => c
|
||||
}
|
||||
}
|
||||
} catch {
|
||||
case e: EOFException => { }
|
||||
}
|
||||
inputStream.close()
|
||||
|
||||
// Consumption completed. Update stats.
|
||||
hasBlocksInSplit(splitIndex) = hasBlocksInSplit(splitIndex) + 1
|
||||
|
||||
// Split has been received only if all the blocks have been received
|
||||
if (hasBlocksInSplit(splitIndex) == totalBlocksInSplit(splitIndex)) {
|
||||
hasSplitsBitVector.synchronized {
|
||||
hasSplitsBitVector.set(splitIndex)
|
||||
}
|
||||
hasSplits += 1
|
||||
}
|
||||
|
||||
// We have received splitIndex
|
||||
splitsInRequestBitVector.synchronized {
|
||||
splitsInRequestBitVector.set(splitIndex, false)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
class ShuffleClient(hostAddress: String, listenPort: Int, shuffleId: Int,
|
||||
inputId: Int, myId: Int, splitIndex: Int)
|
||||
extends Thread with Logging {
|
||||
private var peerSocketToSource: Socket = null
|
||||
private var oosSource: ObjectOutputStream = null
|
||||
private var oisSource: ObjectInputStream = null
|
||||
|
||||
private var receptionSucceeded = false
|
||||
|
||||
override def run: Unit = {
|
||||
// Setup the timeout mechanism
|
||||
var timeOutTask = new TimerTask {
|
||||
override def run: Unit = {
|
||||
cleanUpConnections()
|
||||
}
|
||||
}
|
||||
|
||||
var timeOutTimer = new Timer
|
||||
timeOutTimer.schedule(timeOutTask,
|
||||
CustomParallelLocalFileShuffle.MaxKnockInterval)
|
||||
|
||||
try {
|
||||
// Everything will break if BLOCKNUM is not correctly received
|
||||
// First get BLOCKNUM file if totalBlocksInSplit(splitIndex) is unknown
|
||||
peerSocketToSource = new Socket(hostAddress, listenPort)
|
||||
oosSource =
|
||||
new ObjectOutputStream(peerSocketToSource.getOutputStream)
|
||||
oosSource.flush()
|
||||
var isSource = peerSocketToSource.getInputStream
|
||||
oisSource = new ObjectInputStream(isSource)
|
||||
|
||||
// Send path information
|
||||
oosSource.writeObject((shuffleId, inputId, myId))
|
||||
|
||||
// TODO: Can be optimized. No need to do it everytime.
|
||||
// Receive BLOCKNUM
|
||||
totalBlocksInSplit(splitIndex) = oisSource.readObject.asInstanceOf[Int]
|
||||
|
||||
// Request specific block
|
||||
oosSource.writeObject(hasBlocksInSplit(splitIndex))
|
||||
|
||||
// Good to go. First, receive the length of the requested file
|
||||
var requestedFileLen = oisSource.readObject.asInstanceOf[Int]
|
||||
logInfo("Received requestedFileLen = " + requestedFileLen)
|
||||
|
||||
// Turn the timer OFF, if the sender responds before timeout
|
||||
timeOutTimer.cancel()
|
||||
|
||||
val requestSplit = "%d/%d/%d-%d".format(shuffleId, inputId, myId,
|
||||
hasBlocksInSplit(splitIndex))
|
||||
|
||||
// Receive the file
|
||||
if (requestedFileLen != -1) {
|
||||
val readStartTime = System.currentTimeMillis
|
||||
logInfo("BEGIN READ: http://%s:%d/shuffle/%s".format(hostAddress, listenPort, requestSplit))
|
||||
|
||||
// Receive data in an Array[Byte]
|
||||
var recvByteArray = new Array[Byte](requestedFileLen)
|
||||
var alreadyRead = 0
|
||||
var bytesRead = 0
|
||||
|
||||
while (alreadyRead != requestedFileLen) {
|
||||
bytesRead = isSource.read(recvByteArray, alreadyRead,
|
||||
requestedFileLen - alreadyRead)
|
||||
if (bytesRead > 0) {
|
||||
alreadyRead = alreadyRead + bytesRead
|
||||
}
|
||||
}
|
||||
|
||||
// Make it available to the consumer
|
||||
try {
|
||||
receivedData.put((splitIndex, recvByteArray))
|
||||
} catch {
|
||||
case e: Exception => {
|
||||
logInfo("Exception during putting data into receivedData")
|
||||
}
|
||||
}
|
||||
|
||||
// NOTE: Update of bitVectors are now done by the consumer
|
||||
|
||||
receptionSucceeded = true
|
||||
|
||||
logInfo("END READ: http://%s:%d/shuffle/%s".format(hostAddress, listenPort, requestSplit))
|
||||
val readTime = System.currentTimeMillis - readStartTime
|
||||
logInfo("Reading http://%s:%d/shuffle/%s".format(hostAddress, listenPort, requestSplit) + " took " + readTime + " millis.")
|
||||
} else {
|
||||
throw new SparkException("ShuffleServer " + hostAddress + " does not have " + requestSplit)
|
||||
}
|
||||
} catch {
|
||||
// EOFException is expected to happen because sender can break
|
||||
// connection due to timeout
|
||||
case eofe: java.io.EOFException => { }
|
||||
case e: Exception => {
|
||||
logInfo("ShuffleClient had a " + e)
|
||||
}
|
||||
} finally {
|
||||
// If reception failed, unset for future retry
|
||||
if (!receptionSucceeded) {
|
||||
splitsInRequestBitVector.synchronized {
|
||||
splitsInRequestBitVector.set(splitIndex, false)
|
||||
}
|
||||
}
|
||||
cleanUpConnections()
|
||||
}
|
||||
}
|
||||
|
||||
private def cleanUpConnections(): Unit = {
|
||||
if (oisSource != null) {
|
||||
oisSource.close()
|
||||
}
|
||||
if (oosSource != null) {
|
||||
oosSource.close()
|
||||
}
|
||||
if (peerSocketToSource != null) {
|
||||
peerSocketToSource.close()
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
object CustomBlockedInMemoryShuffle extends Logging {
|
||||
// Cache for keeping the splits around
|
||||
val splitsCache = new HashMap[String, Array[Byte]]
|
||||
|
||||
// Used thoughout the code for small and large waits/timeouts
|
||||
private var BlockSize_ = 1024 * 1024
|
||||
|
||||
private var MinKnockInterval_ = 1000
|
||||
private var MaxKnockInterval_ = 5000
|
||||
|
||||
// Maximum number of connections
|
||||
private var MaxRxConnections_ = 4
|
||||
private var MaxTxConnections_ = 8
|
||||
|
||||
private var initialized = false
|
||||
private var nextShuffleId = new AtomicLong(0)
|
||||
|
||||
// Variables initialized by initializeIfNeeded()
|
||||
private var shuffleDir: File = null
|
||||
|
||||
private var shuffleServer: ShuffleServer = null
|
||||
private var serverAddress = InetAddress.getLocalHost.getHostAddress
|
||||
private var serverPort: Int = -1
|
||||
|
||||
// Random number generator
|
||||
var ranGen = new Random
|
||||
|
||||
private def initializeIfNeeded() = synchronized {
|
||||
if (!initialized) {
|
||||
// Load config parameters
|
||||
BlockSize_ = System.getProperty(
|
||||
"spark.blockedInMemoryShuffle.blockSize", "1024").toInt * 1024
|
||||
|
||||
MinKnockInterval_ = System.getProperty(
|
||||
"spark.blockedInMemoryShuffle.minKnockInterval", "1000").toInt
|
||||
MaxKnockInterval_ = System.getProperty(
|
||||
"spark.blockedInMemoryShuffle.maxKnockInterval", "5000").toInt
|
||||
|
||||
MaxRxConnections_ = System.getProperty(
|
||||
"spark.blockedInMemoryShuffle.maxRxConnections", "4").toInt
|
||||
MaxTxConnections_ = System.getProperty(
|
||||
"spark.blockedInMemoryShuffle.maxTxConnections", "8").toInt
|
||||
|
||||
// TODO: localDir should be created by some mechanism common to Spark
|
||||
// so that it can be shared among shuffle, broadcast, etc
|
||||
val localDirRoot = System.getProperty("spark.local.dir", "/tmp")
|
||||
var tries = 0
|
||||
var foundLocalDir = false
|
||||
var localDir: File = null
|
||||
var localDirUuid: UUID = null
|
||||
while (!foundLocalDir && tries < 10) {
|
||||
tries += 1
|
||||
try {
|
||||
localDirUuid = UUID.randomUUID
|
||||
localDir = new File(localDirRoot, "spark-local-" + localDirUuid)
|
||||
if (!localDir.exists) {
|
||||
localDir.mkdirs()
|
||||
foundLocalDir = true
|
||||
}
|
||||
} catch {
|
||||
case e: Exception =>
|
||||
logWarning("Attempt " + tries + " to create local dir failed", e)
|
||||
}
|
||||
}
|
||||
if (!foundLocalDir) {
|
||||
logError("Failed 10 attempts to create local dir in " + localDirRoot)
|
||||
System.exit(1)
|
||||
}
|
||||
shuffleDir = new File(localDir, "shuffle")
|
||||
shuffleDir.mkdirs()
|
||||
logInfo("Shuffle dir: " + shuffleDir)
|
||||
|
||||
// Create and start the shuffleServer
|
||||
shuffleServer = new ShuffleServer
|
||||
shuffleServer.setDaemon(true)
|
||||
shuffleServer.start()
|
||||
logInfo("ShuffleServer started...")
|
||||
|
||||
initialized = true
|
||||
}
|
||||
}
|
||||
|
||||
def BlockSize = BlockSize_
|
||||
|
||||
def MinKnockInterval = MinKnockInterval_
|
||||
def MaxKnockInterval = MaxKnockInterval_
|
||||
|
||||
def MaxRxConnections = MaxRxConnections_
|
||||
def MaxTxConnections = MaxTxConnections_
|
||||
|
||||
def getSplitName(shuffleId: Long, inputId: Int, outputId: Int,
|
||||
blockId: Int): String = {
|
||||
initializeIfNeeded()
|
||||
// Adding shuffleDir is unnecessary. Added to keep the parsers working
|
||||
return "%s/%d/%d/%d-%d".format(shuffleDir, shuffleId, inputId, outputId,
|
||||
blockId)
|
||||
}
|
||||
|
||||
def getBlockNumOutputName(shuffleId: Long, inputId: Int,
|
||||
outputId: Int): String = {
|
||||
initializeIfNeeded()
|
||||
return "%s/%d/%d/BLOCKNUM-%d".format(shuffleDir, shuffleId, inputId,
|
||||
outputId)
|
||||
}
|
||||
|
||||
def newShuffleId(): Long = {
|
||||
nextShuffleId.getAndIncrement()
|
||||
}
|
||||
|
||||
// Returns a standard ThreadFactory except all threads are daemons
|
||||
private def newDaemonThreadFactory: ThreadFactory = {
|
||||
new ThreadFactory {
|
||||
def newThread(r: Runnable): Thread = {
|
||||
var t = Executors.defaultThreadFactory.newThread(r)
|
||||
t.setDaemon(true)
|
||||
return t
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Wrapper over newFixedThreadPool
|
||||
def newDaemonFixedThreadPool(nThreads: Int): ThreadPoolExecutor = {
|
||||
var threadPool =
|
||||
Executors.newFixedThreadPool(nThreads).asInstanceOf[ThreadPoolExecutor]
|
||||
|
||||
threadPool.setThreadFactory(newDaemonThreadFactory)
|
||||
|
||||
return threadPool
|
||||
}
|
||||
|
||||
class ShuffleServer
|
||||
extends Thread with Logging {
|
||||
var threadPool =
|
||||
newDaemonFixedThreadPool(CustomBlockedInMemoryShuffle.MaxTxConnections)
|
||||
|
||||
var serverSocket: ServerSocket = null
|
||||
|
||||
override def run: Unit = {
|
||||
serverSocket = new ServerSocket(0)
|
||||
serverPort = serverSocket.getLocalPort
|
||||
|
||||
logInfo("ShuffleServer started with " + serverSocket)
|
||||
logInfo("Local URI: http://" + serverAddress + ":" + serverPort)
|
||||
|
||||
try {
|
||||
while (true) {
|
||||
var clientSocket: Socket = null
|
||||
try {
|
||||
clientSocket = serverSocket.accept()
|
||||
} catch {
|
||||
case e: Exception => { }
|
||||
}
|
||||
if (clientSocket != null) {
|
||||
logInfo("Serve: Accepted new client connection:" + clientSocket)
|
||||
try {
|
||||
threadPool.execute(new ShuffleServerThread(clientSocket))
|
||||
} catch {
|
||||
// In failure, close socket here; else, the thread will close it
|
||||
case ioe: IOException => {
|
||||
clientSocket.close()
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
} finally {
|
||||
if (serverSocket != null) {
|
||||
logInfo("ShuffleServer now stopping...")
|
||||
serverSocket.close()
|
||||
}
|
||||
}
|
||||
// Shutdown the thread pool
|
||||
threadPool.shutdown()
|
||||
}
|
||||
|
||||
class ShuffleServerThread(val clientSocket: Socket)
|
||||
extends Thread with Logging {
|
||||
private val os = clientSocket.getOutputStream.asInstanceOf[OutputStream]
|
||||
os.flush()
|
||||
private val bos = new BufferedOutputStream(os)
|
||||
bos.flush()
|
||||
private val oos = new ObjectOutputStream(os)
|
||||
oos.flush()
|
||||
private val ois = new ObjectInputStream(clientSocket.getInputStream)
|
||||
|
||||
logInfo("new ShuffleServerThread is running")
|
||||
|
||||
override def run: Unit = {
|
||||
try {
|
||||
// Receive basic path information
|
||||
val (shuffleId, myIndex, outputId) =
|
||||
ois.readObject.asInstanceOf[(Int, Int, Int)]
|
||||
|
||||
var requestedSplit =
|
||||
"%s/%d/%d/%d".format(shuffleDir, shuffleId, myIndex, outputId)
|
||||
logInfo("requestedSplit: " + requestedSplit)
|
||||
|
||||
// Read BLOCKNUM and send back the total number of blocks
|
||||
val blockNumName = "%s/%d/%d/BLOCKNUM-%d".format(shuffleDir,
|
||||
shuffleId, myIndex, outputId)
|
||||
|
||||
val blockNumIn = new ObjectInputStream(new ByteArrayInputStream(
|
||||
CustomBlockedInMemoryShuffle.splitsCache(blockNumName)))
|
||||
val BLOCKNUM = blockNumIn.readObject.asInstanceOf[Int]
|
||||
blockNumIn.close()
|
||||
|
||||
oos.writeObject(BLOCKNUM)
|
||||
|
||||
// Receive specific block request
|
||||
val blockId = ois.readObject.asInstanceOf[Int]
|
||||
|
||||
// Ready to send
|
||||
requestedSplit = requestedSplit + "-" + blockId
|
||||
|
||||
// Send the length of the requestedSplit to let the receiver know that
|
||||
// transfer is about to start
|
||||
// In the case of receiver timeout and connection close, this will
|
||||
// throw a java.net.SocketException: Broken pipe
|
||||
var requestedSplitLen = -1
|
||||
|
||||
try {
|
||||
requestedSplitLen =
|
||||
CustomBlockedInMemoryShuffle.splitsCache(requestedSplit).length
|
||||
} catch {
|
||||
case e: Exception => { }
|
||||
}
|
||||
|
||||
oos.writeObject(requestedSplitLen)
|
||||
oos.flush()
|
||||
|
||||
logInfo("requestedSplitLen = " + requestedSplitLen)
|
||||
|
||||
// Read and send the requested file
|
||||
if (requestedSplitLen != -1) {
|
||||
// Send
|
||||
bos.write(CustomBlockedInMemoryShuffle.splitsCache(requestedSplit),
|
||||
0, requestedSplitLen)
|
||||
bos.flush()
|
||||
} else {
|
||||
// Close the connection
|
||||
}
|
||||
} catch {
|
||||
// If something went wrong, e.g., the worker at the other end died etc
|
||||
// then close everything up
|
||||
// Exception can happen if the receiver stops receiving
|
||||
case e: Exception => {
|
||||
logInfo("ShuffleServerThread had a " + e)
|
||||
}
|
||||
} finally {
|
||||
logInfo("ShuffleServerThread is closing streams and sockets")
|
||||
ois.close()
|
||||
// TODO: Following can cause "java.net.SocketException: Socket closed"
|
||||
oos.close()
|
||||
bos.close()
|
||||
clientSocket.close()
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
|
@ -506,7 +506,7 @@ object CustomParallelInMemoryShuffle extends Logging {
|
|||
shuffleDir + "/" + ois.readObject.asInstanceOf[String]
|
||||
logInfo("requestedSplit: " + requestedSplit)
|
||||
|
||||
// Send the lendth of the requestedSplit to let the receiver know that
|
||||
// Send the length of the requestedSplit to let the receiver know that
|
||||
// transfer is about to start
|
||||
// In the case of receiver timeout and connection close, this will
|
||||
// throw a java.net.SocketException: Broken pipe
|
||||
|
|
Loading…
Reference in a new issue