Torrent-ish broadcast based on BlockManager.

This commit is contained in:
Mosharaf Chowdhury 2013-10-13 18:46:03 -07:00
parent f9973cae3a
commit 4602e2bf6e
3 changed files with 251 additions and 4 deletions

View file

@ -0,0 +1,245 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.broadcast
import java.io._
import scala.math
import scala.util.Random
import org.apache.spark._
import org.apache.spark.storage.{BlockManager, StorageLevel}
import org.apache.spark.util.Utils
private[spark] class TorrentBroadcast[T](@transient var value_ : T, isLocal: Boolean, id: Long)
extends Broadcast[T](id) with Logging with Serializable {
def value = value_
def broadcastId = BlockManager.toBroadcastId(id)
TorrentBroadcast.synchronized {
SparkEnv.get.blockManager.putSingle(broadcastId, value_, StorageLevel.MEMORY_AND_DISK, false)
}
@transient var arrayOfBlocks: Array[TorrentBlock] = null
@transient var totalBlocks = -1
@transient var totalBytes = -1
@transient var hasBlocks = 0
if (!isLocal) {
sendBroadcast()
}
def sendBroadcast() {
var tInfo = TorrentBroadcast.blockifyObject(value_)
totalBlocks = tInfo.totalBlocks
totalBytes = tInfo.totalBytes
hasBlocks = tInfo.totalBlocks
// Store meta-info
val metaId = broadcastId + "_meta"
val metaInfo = TorrentInfo(null, totalBlocks, totalBytes)
TorrentBroadcast.synchronized {
SparkEnv.get.blockManager.putSingle(
metaId, metaInfo, StorageLevel.MEMORY_AND_DISK, true)
}
// Store individual pieces
for (i <- 0 until totalBlocks) {
val pieceId = broadcastId + "_piece_" + i
TorrentBroadcast.synchronized {
SparkEnv.get.blockManager.putSingle(
pieceId, tInfo.arrayOfBlocks(i), StorageLevel.MEMORY_AND_DISK, true)
}
}
}
// Called by JVM when deserializing an object
private def readObject(in: ObjectInputStream) {
in.defaultReadObject()
TorrentBroadcast.synchronized {
SparkEnv.get.blockManager.getSingle(broadcastId) match {
case Some(x) =>
value_ = x.asInstanceOf[T]
case None =>
val start = System.nanoTime
logInfo("Started reading broadcast variable " + id)
// Master might send invalid values
resetWorkerVariables()
if (receiveBroadcast(id)) {
value_ = TorrentBroadcast.unBlockifyObject[T](arrayOfBlocks, totalBytes, totalBlocks)
SparkEnv.get.blockManager.putSingle(broadcastId, value_, StorageLevel.MEMORY_AND_DISK, false)
// Remove arrayOfBlocks from memory once value_ is on local cache
resetWorkerVariables()
} else {
logError("Reading broadcast variable " + id + " failed")
}
val time = (System.nanoTime - start) / 1e9
logInfo("Reading broadcast variable " + id + " took " + time + " s")
}
}
}
private def resetWorkerVariables() {
arrayOfBlocks = null
totalBytes = -1
totalBlocks = -1
hasBlocks = 0
}
def receiveBroadcast(variableID: Long): Boolean = {
if (totalBlocks > 0 && totalBlocks == hasBlocks)
return true
// Receive meta-info
val metaId = broadcastId + "_meta"
var attemptId = 10
while (attemptId > 0 && totalBlocks == -1) {
TorrentBroadcast.synchronized {
SparkEnv.get.blockManager.getSingle(metaId) match {
case Some(x) =>
val tInfo = x.asInstanceOf[TorrentInfo]
totalBlocks = tInfo.totalBlocks
totalBytes = tInfo.totalBytes
arrayOfBlocks = new Array[TorrentBlock](totalBlocks)
hasBlocks = 0
case None =>
Thread.sleep(500)
}
}
attemptId -= 1
}
if (totalBlocks == -1)
return false
// Receive actual blocks
val recvOrder = new Random().shuffle(Array.iterate(0, totalBlocks)(_ + 1).toList)
for (pid <- recvOrder) {
val pieceId = broadcastId + "_piece_" + pid
TorrentBroadcast.synchronized {
SparkEnv.get.blockManager.getSingle(pieceId) match {
case Some(x) =>
arrayOfBlocks(pid) = x.asInstanceOf[TorrentBlock]
hasBlocks += 1
SparkEnv.get.blockManager.putSingle(
pieceId, arrayOfBlocks(pid), StorageLevel.MEMORY_AND_DISK, true)
case None =>
throw new SparkException(
"Failed to get " + pieceId + " of " + broadcastId)
}
}
}
(hasBlocks == totalBlocks)
}
}
private object TorrentBroadcast
extends Logging {
private var initialized = false
def initialize(_isDriver: Boolean) {
synchronized {
if (!initialized) {
initialized = true
}
}
}
def stop() {
initialized = false
}
val BlockSize = System.getProperty("spark.broadcast.blockSize", "2048").toInt * 1024
def blockifyObject[IN](obj: IN): TorrentInfo = {
val byteArray = Utils.serialize[IN](obj)
val bais = new ByteArrayInputStream(byteArray)
var blockNum = (byteArray.length / BlockSize)
if (byteArray.length % BlockSize != 0)
blockNum += 1
var retVal = new Array[TorrentBlock](blockNum)
var blockID = 0
for (i <- 0 until (byteArray.length, BlockSize)) {
val thisBlockSize = math.min(BlockSize, byteArray.length - i)
var tempByteArray = new Array[Byte](thisBlockSize)
val hasRead = bais.read(tempByteArray, 0, thisBlockSize)
retVal(blockID) = new TorrentBlock(blockID, tempByteArray)
blockID += 1
}
bais.close()
var tInfo = TorrentInfo(retVal, blockNum, byteArray.length)
tInfo.hasBlocks = blockNum
return tInfo
}
def unBlockifyObject[OUT](arrayOfBlocks: Array[TorrentBlock],
totalBytes: Int,
totalBlocks: Int): OUT = {
var retByteArray = new Array[Byte](totalBytes)
for (i <- 0 until totalBlocks) {
System.arraycopy(arrayOfBlocks(i).byteArray, 0, retByteArray,
i * BlockSize, arrayOfBlocks(i).byteArray.length)
}
Utils.deserialize[OUT](retByteArray, Thread.currentThread.getContextClassLoader)
}
}
private[spark] case class TorrentBlock(
blockID: Int,
byteArray: Array[Byte])
extends Serializable
private[spark] case class TorrentInfo(
@transient arrayOfBlocks : Array[TorrentBlock],
totalBlocks: Int,
totalBytes: Int)
extends Serializable {
@transient var hasBlocks = 0
}
private[spark] class TorrentBroadcastFactory
extends BroadcastFactory {
def initialize(isDriver: Boolean) { TorrentBroadcast.initialize(isDriver) }
def newBroadcast[T](value_ : T, isLocal: Boolean, id: Long) =
new TorrentBroadcast[T](value_, isLocal, id)
def stop() { TorrentBroadcast.stop() }
}

View file

@ -21,6 +21,7 @@ import java.io.{InputStream, OutputStream}
import java.nio.{ByteBuffer, MappedByteBuffer}
import scala.collection.mutable.{HashMap, ArrayBuffer, HashSet}
import scala.util.Random
import akka.actor.{ActorSystem, Cancellable, Props}
import akka.dispatch.{Await, Future}
@ -269,7 +270,7 @@ private[spark] class BlockManager(
}
/**
* Actually send a UpdateBlockInfo message. Returns the mater's response,
* Actually send a UpdateBlockInfo message. Returns the master's response,
* which will be true if the block was successfully recorded and false if
* the slave needs to re-register.
*/
@ -478,7 +479,7 @@ private[spark] class BlockManager(
}
logDebug("Getting remote block " + blockId)
// Get locations of block
val locations = master.getLocations(blockId)
val locations = Random.shuffle(master.getLocations(blockId))
// Get block from remote locations
for (loc <- locations) {

View file

@ -227,9 +227,10 @@ class BlockManagerMasterActor(val isLocal: Boolean) extends Actor with Logging {
}
private def register(id: BlockManagerId, maxMemSize: Long, slaveActor: ActorRef) {
if (id.executorId == "<driver>" && !isLocal) {
/* if (id.executorId == "<driver>" && !isLocal) {
// Got a register message from the master node; don't register it
} else if (!blockManagerInfo.contains(id)) {
} else */
if (!blockManagerInfo.contains(id)) {
blockManagerIdByExecutor.get(id.executorId) match {
case Some(manager) =>
// A block manager of the same executor already exists.