Torrent-ish broadcast based on BlockManager.
This commit is contained in:
parent
f9973cae3a
commit
4602e2bf6e
|
@ -0,0 +1,245 @@
|
|||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* (the "License"); you may not use this file except in compliance with
|
||||
* the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package org.apache.spark.broadcast
|
||||
|
||||
import java.io._
|
||||
|
||||
import scala.math
|
||||
import scala.util.Random
|
||||
|
||||
import org.apache.spark._
|
||||
import org.apache.spark.storage.{BlockManager, StorageLevel}
|
||||
import org.apache.spark.util.Utils
|
||||
|
||||
|
||||
private[spark] class TorrentBroadcast[T](@transient var value_ : T, isLocal: Boolean, id: Long)
|
||||
extends Broadcast[T](id) with Logging with Serializable {
|
||||
|
||||
def value = value_
|
||||
|
||||
def broadcastId = BlockManager.toBroadcastId(id)
|
||||
|
||||
TorrentBroadcast.synchronized {
|
||||
SparkEnv.get.blockManager.putSingle(broadcastId, value_, StorageLevel.MEMORY_AND_DISK, false)
|
||||
}
|
||||
|
||||
@transient var arrayOfBlocks: Array[TorrentBlock] = null
|
||||
@transient var totalBlocks = -1
|
||||
@transient var totalBytes = -1
|
||||
@transient var hasBlocks = 0
|
||||
|
||||
if (!isLocal) {
|
||||
sendBroadcast()
|
||||
}
|
||||
|
||||
def sendBroadcast() {
|
||||
var tInfo = TorrentBroadcast.blockifyObject(value_)
|
||||
|
||||
totalBlocks = tInfo.totalBlocks
|
||||
totalBytes = tInfo.totalBytes
|
||||
hasBlocks = tInfo.totalBlocks
|
||||
|
||||
// Store meta-info
|
||||
val metaId = broadcastId + "_meta"
|
||||
val metaInfo = TorrentInfo(null, totalBlocks, totalBytes)
|
||||
TorrentBroadcast.synchronized {
|
||||
SparkEnv.get.blockManager.putSingle(
|
||||
metaId, metaInfo, StorageLevel.MEMORY_AND_DISK, true)
|
||||
}
|
||||
|
||||
// Store individual pieces
|
||||
for (i <- 0 until totalBlocks) {
|
||||
val pieceId = broadcastId + "_piece_" + i
|
||||
TorrentBroadcast.synchronized {
|
||||
SparkEnv.get.blockManager.putSingle(
|
||||
pieceId, tInfo.arrayOfBlocks(i), StorageLevel.MEMORY_AND_DISK, true)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Called by JVM when deserializing an object
|
||||
private def readObject(in: ObjectInputStream) {
|
||||
in.defaultReadObject()
|
||||
TorrentBroadcast.synchronized {
|
||||
SparkEnv.get.blockManager.getSingle(broadcastId) match {
|
||||
case Some(x) =>
|
||||
value_ = x.asInstanceOf[T]
|
||||
|
||||
case None =>
|
||||
val start = System.nanoTime
|
||||
logInfo("Started reading broadcast variable " + id)
|
||||
|
||||
// Master might send invalid values
|
||||
resetWorkerVariables()
|
||||
|
||||
if (receiveBroadcast(id)) {
|
||||
value_ = TorrentBroadcast.unBlockifyObject[T](arrayOfBlocks, totalBytes, totalBlocks)
|
||||
SparkEnv.get.blockManager.putSingle(broadcastId, value_, StorageLevel.MEMORY_AND_DISK, false)
|
||||
|
||||
// Remove arrayOfBlocks from memory once value_ is on local cache
|
||||
resetWorkerVariables()
|
||||
} else {
|
||||
logError("Reading broadcast variable " + id + " failed")
|
||||
}
|
||||
|
||||
val time = (System.nanoTime - start) / 1e9
|
||||
logInfo("Reading broadcast variable " + id + " took " + time + " s")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private def resetWorkerVariables() {
|
||||
arrayOfBlocks = null
|
||||
totalBytes = -1
|
||||
totalBlocks = -1
|
||||
hasBlocks = 0
|
||||
}
|
||||
|
||||
def receiveBroadcast(variableID: Long): Boolean = {
|
||||
if (totalBlocks > 0 && totalBlocks == hasBlocks)
|
||||
return true
|
||||
|
||||
// Receive meta-info
|
||||
val metaId = broadcastId + "_meta"
|
||||
var attemptId = 10
|
||||
while (attemptId > 0 && totalBlocks == -1) {
|
||||
TorrentBroadcast.synchronized {
|
||||
SparkEnv.get.blockManager.getSingle(metaId) match {
|
||||
case Some(x) =>
|
||||
val tInfo = x.asInstanceOf[TorrentInfo]
|
||||
totalBlocks = tInfo.totalBlocks
|
||||
totalBytes = tInfo.totalBytes
|
||||
arrayOfBlocks = new Array[TorrentBlock](totalBlocks)
|
||||
hasBlocks = 0
|
||||
|
||||
case None =>
|
||||
Thread.sleep(500)
|
||||
}
|
||||
}
|
||||
attemptId -= 1
|
||||
}
|
||||
if (totalBlocks == -1)
|
||||
return false
|
||||
|
||||
// Receive actual blocks
|
||||
val recvOrder = new Random().shuffle(Array.iterate(0, totalBlocks)(_ + 1).toList)
|
||||
for (pid <- recvOrder) {
|
||||
val pieceId = broadcastId + "_piece_" + pid
|
||||
TorrentBroadcast.synchronized {
|
||||
SparkEnv.get.blockManager.getSingle(pieceId) match {
|
||||
case Some(x) =>
|
||||
arrayOfBlocks(pid) = x.asInstanceOf[TorrentBlock]
|
||||
hasBlocks += 1
|
||||
SparkEnv.get.blockManager.putSingle(
|
||||
pieceId, arrayOfBlocks(pid), StorageLevel.MEMORY_AND_DISK, true)
|
||||
|
||||
case None =>
|
||||
throw new SparkException(
|
||||
"Failed to get " + pieceId + " of " + broadcastId)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
(hasBlocks == totalBlocks)
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
private object TorrentBroadcast
|
||||
extends Logging {
|
||||
|
||||
private var initialized = false
|
||||
|
||||
def initialize(_isDriver: Boolean) {
|
||||
synchronized {
|
||||
if (!initialized) {
|
||||
initialized = true
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
def stop() {
|
||||
initialized = false
|
||||
}
|
||||
|
||||
val BlockSize = System.getProperty("spark.broadcast.blockSize", "2048").toInt * 1024
|
||||
|
||||
def blockifyObject[IN](obj: IN): TorrentInfo = {
|
||||
val byteArray = Utils.serialize[IN](obj)
|
||||
val bais = new ByteArrayInputStream(byteArray)
|
||||
|
||||
var blockNum = (byteArray.length / BlockSize)
|
||||
if (byteArray.length % BlockSize != 0)
|
||||
blockNum += 1
|
||||
|
||||
var retVal = new Array[TorrentBlock](blockNum)
|
||||
var blockID = 0
|
||||
|
||||
for (i <- 0 until (byteArray.length, BlockSize)) {
|
||||
val thisBlockSize = math.min(BlockSize, byteArray.length - i)
|
||||
var tempByteArray = new Array[Byte](thisBlockSize)
|
||||
val hasRead = bais.read(tempByteArray, 0, thisBlockSize)
|
||||
|
||||
retVal(blockID) = new TorrentBlock(blockID, tempByteArray)
|
||||
blockID += 1
|
||||
}
|
||||
bais.close()
|
||||
|
||||
var tInfo = TorrentInfo(retVal, blockNum, byteArray.length)
|
||||
tInfo.hasBlocks = blockNum
|
||||
|
||||
return tInfo
|
||||
}
|
||||
|
||||
def unBlockifyObject[OUT](arrayOfBlocks: Array[TorrentBlock],
|
||||
totalBytes: Int,
|
||||
totalBlocks: Int): OUT = {
|
||||
var retByteArray = new Array[Byte](totalBytes)
|
||||
for (i <- 0 until totalBlocks) {
|
||||
System.arraycopy(arrayOfBlocks(i).byteArray, 0, retByteArray,
|
||||
i * BlockSize, arrayOfBlocks(i).byteArray.length)
|
||||
}
|
||||
Utils.deserialize[OUT](retByteArray, Thread.currentThread.getContextClassLoader)
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
private[spark] case class TorrentBlock(
|
||||
blockID: Int,
|
||||
byteArray: Array[Byte])
|
||||
extends Serializable
|
||||
|
||||
private[spark] case class TorrentInfo(
|
||||
@transient arrayOfBlocks : Array[TorrentBlock],
|
||||
totalBlocks: Int,
|
||||
totalBytes: Int)
|
||||
extends Serializable {
|
||||
|
||||
@transient var hasBlocks = 0
|
||||
}
|
||||
|
||||
private[spark] class TorrentBroadcastFactory
|
||||
extends BroadcastFactory {
|
||||
|
||||
def initialize(isDriver: Boolean) { TorrentBroadcast.initialize(isDriver) }
|
||||
|
||||
def newBroadcast[T](value_ : T, isLocal: Boolean, id: Long) =
|
||||
new TorrentBroadcast[T](value_, isLocal, id)
|
||||
|
||||
def stop() { TorrentBroadcast.stop() }
|
||||
}
|
|
@ -21,6 +21,7 @@ import java.io.{InputStream, OutputStream}
|
|||
import java.nio.{ByteBuffer, MappedByteBuffer}
|
||||
|
||||
import scala.collection.mutable.{HashMap, ArrayBuffer, HashSet}
|
||||
import scala.util.Random
|
||||
|
||||
import akka.actor.{ActorSystem, Cancellable, Props}
|
||||
import akka.dispatch.{Await, Future}
|
||||
|
@ -269,7 +270,7 @@ private[spark] class BlockManager(
|
|||
}
|
||||
|
||||
/**
|
||||
* Actually send a UpdateBlockInfo message. Returns the mater's response,
|
||||
* Actually send a UpdateBlockInfo message. Returns the master's response,
|
||||
* which will be true if the block was successfully recorded and false if
|
||||
* the slave needs to re-register.
|
||||
*/
|
||||
|
@ -478,7 +479,7 @@ private[spark] class BlockManager(
|
|||
}
|
||||
logDebug("Getting remote block " + blockId)
|
||||
// Get locations of block
|
||||
val locations = master.getLocations(blockId)
|
||||
val locations = Random.shuffle(master.getLocations(blockId))
|
||||
|
||||
// Get block from remote locations
|
||||
for (loc <- locations) {
|
||||
|
|
|
@ -227,9 +227,10 @@ class BlockManagerMasterActor(val isLocal: Boolean) extends Actor with Logging {
|
|||
}
|
||||
|
||||
private def register(id: BlockManagerId, maxMemSize: Long, slaveActor: ActorRef) {
|
||||
if (id.executorId == "<driver>" && !isLocal) {
|
||||
/* if (id.executorId == "<driver>" && !isLocal) {
|
||||
// Got a register message from the master node; don't register it
|
||||
} else if (!blockManagerInfo.contains(id)) {
|
||||
} else */
|
||||
if (!blockManagerInfo.contains(id)) {
|
||||
blockManagerIdByExecutor.get(id.executorId) match {
|
||||
case Some(manager) =>
|
||||
// A block manager of the same executor already exists.
|
||||
|
|
Loading…
Reference in a new issue