Bug fixes. Not yet parallel.

This commit is contained in:
Mosharaf Chowdhury 2010-12-04 00:06:47 -08:00
parent 52086cef32
commit 0d7ca7751e

View file

@ -2,7 +2,7 @@ package spark
import java.io._
import java.net._
import java.util.{Timer, TimerTask, UUID}
import java.util.{BitSet, Random, Timer, TimerTask, UUID}
import java.util.concurrent.atomic.AtomicLong
import java.util.concurrent.{Executors, ThreadPoolExecutor, ThreadFactory}
@ -15,6 +15,11 @@ import scala.collection.mutable.{ArrayBuffer, HashMap}
*/
@serializable
class LocalFileShuffle[K, V, C] extends Shuffle[K, V, C] with Logging {
@transient var totalSplits = 0
@transient var hasSplits = 0
@transient var hasSplitsBitVector: BitSet = null
@transient var combiners: HashMap[K,C] = null
override def compute(input: RDD[(K, V)],
numOutputSplits: Int,
createCombiner: V => C,
@ -71,11 +76,20 @@ class LocalFileShuffle[K, V, C] extends Shuffle[K, V, C] with Logging {
// Return an RDD that does each of the merges for a given partition
val indexes = sc.parallelize(0 until numOutputSplits, numOutputSplits)
return indexes.flatMap((myId: Int) => {
val combiners = new HashMap[K, C]
for ((serverAddress, serverPort, inputId) <- splitsByUri) {
totalSplits = splitsByUri.size
hasSplitsBitVector = new BitSet (totalSplits)
combiners = new HashMap[K, C]
while (hasSplits < totalSplits) {
// Select a random split to pull
val splitIndex = selectRandomSplit
val (serverAddress, serverPort, inputId) =
splitsByUri (splitIndex)
val requestPath = "%d/%d/%d".format(shuffleId, inputId, myId)
val shuffleClient = new ShuffleClient(serverAddress, serverPort, requestPath)
val shuffleClient =
new ShuffleClient(serverAddress, serverPort, requestPath)
val readStartTime = System.currentTimeMillis
logInfo ("BEGIN READ: " + requestPath)
shuffleClient.start
@ -96,6 +110,11 @@ class LocalFileShuffle[K, V, C] extends Shuffle[K, V, C] with Logging {
}
inputStream.close
hasSplits += 1
hasSplitsBitVector.synchronized {
hasSplitsBitVector.set (splitIndex)
}
logInfo ("END READ: " + requestPath)
val readTime = (System.currentTimeMillis - readStartTime)
logInfo ("Reading " + requestPath + " took " + readTime + " millis.")
@ -103,6 +122,107 @@ class LocalFileShuffle[K, V, C] extends Shuffle[K, V, C] with Logging {
combiners
})
}
def selectRandomSplit: Int = {
var requiredSplits = new ArrayBuffer[Int]
hasSplitsBitVector.synchronized {
for (i <- 0 until totalSplits) {
if (!hasSplitsBitVector.get(i)) {
requiredSplits += i
}
}
}
if (requiredSplits.size > 0) {
requiredSplits(LocalFileShuffle.ranGen.nextInt (requiredSplits.size))
} else {
-1
}
}
class ShuffleClient (hostAddress: String, listenPort: Int, requestPath: String)
extends Thread with Logging {
private var peerSocketToSource: Socket = null
private var oosSource: ObjectOutputStream = null
private var oisSource: ObjectInputStream = null
var byteArray: Array[Byte] = null
override def run: Unit = {
// Setup the timeout mechanism
var timeOutTask = new TimerTask {
override def run: Unit = {
cleanUpConnections
}
}
var timeOutTimer = new Timer
// TODO: Set wait timer
// TODO: If its too small, things FAIL
timeOutTimer.schedule (timeOutTask, 10000)
logInfo ("ShuffleClient started... => %s:%d#%s".format(hostAddress, listenPort, requestPath))
try {
// Connect to the source
peerSocketToSource = new Socket (hostAddress, listenPort)
oosSource =
new ObjectOutputStream (peerSocketToSource.getOutputStream)
oosSource.flush
var isSource = peerSocketToSource.getInputStream
oisSource = new ObjectInputStream (isSource)
// Send the request
oosSource.writeObject(requestPath)
// Receive the length of the requested file
var requestedFileLen = oisSource.readObject.asInstanceOf[Int]
logInfo ("Received requestedFileLen = " + requestedFileLen)
// Turn the timer OFF, if the sender responds before timeout
timeOutTimer.cancel
// Receive the file
if (requestedFileLen != -1) {
byteArray = new Array[Byte] (requestedFileLen)
var bytesRead = isSource.read (byteArray, 0, byteArray.length)
var alreadyRead = bytesRead
while (alreadyRead < requestedFileLen) {
bytesRead = isSource.read(byteArray, alreadyRead,
(byteArray.length - alreadyRead))
if(bytesRead > 0) {
alreadyRead = alreadyRead + bytesRead
}
}
} else {
throw new SparkException("ShuffleServer " + hostAddress + " does not have " + requestPath)
}
} catch {
// EOFException is expected to happen because sender can break
// connection due to timeout
case eofe: java.io.EOFException => { }
case e: Exception => {
logInfo ("ShuffleClient had a " + e)
}
} finally {
cleanUpConnections
}
}
private def cleanUpConnections: Unit = {
if (oisSource != null) {
oisSource.close
}
if (oosSource != null) {
oosSource.close
}
if (peerSocketToSource != null) {
peerSocketToSource.close
}
}
}
}
object LocalFileShuffle extends Logging {
@ -116,6 +236,9 @@ object LocalFileShuffle extends Logging {
private var serverAddress = InetAddress.getLocalHost.getHostAddress
private var serverPort: Int = -1
// Random number generator
var ranGen = new Random
private def initializeIfNeeded() = synchronized {
if (!initialized) {
// TODO: localDir should be created by some mechanism common to Spark
@ -203,7 +326,7 @@ object LocalFileShuffle extends Logging {
serverPort = serverSocket.getLocalPort
logInfo ("ShuffleServer started with " + serverSocket)
logInfo ("Local URI: " + serverAddress + ":" + serverPort)
logInfo ("Local URI: http://" + serverAddress + ":" + serverPort)
try {
while (true) {
@ -239,6 +362,8 @@ object LocalFileShuffle extends Logging {
extends Thread with Logging {
private val os = clientSocket.getOutputStream.asInstanceOf[OutputStream]
os.flush
private val bos = new BufferedOutputStream (os)
bos.flush
private val oos = new ObjectOutputStream (os)
oos.flush
private val ois = new ObjectInputStream (clientSocket.getInputStream)
@ -286,13 +411,12 @@ object LocalFileShuffle extends Logging {
if(bytesRead > 0) {
alreadyRead = alreadyRead + bytesRead
}
}
}
bis.close
// Send
os.write (byteArray, 0, byteArray.length)
os.flush
bos.write (byteArray, 0, byteArray.length)
bos.flush
} else {
// Close the connection
}
@ -307,92 +431,11 @@ object LocalFileShuffle extends Logging {
logInfo ("ShuffleServerThread is closing streams and sockets")
ois.close
// TODO: Following can cause "java.net.SocketException: Socket closed"
oos.close
oos.close
bos.close
clientSocket.close
}
}
}
}
}
class ShuffleClient (hostAddress: String, listenPort: Int, requestPath: String)
extends Thread with Logging {
private var peerSocketToSource: Socket = null
private var oosSource: ObjectOutputStream = null
private var oisSource: ObjectInputStream = null
var byteArray: Array[Byte] = null
override def run: Unit = {
// Setup the timeout mechanism
var timeOutTask = new TimerTask {
override def run: Unit = {
cleanUpConnections
}
}
var timeOutTimer = new Timer
// TODO: Set wait timer
timeOutTimer.schedule (timeOutTask, 1000)
logInfo ("ShuffleClient started... => %s:%d#%s".format(hostAddress, listenPort, requestPath))
try {
// Connect to the source
peerSocketToSource = new Socket (hostAddress, listenPort)
oosSource =
new ObjectOutputStream (peerSocketToSource.getOutputStream)
oosSource.flush
var isSource = peerSocketToSource.getInputStream
oisSource = new ObjectInputStream (isSource)
// Send the request
oosSource.writeObject(requestPath)
// Receive the length of the requested file
var requestedFileLen = oisSource.readObject.asInstanceOf[Int]
logInfo ("Received requestedFileLen = " + requestedFileLen)
// Turn the timer OFF, if the sender responds before timeout
timeOutTimer.cancel
// Receive the file
if (requestedFileLen != -1) {
byteArray = new Array[Byte] (requestedFileLen)
var bytesRead = isSource.read (byteArray, 0, byteArray.length)
var alreadyRead = bytesRead
while (alreadyRead < requestedFileLen) {
bytesRead = isSource.read(byteArray, alreadyRead,
(byteArray.length - alreadyRead))
if(bytesRead > 0) {
alreadyRead = alreadyRead + bytesRead
}
}
} else {
throw new SparkException("ShuffleServer " + hostAddress + " does not have " + requestPath)
}
} catch {
// EOFException is expected to happen because sender can break
// connection due to timeout
case eofe: java.io.EOFException => { }
case e: Exception => {
logInfo ("ShuffleClient had a " + e)
}
} finally {
cleanUpConnections
}
}
private def cleanUpConnections: Unit = {
if (oisSource != null) {
oisSource.close
}
if (oosSource != null) {
oosSource.close
}
if (peerSocketToSource != null) {
peerSocketToSource.close
}
}
}