Some bug fixes and logging fixes for broadcast.

This commit is contained in:
Matei Zaharia 2012-10-01 15:20:42 -07:00
parent c1db5a849b
commit 802aa8aef9
10 changed files with 113 additions and 108 deletions

View file

@ -99,7 +99,7 @@ abstract class RDD[T: ClassManifest](@transient sc: SparkContext) extends Serial
def getStorageLevel = storageLevel
def checkpoint(level: StorageLevel = StorageLevel.MEMORY_AND_DISK_2): RDD[T] = {
private[spark] def checkpoint(level: StorageLevel = StorageLevel.MEMORY_AND_DISK_2): RDD[T] = {
if (!level.useDisk && level.replication < 2) {
throw new Exception("Cannot checkpoint without using disk or replication (level requested was " + level + ")")
}

View file

@ -11,14 +11,17 @@ import scala.math
import spark._
import spark.storage.StorageLevel
class BitTorrentBroadcast[T](@transient var value_ : T, isLocal: Boolean)
extends Broadcast[T] with Logging with Serializable {
class BitTorrentBroadcast[T](@transient var value_ : T, isLocal: Boolean, id: Long)
extends Broadcast[T](id)
with Logging
with Serializable {
def value = value_
def blockId: String = "broadcast_" + id
MultiTracker.synchronized {
SparkEnv.get.blockManager.putSingle(
uuid.toString, value_, StorageLevel.MEMORY_AND_DISK, false)
SparkEnv.get.blockManager.putSingle(blockId, value_, StorageLevel.MEMORY_AND_DISK, false)
}
@transient var arrayOfBlocks: Array[BroadcastBlock] = null
@ -45,7 +48,7 @@ extends Broadcast[T] with Logging with Serializable {
// Used only in Workers
@transient var ttGuide: TalkToGuide = null
@transient var hostAddress = Utils.localIpAddress
@transient var hostAddress = Utils.localIpAddress()
@transient var listenPort = -1
@transient var guidePort = -1
@ -106,17 +109,19 @@ extends Broadcast[T] with Logging with Serializable {
listOfSources += masterSource
// Register with the Tracker
MultiTracker.registerBroadcast(uuid,
MultiTracker.registerBroadcast(id,
SourceInfo(hostAddress, guidePort, totalBlocks, totalBytes))
}
private def readObject(in: ObjectInputStream) {
in.defaultReadObject()
MultiTracker.synchronized {
SparkEnv.get.blockManager.getSingle(uuid.toString) match {
case Some(x) => x.asInstanceOf[T]
case None => {
logInfo("Started reading broadcast variable " + uuid)
SparkEnv.get.blockManager.getSingle(blockId) match {
case Some(x) =>
value_ = x.asInstanceOf[T]
case None =>
logInfo("Started reading broadcast variable " + id)
// Initializing everything because Master will only send null/0 values
// Only the 1st worker in a node can be here. Others will get from cache
initializeWorkerVariables()
@ -131,18 +136,17 @@ extends Broadcast[T] with Logging with Serializable {
val start = System.nanoTime
val receptionSucceeded = receiveBroadcast(uuid)
val receptionSucceeded = receiveBroadcast(id)
if (receptionSucceeded) {
value_ = MultiTracker.unBlockifyObject[T](arrayOfBlocks, totalBytes, totalBlocks)
SparkEnv.get.blockManager.putSingle(
uuid.toString, value_, StorageLevel.MEMORY_AND_DISK, false)
blockId, value_, StorageLevel.MEMORY_AND_DISK, false)
} else {
logError("Reading Broadcasted variable " + uuid + " failed")
logError("Reading broadcast variable " + id + " failed")
}
val time = (System.nanoTime - start) / 1e9
logInfo("Reading Broadcasted variable " + uuid + " took " + time + " s")
}
logInfo("Reading broadcast variable " + id + " took " + time + " s")
}
}
}
@ -254,8 +258,8 @@ extends Broadcast[T] with Logging with Serializable {
}
}
def receiveBroadcast(variableUUID: UUID): Boolean = {
val gInfo = MultiTracker.getGuideInfo(variableUUID)
def receiveBroadcast(variableID: Long): Boolean = {
val gInfo = MultiTracker.getGuideInfo(variableID)
if (gInfo.listenPort == SourceInfo.TxOverGoToDefault) {
return false
@ -760,7 +764,7 @@ extends Broadcast[T] with Logging with Serializable {
logInfo("Sending stopBroadcast notifications...")
sendStopBroadcastNotifications
MultiTracker.unregisterBroadcast(uuid)
MultiTracker.unregisterBroadcast(id)
} finally {
if (serverSocket != null) {
logInfo("GuideMultipleRequests now stopping...")
@ -1025,7 +1029,10 @@ extends Broadcast[T] with Logging with Serializable {
class BitTorrentBroadcastFactory
extends BroadcastFactory {
def initialize(isMaster: Boolean) = MultiTracker.initialize(isMaster)
def newBroadcast[T](value_ : T, isLocal: Boolean) = new BitTorrentBroadcast[T](value_, isLocal)
def stop() = MultiTracker.stop
def initialize(isMaster: Boolean) { MultiTracker.initialize(isMaster) }
def newBroadcast[T](value_ : T, isLocal: Boolean, id: Long) =
new BitTorrentBroadcast[T](value_, isLocal, id)
def stop() { MultiTracker.stop() }
}

View file

@ -1,23 +1,17 @@
package spark.broadcast
import java.io._
import java.net._
import java.util.{BitSet, UUID}
import java.util.concurrent.{Executors, ThreadFactory, ThreadPoolExecutor}
import scala.collection.mutable.Map
import java.util.concurrent.atomic.AtomicLong
import spark._
trait Broadcast[T] extends Serializable {
val uuid = UUID.randomUUID
abstract class Broadcast[T](id: Long) extends Serializable {
def value: T
// We cannot have an abstract readObject here due to some weird issues with
// readObject having to be 'private' in sub-classes.
override def toString = "spark.Broadcast(" + uuid + ")"
override def toString = "spark.Broadcast(" + id + ")"
}
class BroadcastManager(val isMaster_ : Boolean) extends Logging with Serializable {
@ -49,14 +43,10 @@ class BroadcastManager(val isMaster_ : Boolean) extends Logging with Serializabl
broadcastFactory.stop()
}
private def getBroadcastFactory: BroadcastFactory = {
if (broadcastFactory == null) {
throw new SparkException ("Broadcast.getBroadcastFactory called before initialize")
}
broadcastFactory
}
private val nextBroadcastId = new AtomicLong(0)
def newBroadcast[T](value_ : T, isLocal: Boolean) = broadcastFactory.newBroadcast[T](value_, isLocal)
def newBroadcast[T](value_ : T, isLocal: Boolean) =
broadcastFactory.newBroadcast[T](value_, isLocal, nextBroadcastId.getAndIncrement())
def isMaster = isMaster_
}

View file

@ -8,6 +8,6 @@ package spark.broadcast
*/
trait BroadcastFactory {
def initialize(isMaster: Boolean): Unit
def newBroadcast[T](value_ : T, isLocal: Boolean): Broadcast[T]
def newBroadcast[T](value_ : T, isLocal: Boolean, id: Long): Broadcast[T]
def stop(): Unit
}

View file

@ -12,34 +12,34 @@ import it.unimi.dsi.fastutil.io.FastBufferedOutputStream
import spark._
import spark.storage.StorageLevel
class HttpBroadcast[T](@transient var value_ : T, isLocal: Boolean)
extends Broadcast[T] with Logging with Serializable {
class HttpBroadcast[T](@transient var value_ : T, isLocal: Boolean, id: Long)
extends Broadcast[T](id) with Logging with Serializable {
def value = value_
def blockId: String = "broadcast_" + id
HttpBroadcast.synchronized {
SparkEnv.get.blockManager.putSingle(
uuid.toString, value_, StorageLevel.MEMORY_AND_DISK, false)
SparkEnv.get.blockManager.putSingle(blockId, value_, StorageLevel.MEMORY_AND_DISK, false)
}
if (!isLocal) {
HttpBroadcast.write(uuid, value_)
HttpBroadcast.write(id, value_)
}
// Called by JVM when deserializing an object
private def readObject(in: ObjectInputStream) {
in.defaultReadObject()
HttpBroadcast.synchronized {
SparkEnv.get.blockManager.getSingle(uuid.toString) match {
SparkEnv.get.blockManager.getSingle(blockId) match {
case Some(x) => value_ = x.asInstanceOf[T]
case None => {
logInfo("Started reading broadcast variable " + uuid)
logInfo("Started reading broadcast variable " + id)
val start = System.nanoTime
value_ = HttpBroadcast.read[T](uuid)
SparkEnv.get.blockManager.putSingle(
uuid.toString, value_, StorageLevel.MEMORY_AND_DISK, false)
value_ = HttpBroadcast.read[T](id)
SparkEnv.get.blockManager.putSingle(blockId, value_, StorageLevel.MEMORY_AND_DISK, false)
val time = (System.nanoTime - start) / 1e9
logInfo("Reading broadcast variable " + uuid + " took " + time + " s")
logInfo("Reading broadcast variable " + id + " took " + time + " s")
}
}
}
@ -47,9 +47,12 @@ extends Broadcast[T] with Logging with Serializable {
}
class HttpBroadcastFactory extends BroadcastFactory {
def initialize(isMaster: Boolean) = HttpBroadcast.initialize(isMaster)
def newBroadcast[T](value_ : T, isLocal: Boolean) = new HttpBroadcast[T](value_, isLocal)
def stop() = HttpBroadcast.stop()
def initialize(isMaster: Boolean) { HttpBroadcast.initialize(isMaster) }
def newBroadcast[T](value_ : T, isLocal: Boolean, id: Long) =
new HttpBroadcast[T](value_, isLocal, id)
def stop() { HttpBroadcast.stop() }
}
private object HttpBroadcast extends Logging {
@ -94,8 +97,8 @@ private object HttpBroadcast extends Logging {
logInfo("Broadcast server started at " + serverUri)
}
def write(uuid: UUID, value: Any) {
val file = new File(broadcastDir, "broadcast-" + uuid)
def write(id: Long, value: Any) {
val file = new File(broadcastDir, "broadcast-" + id)
val out: OutputStream = if (compress) {
new LZFOutputStream(new FileOutputStream(file)) // Does its own buffering
} else {
@ -107,8 +110,8 @@ private object HttpBroadcast extends Logging {
serOut.close()
}
def read[T](uuid: UUID): T = {
val url = serverUri + "/broadcast-" + uuid
def read[T](id: Long): T = {
val url = serverUri + "/broadcast-" + id
var in = if (compress) {
new LZFInputStream(new URL(url).openStream()) // Does its own buffering
} else {

View file

@ -2,8 +2,7 @@ package spark.broadcast
import java.io._
import java.net._
import java.util.{UUID, Random}
import java.util.concurrent.{Executors, ThreadFactory, ThreadPoolExecutor}
import java.util.Random
import scala.collection.mutable.Map
@ -18,7 +17,7 @@ extends Logging {
val FIND_BROADCAST_TRACKER = 2
// Map to keep track of guides of ongoing broadcasts
var valueToGuideMap = Map[UUID, SourceInfo]()
var valueToGuideMap = Map[Long, SourceInfo]()
// Random number generator
var ranGen = new Random
@ -154,44 +153,44 @@ extends Logging {
val messageType = ois.readObject.asInstanceOf[Int]
if (messageType == REGISTER_BROADCAST_TRACKER) {
// Receive UUID
val uuid = ois.readObject.asInstanceOf[UUID]
// Receive Long
val id = ois.readObject.asInstanceOf[Long]
// Receive hostAddress and listenPort
val gInfo = ois.readObject.asInstanceOf[SourceInfo]
// Add to the map
valueToGuideMap.synchronized {
valueToGuideMap += (uuid -> gInfo)
valueToGuideMap += (id -> gInfo)
}
logInfo ("New broadcast " + uuid + " registered with TrackMultipleValues. Ongoing ones: " + valueToGuideMap)
logInfo ("New broadcast " + id + " registered with TrackMultipleValues. Ongoing ones: " + valueToGuideMap)
// Send dummy ACK
oos.writeObject(-1)
oos.flush()
} else if (messageType == UNREGISTER_BROADCAST_TRACKER) {
// Receive UUID
val uuid = ois.readObject.asInstanceOf[UUID]
// Receive Long
val id = ois.readObject.asInstanceOf[Long]
// Remove from the map
valueToGuideMap.synchronized {
valueToGuideMap(uuid) = SourceInfo("", SourceInfo.TxOverGoToDefault)
valueToGuideMap(id) = SourceInfo("", SourceInfo.TxOverGoToDefault)
}
logInfo ("Broadcast " + uuid + " unregistered from TrackMultipleValues. Ongoing ones: " + valueToGuideMap)
logInfo ("Broadcast " + id + " unregistered from TrackMultipleValues. Ongoing ones: " + valueToGuideMap)
// Send dummy ACK
oos.writeObject(-1)
oos.flush()
} else if (messageType == FIND_BROADCAST_TRACKER) {
// Receive UUID
val uuid = ois.readObject.asInstanceOf[UUID]
// Receive Long
val id = ois.readObject.asInstanceOf[Long]
var gInfo =
if (valueToGuideMap.contains(uuid)) valueToGuideMap(uuid)
if (valueToGuideMap.contains(id)) valueToGuideMap(id)
else SourceInfo("", SourceInfo.TxNotStartedRetry)
logDebug("Got new request: " + clientSocket + " for " + uuid + " : " + gInfo.listenPort)
logDebug("Got new request: " + clientSocket + " for " + id + " : " + gInfo.listenPort)
// Send reply back
oos.writeObject(gInfo)
@ -224,7 +223,7 @@ extends Logging {
}
}
def getGuideInfo(variableUUID: UUID): SourceInfo = {
def getGuideInfo(variableLong: Long): SourceInfo = {
var clientSocketToTracker: Socket = null
var oosTracker: ObjectOutputStream = null
var oisTracker: ObjectInputStream = null
@ -247,8 +246,8 @@ extends Logging {
oosTracker.writeObject(MultiTracker.FIND_BROADCAST_TRACKER)
oosTracker.flush()
// Send UUID and receive GuideInfo
oosTracker.writeObject(variableUUID)
// Send Long and receive GuideInfo
oosTracker.writeObject(variableLong)
oosTracker.flush()
gInfo = oisTracker.readObject.asInstanceOf[SourceInfo]
} catch {
@ -276,7 +275,7 @@ extends Logging {
return gInfo
}
def registerBroadcast(uuid: UUID, gInfo: SourceInfo) {
def registerBroadcast(id: Long, gInfo: SourceInfo) {
val socket = new Socket(MultiTracker.MasterHostAddress, MasterTrackerPort)
val oosST = new ObjectOutputStream(socket.getOutputStream)
oosST.flush()
@ -286,8 +285,8 @@ extends Logging {
oosST.writeObject(REGISTER_BROADCAST_TRACKER)
oosST.flush()
// Send UUID of this broadcast
oosST.writeObject(uuid)
// Send Long of this broadcast
oosST.writeObject(id)
oosST.flush()
// Send this tracker's information
@ -303,7 +302,7 @@ extends Logging {
socket.close()
}
def unregisterBroadcast(uuid: UUID) {
def unregisterBroadcast(id: Long) {
val socket = new Socket(MultiTracker.MasterHostAddress, MasterTrackerPort)
val oosST = new ObjectOutputStream(socket.getOutputStream)
oosST.flush()
@ -313,8 +312,8 @@ extends Logging {
oosST.writeObject(UNREGISTER_BROADCAST_TRACKER)
oosST.flush()
// Send UUID of this broadcast
oosST.writeObject(uuid)
// Send Long of this broadcast
oosST.writeObject(id)
oosST.flush()
// Receive ACK and throw it away

View file

@ -10,14 +10,15 @@ import scala.math
import spark._
import spark.storage.StorageLevel
class TreeBroadcast[T](@transient var value_ : T, isLocal: Boolean)
extends Broadcast[T] with Logging with Serializable {
class TreeBroadcast[T](@transient var value_ : T, isLocal: Boolean, id: Long)
extends Broadcast[T](id) with Logging with Serializable {
def value = value_
def blockId = "broadcast_" + id
MultiTracker.synchronized {
SparkEnv.get.blockManager.putSingle(
uuid.toString, value_, StorageLevel.MEMORY_AND_DISK, false)
SparkEnv.get.blockManager.putSingle(blockId, value_, StorageLevel.MEMORY_AND_DISK, false)
}
@transient var arrayOfBlocks: Array[BroadcastBlock] = null
@ -35,7 +36,7 @@ extends Broadcast[T] with Logging with Serializable {
@transient var serveMR: ServeMultipleRequests = null
@transient var guideMR: GuideMultipleRequests = null
@transient var hostAddress = Utils.localIpAddress
@transient var hostAddress = Utils.localIpAddress()
@transient var listenPort = -1
@transient var guidePort = -1
@ -43,7 +44,7 @@ extends Broadcast[T] with Logging with Serializable {
// Must call this after all the variables have been created/initialized
if (!isLocal) {
sendBroadcast
sendBroadcast()
}
def sendBroadcast() {
@ -84,20 +85,22 @@ extends Broadcast[T] with Logging with Serializable {
listOfSources += masterSource
// Register with the Tracker
MultiTracker.registerBroadcast(uuid,
MultiTracker.registerBroadcast(id,
SourceInfo(hostAddress, guidePort, totalBlocks, totalBytes))
}
private def readObject(in: ObjectInputStream) {
in.defaultReadObject()
MultiTracker.synchronized {
SparkEnv.get.blockManager.getSingle(uuid.toString) match {
case Some(x) => x.asInstanceOf[T]
case None => {
logInfo("Started reading broadcast variable " + uuid)
SparkEnv.get.blockManager.getSingle(blockId) match {
case Some(x) =>
value_ = x.asInstanceOf[T]
case None =>
logInfo("Started reading broadcast variable " + id)
// Initializing everything because Master will only send null/0 values
// Only the 1st worker in a node can be here. Others will get from cache
initializeWorkerVariables
initializeWorkerVariables()
logInfo("Local host address: " + hostAddress)
@ -108,18 +111,17 @@ extends Broadcast[T] with Logging with Serializable {
val start = System.nanoTime
val receptionSucceeded = receiveBroadcast(uuid)
val receptionSucceeded = receiveBroadcast(id)
if (receptionSucceeded) {
value_ = MultiTracker.unBlockifyObject[T](arrayOfBlocks, totalBytes, totalBlocks)
SparkEnv.get.blockManager.putSingle(
uuid.toString, value_, StorageLevel.MEMORY_AND_DISK, false)
blockId, value_, StorageLevel.MEMORY_AND_DISK, false)
} else {
logError("Reading Broadcasted variable " + uuid + " failed")
logError("Reading broadcast variable " + id + " failed")
}
val time = (System.nanoTime - start) / 1e9
logInfo("Reading Broadcasted variable " + uuid + " took " + time + " s")
}
logInfo("Reading broadcast variable " + id + " took " + time + " s")
}
}
}
@ -136,14 +138,14 @@ extends Broadcast[T] with Logging with Serializable {
serveMR = null
hostAddress = Utils.localIpAddress
hostAddress = Utils.localIpAddress()
listenPort = -1
stopBroadcast = false
}
def receiveBroadcast(variableUUID: UUID): Boolean = {
val gInfo = MultiTracker.getGuideInfo(variableUUID)
def receiveBroadcast(variableID: Long): Boolean = {
val gInfo = MultiTracker.getGuideInfo(variableID)
if (gInfo.listenPort == SourceInfo.TxOverGoToDefault) {
return false
@ -316,7 +318,7 @@ extends Broadcast[T] with Logging with Serializable {
logInfo("Sending stopBroadcast notifications...")
sendStopBroadcastNotifications
MultiTracker.unregisterBroadcast(uuid)
MultiTracker.unregisterBroadcast(id)
} finally {
if (serverSocket != null) {
logInfo("GuideMultipleRequests now stopping...")
@ -572,7 +574,10 @@ extends Broadcast[T] with Logging with Serializable {
class TreeBroadcastFactory
extends BroadcastFactory {
def initialize(isMaster: Boolean) = MultiTracker.initialize(isMaster)
def newBroadcast[T](value_ : T, isLocal: Boolean) = new TreeBroadcast[T](value_, isLocal)
def stop() = MultiTracker.stop
def initialize(isMaster: Boolean) { MultiTracker.initialize(isMaster) }
def newBroadcast[T](value_ : T, isLocal: Boolean, id: Long) =
new TreeBroadcast[T](value_, isLocal, id)
def stop() { MultiTracker.stop() }
}

View file

@ -82,8 +82,6 @@ class BlockManager(val master: BlockManagerMaster, val serializer: Serializer, m
val compress = System.getProperty("spark.blockManager.compress", "false").toBoolean
initLogging()
initialize()
/**

View file

@ -147,6 +147,8 @@ private class MemoryStore(blockManager: BlockManager, maxMemory: Long)
logInfo("MemoryStore cleared")
}
// TODO: This should be able to return false if the space is larger than our total memory,
// or if adding this block would require evicting another one from the same RDD
private def ensureFreeSpace(space: Long) {
logInfo("ensureFreeSpace(%d) called with curMem=%d, maxMem=%d".format(
space, currentMemory, maxMemory))

View file

@ -73,7 +73,8 @@ there are at least four properties that you will commonly want to control:
<td>/tmp</td>
<td>
Directory to use for "scratch" space in Spark, including map output files and RDDs that get stored
on disk. This should be on a fast, local disk in your system.
on disk. This should be on a fast, local disk in your system. It can also be a comma-separated
list of multiple directories.
</td>
</tr>
<tr>