diff --git a/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala b/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala new file mode 100644 index 0000000000..ad1d29a79a --- /dev/null +++ b/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala @@ -0,0 +1,245 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.broadcast + +import java.io._ + +import scala.math +import scala.util.Random + +import org.apache.spark._ +import org.apache.spark.storage.{BlockManager, StorageLevel} +import org.apache.spark.util.Utils + + +private[spark] class TorrentBroadcast[T](@transient var value_ : T, isLocal: Boolean, id: Long) +extends Broadcast[T](id) with Logging with Serializable { + + def value = value_ + + def broadcastId = BlockManager.toBroadcastId(id) + + TorrentBroadcast.synchronized { + SparkEnv.get.blockManager.putSingle(broadcastId, value_, StorageLevel.MEMORY_AND_DISK, false) + } + + @transient var arrayOfBlocks: Array[TorrentBlock] = null + @transient var totalBlocks = -1 + @transient var totalBytes = -1 + @transient var hasBlocks = 0 + + if (!isLocal) { + sendBroadcast() + } + + def sendBroadcast() { + var tInfo = TorrentBroadcast.blockifyObject(value_) + + totalBlocks = tInfo.totalBlocks + totalBytes = tInfo.totalBytes + hasBlocks = tInfo.totalBlocks + + // Store meta-info + val metaId = broadcastId + "_meta" + val metaInfo = TorrentInfo(null, totalBlocks, totalBytes) + TorrentBroadcast.synchronized { + SparkEnv.get.blockManager.putSingle( + metaId, metaInfo, StorageLevel.MEMORY_AND_DISK, true) + } + + // Store individual pieces + for (i <- 0 until totalBlocks) { + val pieceId = broadcastId + "_piece_" + i + TorrentBroadcast.synchronized { + SparkEnv.get.blockManager.putSingle( + pieceId, tInfo.arrayOfBlocks(i), StorageLevel.MEMORY_AND_DISK, true) + } + } + } + + // Called by JVM when deserializing an object + private def readObject(in: ObjectInputStream) { + in.defaultReadObject() + TorrentBroadcast.synchronized { + SparkEnv.get.blockManager.getSingle(broadcastId) match { + case Some(x) => + value_ = x.asInstanceOf[T] + + case None => + val start = System.nanoTime + logInfo("Started reading broadcast variable " + id) + + // Master might send invalid values + resetWorkerVariables() + + if (receiveBroadcast(id)) { + value_ = TorrentBroadcast.unBlockifyObject[T](arrayOfBlocks, totalBytes, totalBlocks) + SparkEnv.get.blockManager.putSingle(broadcastId, value_, StorageLevel.MEMORY_AND_DISK, false) + + // Remove arrayOfBlocks from memory once value_ is on local cache + resetWorkerVariables() + } else { + logError("Reading broadcast variable " + id + " failed") + } + + val time = (System.nanoTime - start) / 1e9 + logInfo("Reading broadcast variable " + id + " took " + time + " s") + } + } + } + + private def resetWorkerVariables() { + arrayOfBlocks = null + totalBytes = -1 + totalBlocks = -1 + hasBlocks = 0 + } + + def receiveBroadcast(variableID: Long): Boolean = { + if (totalBlocks > 0 && totalBlocks == hasBlocks) + return true + + // Receive meta-info + val metaId = broadcastId + "_meta" + var attemptId = 10 + while (attemptId > 0 && totalBlocks == -1) { + TorrentBroadcast.synchronized { + SparkEnv.get.blockManager.getSingle(metaId) match { + case Some(x) => + val tInfo = x.asInstanceOf[TorrentInfo] + totalBlocks = tInfo.totalBlocks + totalBytes = tInfo.totalBytes + arrayOfBlocks = new Array[TorrentBlock](totalBlocks) + hasBlocks = 0 + + case None => + Thread.sleep(500) + } + } + attemptId -= 1 + } + if (totalBlocks == -1) + return false + + // Receive actual blocks + val recvOrder = new Random().shuffle(Array.iterate(0, totalBlocks)(_ + 1).toList) + for (pid <- recvOrder) { + val pieceId = broadcastId + "_piece_" + pid + TorrentBroadcast.synchronized { + SparkEnv.get.blockManager.getSingle(pieceId) match { + case Some(x) => + arrayOfBlocks(pid) = x.asInstanceOf[TorrentBlock] + hasBlocks += 1 + SparkEnv.get.blockManager.putSingle( + pieceId, arrayOfBlocks(pid), StorageLevel.MEMORY_AND_DISK, true) + + case None => + throw new SparkException( + "Failed to get " + pieceId + " of " + broadcastId) + } + } + } + + (hasBlocks == totalBlocks) + } + +} + +private object TorrentBroadcast +extends Logging { + + private var initialized = false + + def initialize(_isDriver: Boolean) { + synchronized { + if (!initialized) { + initialized = true + } + } + } + + def stop() { + initialized = false + } + + val BlockSize = System.getProperty("spark.broadcast.blockSize", "2048").toInt * 1024 + + def blockifyObject[IN](obj: IN): TorrentInfo = { + val byteArray = Utils.serialize[IN](obj) + val bais = new ByteArrayInputStream(byteArray) + + var blockNum = (byteArray.length / BlockSize) + if (byteArray.length % BlockSize != 0) + blockNum += 1 + + var retVal = new Array[TorrentBlock](blockNum) + var blockID = 0 + + for (i <- 0 until (byteArray.length, BlockSize)) { + val thisBlockSize = math.min(BlockSize, byteArray.length - i) + var tempByteArray = new Array[Byte](thisBlockSize) + val hasRead = bais.read(tempByteArray, 0, thisBlockSize) + + retVal(blockID) = new TorrentBlock(blockID, tempByteArray) + blockID += 1 + } + bais.close() + + var tInfo = TorrentInfo(retVal, blockNum, byteArray.length) + tInfo.hasBlocks = blockNum + + return tInfo + } + + def unBlockifyObject[OUT](arrayOfBlocks: Array[TorrentBlock], + totalBytes: Int, + totalBlocks: Int): OUT = { + var retByteArray = new Array[Byte](totalBytes) + for (i <- 0 until totalBlocks) { + System.arraycopy(arrayOfBlocks(i).byteArray, 0, retByteArray, + i * BlockSize, arrayOfBlocks(i).byteArray.length) + } + Utils.deserialize[OUT](retByteArray, Thread.currentThread.getContextClassLoader) + } + +} + +private[spark] case class TorrentBlock( + blockID: Int, + byteArray: Array[Byte]) + extends Serializable + +private[spark] case class TorrentInfo( + @transient arrayOfBlocks : Array[TorrentBlock], + totalBlocks: Int, + totalBytes: Int) + extends Serializable { + + @transient var hasBlocks = 0 +} + +private[spark] class TorrentBroadcastFactory + extends BroadcastFactory { + + def initialize(isDriver: Boolean) { TorrentBroadcast.initialize(isDriver) } + + def newBroadcast[T](value_ : T, isLocal: Boolean, id: Long) = + new TorrentBroadcast[T](value_, isLocal, id) + + def stop() { TorrentBroadcast.stop() } +} diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala index 801f88a3db..c67a61515e 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala @@ -21,6 +21,7 @@ import java.io.{InputStream, OutputStream} import java.nio.{ByteBuffer, MappedByteBuffer} import scala.collection.mutable.{HashMap, ArrayBuffer, HashSet} +import scala.util.Random import akka.actor.{ActorSystem, Cancellable, Props} import akka.dispatch.{Await, Future} @@ -269,7 +270,7 @@ private[spark] class BlockManager( } /** - * Actually send a UpdateBlockInfo message. Returns the mater's response, + * Actually send a UpdateBlockInfo message. Returns the master's response, * which will be true if the block was successfully recorded and false if * the slave needs to re-register. */ @@ -478,7 +479,7 @@ private[spark] class BlockManager( } logDebug("Getting remote block " + blockId) // Get locations of block - val locations = master.getLocations(blockId) + val locations = Random.shuffle(master.getLocations(blockId)) // Get block from remote locations for (loc <- locations) { diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterActor.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterActor.scala index 633230c0a8..8b2a812d20 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterActor.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterActor.scala @@ -227,9 +227,10 @@ class BlockManagerMasterActor(val isLocal: Boolean) extends Actor with Logging { } private def register(id: BlockManagerId, maxMemSize: Long, slaveActor: ActorRef) { - if (id.executorId == "" && !isLocal) { + /* if (id.executorId == "" && !isLocal) { // Got a register message from the master node; don't register it - } else if (!blockManagerInfo.contains(id)) { + } else */ + if (!blockManagerInfo.contains(id)) { blockManagerIdByExecutor.get(id.executorId) match { case Some(manager) => // A block manager of the same executor already exists.