[SPARK-1103] Automatic garbage collection of RDD, shuffle and broadcast data

This PR allows Spark to automatically cleanup metadata and data related to persisted RDDs, shuffles and broadcast variables when the corresponding RDDs, shuffles and broadcast variables fall out of scope from the driver program. This is still a work in progress as broadcast cleanup has not been implemented.

**Implementation Details**
A new class `ContextCleaner` is responsible cleaning all the state. It is instantiated as part of a `SparkContext`. RDD and ShuffleDependency classes have overridden `finalize()` function that gets called whenever their instances go out of scope. The `finalize()` function enqueues the object’s identifier (i.e. RDD ID, shuffle ID, etc.) with the `ContextCleaner`, which is a very short and cheap operation and should not significantly affect the garbage collection mechanism. The `ContextCleaner`, on a different thread, performs the cleanup, whose details are given below.

*RDD cleanup:*
`ContextCleaner` calls `RDD.unpersist()` is used to cleanup persisted RDDs. Regarding metadata, the DAGScheduler automatically cleans up all metadata related to a RDD after all jobs have completed. Only the `SparkContext.persistentRDDs` keeps strong references to persisted RDDs. The `TimeStampedHashMap` used for that has been replaced by `TimeStampedWeakValueHashMap` that keeps only weak references to the RDDs, allowing them to be garbage collected.

*Shuffle cleanup:*
New BlockManager message `RemoveShuffle(<shuffle ID>)` asks the `BlockManagerMaster` and currently active `BlockManager`s to delete all the disk blocks related to the shuffle ID. `ContextCleaner` cleans up shuffle data using this message and also cleans up the metadata in the `MapOutputTracker` of the driver. The `MapOutputTracker` at the workers, that caches the shuffle metadata, maintains a `BoundedHashMap` to limit the shuffle information it caches. Refetching the shuffle information from the driver is not too costly.

*Broadcast cleanup:*
To be done. [This PR](https://github.com/apache/incubator-spark/pull/543/) adds mechanism for explicit cleanup of broadcast variables. `Broadcast.finalize()` will enqueue its own ID with ContextCleaner and the PRs mechanism will be used to unpersist the Broadcast data.

*Other cleanup:*
`ShuffleMapTask` and `ResultTask` caches tasks and used TTL based cleanup (using `TimeStampedHashMap`), so nothing got cleaned up if TTL was not set. Instead, they now use `BoundedHashMap` to keep a limited number of map output information. Cost of repopulating the cache if necessary is very small.

**Current state of implementation**
Implemented RDD and shuffle cleanup. Things left to be done are.
- Cleaning up for broadcast variable still to be done.
- Automatic cleaning up keys with empty weak refs as values in `TimeStampedWeakValueHashMap`

Author: Tathagata Das <tathagata.das1565@gmail.com>
Author: Andrew Or <andrewor14@gmail.com>
Author: Roman Pastukhov <ignatich@mail.ru>

Closes #126 from tdas/state-cleanup and squashes the following commits:

61b8d6e [Tathagata Das] Fixed issue with Tachyon + new BlockManager methods.
f489fdc [Tathagata Das] Merge remote-tracking branch 'apache/master' into state-cleanup
d25a86e [Tathagata Das] Fixed stupid typo.
cff023c [Tathagata Das] Fixed issues based on Andrew's comments.
4d05314 [Tathagata Das] Scala style fix.
2b95b5e [Tathagata Das] Added more documentation on Broadcast implementations, specially which blocks are told about to the driver. Also, fixed Broadcast API to hide destroy functionality.
41c9ece [Tathagata Das] Added more unit tests for BlockManager, DiskBlockManager, and ContextCleaner.
6222697 [Tathagata Das] Fixed bug and adding unit test for removeBroadcast in BlockManagerSuite.
104a89a [Tathagata Das] Fixed failing BroadcastSuite unit tests by introducing blocking for removeShuffle and removeBroadcast in BlockManager*
a430f06 [Tathagata Das] Fixed compilation errors.
b27f8e8 [Tathagata Das] Merge pull request #3 from andrewor14/cleanup
cd72d19 [Andrew Or] Make automatic cleanup configurable (not documented)
ada45f0 [Andrew Or] Merge branch 'state-cleanup' of github.com:tdas/spark into cleanup
a2cc8bc [Tathagata Das] Merge remote-tracking branch 'apache/master' into state-cleanup
c5b1d98 [Andrew Or] Address Patrick's comments
a6460d4 [Andrew Or] Merge github.com:apache/spark into cleanup
762a4d8 [Tathagata Das] Merge pull request #1 from andrewor14/cleanup
f0aabb1 [Andrew Or] Correct semantics for TimeStampedWeakValueHashMap + add tests
5016375 [Andrew Or] Address TD's comments
7ed72fb [Andrew Or] Fix style test fail + remove verbose test message regarding broadcast
634a097 [Andrew Or] Merge branch 'state-cleanup' of github.com:tdas/spark into cleanup
7edbc98 [Tathagata Das] Merge remote-tracking branch 'apache-github/master' into state-cleanup
8557c12 [Andrew Or] Merge github.com:apache/spark into cleanup
e442246 [Andrew Or] Merge github.com:apache/spark into cleanup
88904a3 [Andrew Or] Make TimeStampedWeakValueHashMap a wrapper of TimeStampedHashMap
fbfeec8 [Andrew Or] Add functionality to query executors for their local BlockStatuses
34f436f [Andrew Or] Generalize BroadcastBlockId to remove BroadcastHelperBlockId
0d17060 [Andrew Or] Import, comments, and style fixes (minor)
c92e4d9 [Andrew Or] Merge github.com:apache/spark into cleanup
f201a8d [Andrew Or] Test broadcast cleanup in ContextCleanerSuite + remove BoundedHashMap
e95479c [Andrew Or] Add tests for unpersisting broadcast
544ac86 [Andrew Or] Clean up broadcast blocks through BlockManager*
d0edef3 [Andrew Or] Add framework for broadcast cleanup
ba52e00 [Andrew Or] Refactor broadcast classes
c7ccef1 [Andrew Or] Merge branch 'bc-unpersist-merge' of github.com:ignatich/incubator-spark into cleanup
6c9dcf6 [Tathagata Das] Added missing Apache license
d2f8b97 [Tathagata Das] Removed duplicate unpersistRDD.
a007307 [Tathagata Das] Merge remote-tracking branch 'apache/master' into state-cleanup
620eca3 [Tathagata Das] Changes based on PR comments.
f2881fd [Tathagata Das] Changed ContextCleaner to use ReferenceQueue instead of finalizer
e1fba5f [Tathagata Das] Style fix
892b952 [Tathagata Das] Removed use of BoundedHashMap, and made BlockManagerSlaveActor cleanup shuffle metadata in MapOutputTrackerWorker.
a7260d3 [Tathagata Das] Added try-catch in context cleaner and null value cleaning in TimeStampedWeakValueHashMap.
e61daa0 [Tathagata Das] Modifications based on the comments on PR 126.
ae9da88 [Tathagata Das] Removed unncessary TimeStampedHashMap from DAGScheduler, added try-catches in finalize() methods, and replaced ArrayBlockingQueue to LinkedBlockingQueue to avoid blocking in Java's finalizing thread.
cb0a5a6 [Tathagata Das] Fixed docs and styles.
a24fefc [Tathagata Das] Merge remote-tracking branch 'apache/master' into state-cleanup
8512612 [Tathagata Das] Changed TimeStampedHashMap to use WrappedJavaHashMap.
e427a9e [Tathagata Das] Added ContextCleaner to automatically clean RDDs and shuffles when they fall out of scope. Also replaced TimeStampedHashMap to BoundedHashMaps and TimeStampedWeakValueHashMap for the necessary hashmap behavior.
80dd977 [Roman Pastukhov] Fix for Broadcast unpersist patch.
1e752f1 [Roman Pastukhov] Added unpersist method to Broadcast.
This commit is contained in:
Tathagata Das 2014-04-07 23:40:21 -07:00 committed by Patrick Wendell
parent 0d0493fcf7
commit 11eabbe125
40 changed files with 2570 additions and 468 deletions

View file

@ -0,0 +1,192 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark
import java.lang.ref.{ReferenceQueue, WeakReference}
import scala.collection.mutable.{ArrayBuffer, SynchronizedBuffer}
import org.apache.spark.broadcast.Broadcast
import org.apache.spark.rdd.RDD
/**
* Classes that represent cleaning tasks.
*/
private sealed trait CleanupTask
private case class CleanRDD(rddId: Int) extends CleanupTask
private case class CleanShuffle(shuffleId: Int) extends CleanupTask
private case class CleanBroadcast(broadcastId: Long) extends CleanupTask
/**
* A WeakReference associated with a CleanupTask.
*
* When the referent object becomes only weakly reachable, the corresponding
* CleanupTaskWeakReference is automatically added to the given reference queue.
*/
private class CleanupTaskWeakReference(
val task: CleanupTask,
referent: AnyRef,
referenceQueue: ReferenceQueue[AnyRef])
extends WeakReference(referent, referenceQueue)
/**
* An asynchronous cleaner for RDD, shuffle, and broadcast state.
*
* This maintains a weak reference for each RDD, ShuffleDependency, and Broadcast of interest,
* to be processed when the associated object goes out of scope of the application. Actual
* cleanup is performed in a separate daemon thread.
*/
private[spark] class ContextCleaner(sc: SparkContext) extends Logging {
private val referenceBuffer = new ArrayBuffer[CleanupTaskWeakReference]
with SynchronizedBuffer[CleanupTaskWeakReference]
private val referenceQueue = new ReferenceQueue[AnyRef]
private val listeners = new ArrayBuffer[CleanerListener]
with SynchronizedBuffer[CleanerListener]
private val cleaningThread = new Thread() { override def run() { keepCleaning() }}
/**
* Whether the cleaning thread will block on cleanup tasks.
* This is set to true only for tests.
*/
private val blockOnCleanupTasks = sc.conf.getBoolean(
"spark.cleaner.referenceTracking.blocking", false)
@volatile private var stopped = false
/** Attach a listener object to get information of when objects are cleaned. */
def attachListener(listener: CleanerListener) {
listeners += listener
}
/** Start the cleaner. */
def start() {
cleaningThread.setDaemon(true)
cleaningThread.setName("Spark Context Cleaner")
cleaningThread.start()
}
/** Stop the cleaner. */
def stop() {
stopped = true
}
/** Register a RDD for cleanup when it is garbage collected. */
def registerRDDForCleanup(rdd: RDD[_]) {
registerForCleanup(rdd, CleanRDD(rdd.id))
}
/** Register a ShuffleDependency for cleanup when it is garbage collected. */
def registerShuffleForCleanup(shuffleDependency: ShuffleDependency[_, _]) {
registerForCleanup(shuffleDependency, CleanShuffle(shuffleDependency.shuffleId))
}
/** Register a Broadcast for cleanup when it is garbage collected. */
def registerBroadcastForCleanup[T](broadcast: Broadcast[T]) {
registerForCleanup(broadcast, CleanBroadcast(broadcast.id))
}
/** Register an object for cleanup. */
private def registerForCleanup(objectForCleanup: AnyRef, task: CleanupTask) {
referenceBuffer += new CleanupTaskWeakReference(task, objectForCleanup, referenceQueue)
}
/** Keep cleaning RDD, shuffle, and broadcast state. */
private def keepCleaning() {
while (!stopped) {
try {
val reference = Option(referenceQueue.remove(ContextCleaner.REF_QUEUE_POLL_TIMEOUT))
.map(_.asInstanceOf[CleanupTaskWeakReference])
reference.map(_.task).foreach { task =>
logDebug("Got cleaning task " + task)
referenceBuffer -= reference.get
task match {
case CleanRDD(rddId) =>
doCleanupRDD(rddId, blocking = blockOnCleanupTasks)
case CleanShuffle(shuffleId) =>
doCleanupShuffle(shuffleId, blocking = blockOnCleanupTasks)
case CleanBroadcast(broadcastId) =>
doCleanupBroadcast(broadcastId, blocking = blockOnCleanupTasks)
}
}
} catch {
case t: Throwable => logError("Error in cleaning thread", t)
}
}
}
/** Perform RDD cleanup. */
def doCleanupRDD(rddId: Int, blocking: Boolean) {
try {
logDebug("Cleaning RDD " + rddId)
sc.unpersistRDD(rddId, blocking)
listeners.foreach(_.rddCleaned(rddId))
logInfo("Cleaned RDD " + rddId)
} catch {
case t: Throwable => logError("Error cleaning RDD " + rddId, t)
}
}
/** Perform shuffle cleanup, asynchronously. */
def doCleanupShuffle(shuffleId: Int, blocking: Boolean) {
try {
logDebug("Cleaning shuffle " + shuffleId)
mapOutputTrackerMaster.unregisterShuffle(shuffleId)
blockManagerMaster.removeShuffle(shuffleId, blocking)
listeners.foreach(_.shuffleCleaned(shuffleId))
logInfo("Cleaned shuffle " + shuffleId)
} catch {
case t: Throwable => logError("Error cleaning shuffle " + shuffleId, t)
}
}
/** Perform broadcast cleanup. */
def doCleanupBroadcast(broadcastId: Long, blocking: Boolean) {
try {
logDebug("Cleaning broadcast " + broadcastId)
broadcastManager.unbroadcast(broadcastId, true, blocking)
listeners.foreach(_.broadcastCleaned(broadcastId))
logInfo("Cleaned broadcast " + broadcastId)
} catch {
case t: Throwable => logError("Error cleaning broadcast " + broadcastId, t)
}
}
private def blockManagerMaster = sc.env.blockManager.master
private def broadcastManager = sc.env.broadcastManager
private def mapOutputTrackerMaster = sc.env.mapOutputTracker.asInstanceOf[MapOutputTrackerMaster]
// Used for testing. These methods explicitly blocks until cleanup is completed
// to ensure that more reliable testing.
}
private object ContextCleaner {
private val REF_QUEUE_POLL_TIMEOUT = 100
}
/**
* Listener class used for testing when any item has been cleaned by the Cleaner class.
*/
private[spark] trait CleanerListener {
def rddCleaned(rddId: Int)
def shuffleCleaned(shuffleId: Int)
def broadcastCleaned(broadcastId: Long)
}

View file

@ -55,6 +55,8 @@ class ShuffleDependency[K, V](
extends Dependency(rdd.asInstanceOf[RDD[Product2[K, V]]]) {
val shuffleId: Int = rdd.context.newShuffleId()
rdd.sparkContext.cleaner.foreach(_.registerShuffleForCleanup(this))
}

View file

@ -20,21 +20,21 @@ package org.apache.spark
import java.io._
import java.util.zip.{GZIPInputStream, GZIPOutputStream}
import scala.collection.mutable.HashSet
import scala.collection.mutable.{HashSet, HashMap, Map}
import scala.concurrent.Await
import akka.actor._
import akka.pattern.ask
import org.apache.spark.scheduler.MapStatus
import org.apache.spark.storage.BlockManagerId
import org.apache.spark.util.{AkkaUtils, MetadataCleaner, MetadataCleanerType, TimeStampedHashMap}
import org.apache.spark.util._
private[spark] sealed trait MapOutputTrackerMessage
private[spark] case class GetMapOutputStatuses(shuffleId: Int)
extends MapOutputTrackerMessage
private[spark] case object StopMapOutputTracker extends MapOutputTrackerMessage
/** Actor class for MapOutputTrackerMaster */
private[spark] class MapOutputTrackerMasterActor(tracker: MapOutputTrackerMaster, conf: SparkConf)
extends Actor with Logging {
val maxAkkaFrameSize = AkkaUtils.maxFrameSizeBytes(conf)
@ -65,26 +65,41 @@ private[spark] class MapOutputTrackerMasterActor(tracker: MapOutputTrackerMaster
}
}
private[spark] class MapOutputTracker(conf: SparkConf) extends Logging {
/**
* Class that keeps track of the location of the map output of
* a stage. This is abstract because different versions of MapOutputTracker
* (driver and worker) use different HashMap to store its metadata.
*/
private[spark] abstract class MapOutputTracker(conf: SparkConf) extends Logging {
private val timeout = AkkaUtils.askTimeout(conf)
// Set to the MapOutputTrackerActor living on the driver
/** Set to the MapOutputTrackerActor living on the driver. */
var trackerActor: ActorRef = _
protected val mapStatuses = new TimeStampedHashMap[Int, Array[MapStatus]]
/**
* This HashMap has different behavior for the master and the workers.
*
* On the master, it serves as the source of map outputs recorded from ShuffleMapTasks.
* On the workers, it simply serves as a cache, in which a miss triggers a fetch from the
* master's corresponding HashMap.
*/
protected val mapStatuses: Map[Int, Array[MapStatus]]
// Incremented every time a fetch fails so that client nodes know to clear
// their cache of map output locations if this happens.
/**
* Incremented every time a fetch fails so that client nodes know to clear
* their cache of map output locations if this happens.
*/
protected var epoch: Long = 0
protected val epochLock = new java.lang.Object
protected val epochLock = new AnyRef
private val metadataCleaner =
new MetadataCleaner(MetadataCleanerType.MAP_OUTPUT_TRACKER, this.cleanup, conf)
/** Remembers which map output locations are currently being fetched on a worker. */
private val fetching = new HashSet[Int]
// Send a message to the trackerActor and get its result within a default timeout, or
// throw a SparkException if this fails.
private def askTracker(message: Any): Any = {
/**
* Send a message to the trackerActor and get its result within a default timeout, or
* throw a SparkException if this fails.
*/
protected def askTracker(message: Any): Any = {
try {
val future = trackerActor.ask(message)(timeout)
Await.result(future, timeout)
@ -94,17 +109,17 @@ private[spark] class MapOutputTracker(conf: SparkConf) extends Logging {
}
}
// Send a one-way message to the trackerActor, to which we expect it to reply with true.
private def communicate(message: Any) {
/** Send a one-way message to the trackerActor, to which we expect it to reply with true. */
protected def sendTracker(message: Any) {
if (askTracker(message) != true) {
throw new SparkException("Error reply received from MapOutputTracker")
}
}
// Remembers which map output locations are currently being fetched on a worker
private val fetching = new HashSet[Int]
// Called on possibly remote nodes to get the server URIs and output sizes for a given shuffle
/**
* Called from executors to get the server URIs and output sizes of the map outputs of
* a given shuffle.
*/
def getServerStatuses(shuffleId: Int, reduceId: Int): Array[(BlockManagerId, Long)] = {
val statuses = mapStatuses.get(shuffleId).orNull
if (statuses == null) {
@ -152,8 +167,7 @@ private[spark] class MapOutputTracker(conf: SparkConf) extends Logging {
fetchedStatuses.synchronized {
return MapOutputTracker.convertMapStatuses(shuffleId, reduceId, fetchedStatuses)
}
}
else {
} else {
throw new FetchFailedException(null, shuffleId, -1, reduceId,
new Exception("Missing all output locations for shuffle " + shuffleId))
}
@ -164,27 +178,18 @@ private[spark] class MapOutputTracker(conf: SparkConf) extends Logging {
}
}
protected def cleanup(cleanupTime: Long) {
mapStatuses.clearOldValues(cleanupTime)
}
def stop() {
communicate(StopMapOutputTracker)
mapStatuses.clear()
metadataCleaner.cancel()
trackerActor = null
}
// Called to get current epoch number
/** Called to get current epoch number. */
def getEpoch: Long = {
epochLock.synchronized {
return epoch
}
}
// Called on workers to update the epoch number, potentially clearing old outputs
// because of a fetch failure. (Each worker task calls this with the latest epoch
// number on the master at the time it was created.)
/**
* Called from executors to update the epoch number, potentially clearing old outputs
* because of a fetch failure. Each worker task calls this with the latest epoch
* number on the master at the time it was created.
*/
def updateEpoch(newEpoch: Long) {
epochLock.synchronized {
if (newEpoch > epoch) {
@ -194,17 +199,40 @@ private[spark] class MapOutputTracker(conf: SparkConf) extends Logging {
}
}
}
/** Unregister shuffle data. */
def unregisterShuffle(shuffleId: Int) {
mapStatuses.remove(shuffleId)
}
/** Stop the tracker. */
def stop() { }
}
/**
* MapOutputTracker for the driver. This uses TimeStampedHashMap to keep track of map
* output information, which allows old output information based on a TTL.
*/
private[spark] class MapOutputTrackerMaster(conf: SparkConf)
extends MapOutputTracker(conf) {
// Cache a serialized version of the output statuses for each shuffle to send them out faster
/** Cache a serialized version of the output statuses for each shuffle to send them out faster */
private var cacheEpoch = epoch
private val cachedSerializedStatuses = new TimeStampedHashMap[Int, Array[Byte]]
/**
* Timestamp based HashMap for storing mapStatuses and cached serialized statuses in the master,
* so that statuses are dropped only by explicit de-registering or by TTL-based cleaning (if set).
* Other than these two scenarios, nothing should be dropped from this HashMap.
*/
protected val mapStatuses = new TimeStampedHashMap[Int, Array[MapStatus]]()
private val cachedSerializedStatuses = new TimeStampedHashMap[Int, Array[Byte]]()
// For cleaning up TimeStampedHashMaps
private val metadataCleaner =
new MetadataCleaner(MetadataCleanerType.MAP_OUTPUT_TRACKER, this.cleanup, conf)
def registerShuffle(shuffleId: Int, numMaps: Int) {
if (mapStatuses.putIfAbsent(shuffleId, new Array[MapStatus](numMaps)).isDefined) {
if (mapStatuses.put(shuffleId, new Array[MapStatus](numMaps)).isDefined) {
throw new IllegalArgumentException("Shuffle ID " + shuffleId + " registered twice")
}
}
@ -216,6 +244,7 @@ private[spark] class MapOutputTrackerMaster(conf: SparkConf)
}
}
/** Register multiple map output information for the given shuffle */
def registerMapOutputs(shuffleId: Int, statuses: Array[MapStatus], changeEpoch: Boolean = false) {
mapStatuses.put(shuffleId, Array[MapStatus]() ++ statuses)
if (changeEpoch) {
@ -223,6 +252,7 @@ private[spark] class MapOutputTrackerMaster(conf: SparkConf)
}
}
/** Unregister map output information of the given shuffle, mapper and block manager */
def unregisterMapOutput(shuffleId: Int, mapId: Int, bmAddress: BlockManagerId) {
val arrayOpt = mapStatuses.get(shuffleId)
if (arrayOpt.isDefined && arrayOpt.get != null) {
@ -238,6 +268,17 @@ private[spark] class MapOutputTrackerMaster(conf: SparkConf)
}
}
/** Unregister shuffle data */
override def unregisterShuffle(shuffleId: Int) {
mapStatuses.remove(shuffleId)
cachedSerializedStatuses.remove(shuffleId)
}
/** Check if the given shuffle is being tracked */
def containsShuffle(shuffleId: Int): Boolean = {
cachedSerializedStatuses.contains(shuffleId) || mapStatuses.contains(shuffleId)
}
def incrementEpoch() {
epochLock.synchronized {
epoch += 1
@ -274,23 +315,26 @@ private[spark] class MapOutputTrackerMaster(conf: SparkConf)
bytes
}
protected override def cleanup(cleanupTime: Long) {
super.cleanup(cleanupTime)
cachedSerializedStatuses.clearOldValues(cleanupTime)
}
override def stop() {
super.stop()
sendTracker(StopMapOutputTracker)
mapStatuses.clear()
trackerActor = null
metadataCleaner.cancel()
cachedSerializedStatuses.clear()
}
override def updateEpoch(newEpoch: Long) {
// This might be called on the MapOutputTrackerMaster if we're running in local mode.
private def cleanup(cleanupTime: Long) {
mapStatuses.clearOldValues(cleanupTime)
cachedSerializedStatuses.clearOldValues(cleanupTime)
}
}
def has(shuffleId: Int): Boolean = {
cachedSerializedStatuses.get(shuffleId).isDefined || mapStatuses.contains(shuffleId)
}
/**
* MapOutputTracker for the workers, which fetches map output information from the driver's
* MapOutputTrackerMaster.
*/
private[spark] class MapOutputTrackerWorker(conf: SparkConf) extends MapOutputTracker(conf) {
protected val mapStatuses = new HashMap[Int, Array[MapStatus]]
}
private[spark] object MapOutputTracker {

View file

@ -45,7 +45,7 @@ import org.apache.spark.scheduler.cluster.mesos.{CoarseMesosSchedulerBackend, Me
import org.apache.spark.scheduler.local.LocalBackend
import org.apache.spark.storage.{BlockManagerSource, RDDInfo, StorageStatus, StorageUtils}
import org.apache.spark.ui.SparkUI
import org.apache.spark.util.{ClosureCleaner, MetadataCleaner, MetadataCleanerType, TimeStampedHashMap, Utils}
import org.apache.spark.util.{ClosureCleaner, MetadataCleaner, MetadataCleanerType, TimeStampedWeakValueHashMap, Utils}
/**
* Main entry point for Spark functionality. A SparkContext represents the connection to a Spark
@ -157,7 +157,7 @@ class SparkContext(
private[spark] val addedJars = HashMap[String, Long]()
// Keeps track of all persisted RDDs
private[spark] val persistentRdds = new TimeStampedHashMap[Int, RDD[_]]
private[spark] val persistentRdds = new TimeStampedWeakValueHashMap[Int, RDD[_]]
private[spark] val metadataCleaner =
new MetadataCleaner(MetadataCleanerType.SPARK_CONTEXT, this.cleanup, conf)
@ -233,6 +233,15 @@ class SparkContext(
@volatile private[spark] var dagScheduler = new DAGScheduler(this)
dagScheduler.start()
private[spark] val cleaner: Option[ContextCleaner] = {
if (conf.getBoolean("spark.cleaner.referenceTracking", true)) {
Some(new ContextCleaner(this))
} else {
None
}
}
cleaner.foreach(_.start())
postEnvironmentUpdate()
/** A default Hadoop Configuration for the Hadoop code (e.g. file systems) that we reuse. */
@ -679,7 +688,11 @@ class SparkContext(
* [[org.apache.spark.broadcast.Broadcast]] object for reading it in distributed functions.
* The variable will be sent to each cluster only once.
*/
def broadcast[T](value: T): Broadcast[T] = env.broadcastManager.newBroadcast[T](value, isLocal)
def broadcast[T](value: T): Broadcast[T] = {
val bc = env.broadcastManager.newBroadcast[T](value, isLocal)
cleaner.foreach(_.registerBroadcastForCleanup(bc))
bc
}
/**
* Add a file to be downloaded with this Spark job on every node.
@ -789,8 +802,7 @@ class SparkContext(
/**
* Unpersist an RDD from memory and/or disk storage
*/
private[spark] def unpersistRDD(rdd: RDD[_], blocking: Boolean = true) {
val rddId = rdd.id
private[spark] def unpersistRDD(rddId: Int, blocking: Boolean = true) {
env.blockManager.master.removeRdd(rddId, blocking)
persistentRdds.remove(rddId)
listenerBus.post(SparkListenerUnpersistRDD(rddId))
@ -869,6 +881,7 @@ class SparkContext(
dagScheduler = null
if (dagSchedulerCopy != null) {
metadataCleaner.cancel()
cleaner.foreach(_.stop())
dagSchedulerCopy.stop()
listenerBus.stop()
taskScheduler = null

View file

@ -180,12 +180,24 @@ object SparkEnv extends Logging {
}
}
val mapOutputTracker = if (isDriver) {
new MapOutputTrackerMaster(conf)
} else {
new MapOutputTrackerWorker(conf)
}
// Have to assign trackerActor after initialization as MapOutputTrackerActor
// requires the MapOutputTracker itself
mapOutputTracker.trackerActor = registerOrLookup(
"MapOutputTracker",
new MapOutputTrackerMasterActor(mapOutputTracker.asInstanceOf[MapOutputTrackerMaster], conf))
val blockManagerMaster = new BlockManagerMaster(registerOrLookup(
"BlockManagerMaster",
new BlockManagerMasterActor(isLocal, conf, listenerBus)), conf)
val blockManager = new BlockManager(executorId, actorSystem, blockManagerMaster,
serializer, conf, securityManager)
serializer, conf, securityManager, mapOutputTracker)
val connectionManager = blockManager.connectionManager
@ -193,17 +205,6 @@ object SparkEnv extends Logging {
val cacheManager = new CacheManager(blockManager)
// Have to assign trackerActor after initialization as MapOutputTrackerActor
// requires the MapOutputTracker itself
val mapOutputTracker = if (isDriver) {
new MapOutputTrackerMaster(conf)
} else {
new MapOutputTracker(conf)
}
mapOutputTracker.trackerActor = registerOrLookup(
"MapOutputTracker",
new MapOutputTrackerMasterActor(mapOutputTracker.asInstanceOf[MapOutputTrackerMaster], conf))
val shuffleFetcher = instantiateClass[ShuffleFetcher](
"spark.shuffle.fetcher", "org.apache.spark.BlockStoreShuffleFetcher")

View file

@ -18,9 +18,8 @@
package org.apache.spark.broadcast
import java.io.Serializable
import java.util.concurrent.atomic.AtomicLong
import org.apache.spark._
import org.apache.spark.SparkException
/**
* A broadcast variable. Broadcast variables allow the programmer to keep a read-only variable
@ -29,7 +28,8 @@ import org.apache.spark._
* attempts to distribute broadcast variables using efficient broadcast algorithms to reduce
* communication cost.
*
* Broadcast variables are created from a variable `v` by calling [[SparkContext#broadcast]].
* Broadcast variables are created from a variable `v` by calling
* [[org.apache.spark.SparkContext#broadcast]].
* The broadcast variable is a wrapper around `v`, and its value can be accessed by calling the
* `value` method. The interpreter session below shows this:
*
@ -51,49 +51,80 @@ import org.apache.spark._
* @tparam T Type of the data contained in the broadcast variable.
*/
abstract class Broadcast[T](val id: Long) extends Serializable {
def value: T
// We cannot have an abstract readObject here due to some weird issues with
// readObject having to be 'private' in sub-classes.
/**
* Flag signifying whether the broadcast variable is valid
* (that is, not already destroyed) or not.
*/
@volatile private var _isValid = true
override def toString = "Broadcast(" + id + ")"
}
/** Get the broadcasted value. */
def value: T = {
assertValid()
getValue()
}
private[spark]
class BroadcastManager(val _isDriver: Boolean, conf: SparkConf, securityManager: SecurityManager)
extends Logging with Serializable {
/**
* Asynchronously delete cached copies of this broadcast on the executors.
* If the broadcast is used after this is called, it will need to be re-sent to each executor.
*/
def unpersist() {
unpersist(blocking = false)
}
private var initialized = false
private var broadcastFactory: BroadcastFactory = null
/**
* Delete cached copies of this broadcast on the executors. If the broadcast is used after
* this is called, it will need to be re-sent to each executor.
* @param blocking Whether to block until unpersisting has completed
*/
def unpersist(blocking: Boolean) {
assertValid()
doUnpersist(blocking)
}
initialize()
/**
* Destroy all data and metadata related to this broadcast variable. Use this with caution;
* once a broadcast variable has been destroyed, it cannot be used again.
*/
private[spark] def destroy(blocking: Boolean) {
assertValid()
_isValid = false
doDestroy(blocking)
}
// Called by SparkContext or Executor before using Broadcast
private def initialize() {
synchronized {
if (!initialized) {
val broadcastFactoryClass = conf.get(
"spark.broadcast.factory", "org.apache.spark.broadcast.HttpBroadcastFactory")
/**
* Whether this Broadcast is actually usable. This should be false once persisted state is
* removed from the driver.
*/
private[spark] def isValid: Boolean = {
_isValid
}
broadcastFactory =
Class.forName(broadcastFactoryClass).newInstance.asInstanceOf[BroadcastFactory]
/**
* Actually get the broadcasted value. Concrete implementations of Broadcast class must
* define their own way to get the value.
*/
private[spark] def getValue(): T
// Initialize appropriate BroadcastFactory and BroadcastObject
broadcastFactory.initialize(isDriver, conf, securityManager)
/**
* Actually unpersist the broadcasted value on the executors. Concrete implementations of
* Broadcast class must define their own logic to unpersist their own data.
*/
private[spark] def doUnpersist(blocking: Boolean)
initialized = true
}
/**
* Actually destroy all data and metadata related to this broadcast variable.
* Implementation of Broadcast class must define their own logic to destroy their own
* state.
*/
private[spark] def doDestroy(blocking: Boolean)
/** Check if this broadcast is valid. If not valid, exception is thrown. */
private[spark] def assertValid() {
if (!_isValid) {
throw new SparkException("Attempted to use %s after it has been destroyed!".format(toString))
}
}
def stop() {
broadcastFactory.stop()
}
private val nextBroadcastId = new AtomicLong(0)
def newBroadcast[T](value_ : T, isLocal: Boolean) =
broadcastFactory.newBroadcast[T](value_, isLocal, nextBroadcastId.getAndIncrement())
def isDriver = _isDriver
override def toString = "Broadcast(" + id + ")"
}

View file

@ -27,7 +27,8 @@ import org.apache.spark.SparkConf
* entire Spark job.
*/
trait BroadcastFactory {
def initialize(isDriver: Boolean, conf: SparkConf, securityMgr: SecurityManager): Unit
def initialize(isDriver: Boolean, conf: SparkConf, securityMgr: SecurityManager): Unit
def newBroadcast[T](value: T, isLocal: Boolean, id: Long): Broadcast[T]
def unbroadcast(id: Long, removeFromDriver: Boolean, blocking: Boolean): Unit
def stop(): Unit
}

View file

@ -0,0 +1,66 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.broadcast
import java.util.concurrent.atomic.AtomicLong
import org.apache.spark._
private[spark] class BroadcastManager(
val isDriver: Boolean,
conf: SparkConf,
securityManager: SecurityManager)
extends Logging {
private var initialized = false
private var broadcastFactory: BroadcastFactory = null
initialize()
// Called by SparkContext or Executor before using Broadcast
private def initialize() {
synchronized {
if (!initialized) {
val broadcastFactoryClass =
conf.get("spark.broadcast.factory", "org.apache.spark.broadcast.HttpBroadcastFactory")
broadcastFactory =
Class.forName(broadcastFactoryClass).newInstance.asInstanceOf[BroadcastFactory]
// Initialize appropriate BroadcastFactory and BroadcastObject
broadcastFactory.initialize(isDriver, conf, securityManager)
initialized = true
}
}
}
def stop() {
broadcastFactory.stop()
}
private val nextBroadcastId = new AtomicLong(0)
def newBroadcast[T](value_ : T, isLocal: Boolean) = {
broadcastFactory.newBroadcast[T](value_, isLocal, nextBroadcastId.getAndIncrement())
}
def unbroadcast(id: Long, removeFromDriver: Boolean, blocking: Boolean) {
broadcastFactory.unbroadcast(id, removeFromDriver, blocking)
}
}

View file

@ -17,34 +17,65 @@
package org.apache.spark.broadcast
import java.io.{File, FileOutputStream, ObjectInputStream, OutputStream}
import java.net.{URL, URLConnection, URI}
import java.io.{File, FileOutputStream, ObjectInputStream, ObjectOutputStream, OutputStream}
import java.net.{URI, URL, URLConnection}
import java.util.concurrent.TimeUnit
import it.unimi.dsi.fastutil.io.FastBufferedInputStream
import it.unimi.dsi.fastutil.io.FastBufferedOutputStream
import it.unimi.dsi.fastutil.io.{FastBufferedInputStream, FastBufferedOutputStream}
import org.apache.spark.{SparkConf, HttpServer, Logging, SecurityManager, SparkEnv}
import org.apache.spark.{HttpServer, Logging, SecurityManager, SparkConf, SparkEnv}
import org.apache.spark.io.CompressionCodec
import org.apache.spark.storage.{BroadcastBlockId, StorageLevel}
import org.apache.spark.util.{MetadataCleaner, MetadataCleanerType, TimeStampedHashSet, Utils}
/**
* A [[org.apache.spark.broadcast.Broadcast]] implementation that uses HTTP server
* as a broadcast mechanism. The first time a HTTP broadcast variable (sent as part of a
* task) is deserialized in the executor, the broadcasted data is fetched from the driver
* (through a HTTP server running at the driver) and stored in the BlockManager of the
* executor to speed up future accesses.
*/
private[spark] class HttpBroadcast[T](@transient var value_ : T, isLocal: Boolean, id: Long)
extends Broadcast[T](id) with Logging with Serializable {
def value = value_
def getValue = value_
def blockId = BroadcastBlockId(id)
val blockId = BroadcastBlockId(id)
/*
* Broadcasted data is also stored in the BlockManager of the driver. The BlockManagerMaster
* does not need to be told about this block as not only need to know about this data block.
*/
HttpBroadcast.synchronized {
SparkEnv.get.blockManager.putSingle(blockId, value_, StorageLevel.MEMORY_AND_DISK, false)
SparkEnv.get.blockManager.putSingle(
blockId, value_, StorageLevel.MEMORY_AND_DISK, tellMaster = false)
}
if (!isLocal) {
HttpBroadcast.write(id, value_)
}
// Called by JVM when deserializing an object
/**
* Remove all persisted state associated with this HTTP broadcast on the executors.
*/
def doUnpersist(blocking: Boolean) {
HttpBroadcast.unpersist(id, removeFromDriver = false, blocking)
}
/**
* Remove all persisted state associated with this HTTP broadcast on the executors and driver.
*/
def doDestroy(blocking: Boolean) {
HttpBroadcast.unpersist(id, removeFromDriver = true, blocking)
}
/** Used by the JVM when serializing this object. */
private def writeObject(out: ObjectOutputStream) {
assertValid()
out.defaultWriteObject()
}
/** Used by the JVM when deserializing this object. */
private def readObject(in: ObjectInputStream) {
in.defaultReadObject()
HttpBroadcast.synchronized {
@ -54,7 +85,13 @@ private[spark] class HttpBroadcast[T](@transient var value_ : T, isLocal: Boolea
logInfo("Started reading broadcast variable " + id)
val start = System.nanoTime
value_ = HttpBroadcast.read[T](id)
SparkEnv.get.blockManager.putSingle(blockId, value_, StorageLevel.MEMORY_AND_DISK, false)
/*
* We cache broadcast data in the BlockManager so that subsequent tasks using it
* do not need to re-fetch. This data is only used locally and no other node
* needs to fetch this block, so we don't notify the master.
*/
SparkEnv.get.blockManager.putSingle(
blockId, value_, StorageLevel.MEMORY_AND_DISK, tellMaster = false)
val time = (System.nanoTime - start) / 1e9
logInfo("Reading broadcast variable " + id + " took " + time + " s")
}
@ -63,23 +100,8 @@ private[spark] class HttpBroadcast[T](@transient var value_ : T, isLocal: Boolea
}
}
/**
* A [[BroadcastFactory]] implementation that uses a HTTP server as the broadcast medium.
*/
class HttpBroadcastFactory extends BroadcastFactory {
def initialize(isDriver: Boolean, conf: SparkConf, securityMgr: SecurityManager) {
HttpBroadcast.initialize(isDriver, conf, securityMgr)
}
def newBroadcast[T](value_ : T, isLocal: Boolean, id: Long) =
new HttpBroadcast[T](value_, isLocal, id)
def stop() { HttpBroadcast.stop() }
}
private object HttpBroadcast extends Logging {
private[spark] object HttpBroadcast extends Logging {
private var initialized = false
private var broadcastDir: File = null
private var compress: Boolean = false
private var bufferSize: Int = 65536
@ -89,11 +111,9 @@ private object HttpBroadcast extends Logging {
// TODO: This shouldn't be a global variable so that multiple SparkContexts can coexist
private val files = new TimeStampedHashSet[String]
private var cleaner: MetadataCleaner = null
private val httpReadTimeout = TimeUnit.MILLISECONDS.convert(5, TimeUnit.MINUTES).toInt
private var compressionCodec: CompressionCodec = null
private var cleaner: MetadataCleaner = null
def initialize(isDriver: Boolean, conf: SparkConf, securityMgr: SecurityManager) {
synchronized {
@ -136,8 +156,10 @@ private object HttpBroadcast extends Logging {
logInfo("Broadcast server started at " + serverUri)
}
def getFile(id: Long) = new File(broadcastDir, BroadcastBlockId(id).name)
def write(id: Long, value: Any) {
val file = new File(broadcastDir, BroadcastBlockId(id).name)
val file = getFile(id)
val out: OutputStream = {
if (compress) {
compressionCodec.compressedOutputStream(new FileOutputStream(file))
@ -160,7 +182,7 @@ private object HttpBroadcast extends Logging {
if (securityManager.isAuthenticationEnabled()) {
logDebug("broadcast security enabled")
val newuri = Utils.constructURIForAuthentication(new URI(url), securityManager)
uc = newuri.toURL().openConnection()
uc = newuri.toURL.openConnection()
uc.setAllowUserInteraction(false)
} else {
logDebug("broadcast not using security")
@ -169,7 +191,7 @@ private object HttpBroadcast extends Logging {
val in = {
uc.setReadTimeout(httpReadTimeout)
val inputStream = uc.getInputStream();
val inputStream = uc.getInputStream
if (compress) {
compressionCodec.compressedInputStream(inputStream)
} else {
@ -183,20 +205,48 @@ private object HttpBroadcast extends Logging {
obj
}
def cleanup(cleanupTime: Long) {
/**
* Remove all persisted blocks associated with this HTTP broadcast on the executors.
* If removeFromDriver is true, also remove these persisted blocks on the driver
* and delete the associated broadcast file.
*/
def unpersist(id: Long, removeFromDriver: Boolean, blocking: Boolean) = synchronized {
SparkEnv.get.blockManager.master.removeBroadcast(id, removeFromDriver, blocking)
if (removeFromDriver) {
val file = getFile(id)
files.remove(file.toString)
deleteBroadcastFile(file)
}
}
/**
* Periodically clean up old broadcasts by removing the associated map entries and
* deleting the associated files.
*/
private def cleanup(cleanupTime: Long) {
val iterator = files.internalMap.entrySet().iterator()
while(iterator.hasNext) {
val entry = iterator.next()
val (file, time) = (entry.getKey, entry.getValue)
if (time < cleanupTime) {
try {
iterator.remove()
new File(file.toString).delete()
logInfo("Deleted broadcast file '" + file + "'")
} catch {
case e: Exception => logWarning("Could not delete broadcast file '" + file + "'", e)
}
iterator.remove()
deleteBroadcastFile(new File(file.toString))
}
}
}
private def deleteBroadcastFile(file: File) {
try {
if (file.exists) {
if (file.delete()) {
logInfo("Deleted broadcast file: %s".format(file))
} else {
logWarning("Could not delete broadcast file: %s".format(file))
}
}
} catch {
case e: Exception =>
logError("Exception while deleting broadcast file: %s".format(file), e)
}
}
}

View file

@ -0,0 +1,45 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.broadcast
import org.apache.spark.{SecurityManager, SparkConf}
/**
* A [[org.apache.spark.broadcast.BroadcastFactory]] implementation that uses a
* HTTP server as the broadcast mechanism. Refer to
* [[org.apache.spark.broadcast.HttpBroadcast]] for more details about this mechanism.
*/
class HttpBroadcastFactory extends BroadcastFactory {
def initialize(isDriver: Boolean, conf: SparkConf, securityMgr: SecurityManager) {
HttpBroadcast.initialize(isDriver, conf, securityMgr)
}
def newBroadcast[T](value_ : T, isLocal: Boolean, id: Long) =
new HttpBroadcast[T](value_, isLocal, id)
def stop() { HttpBroadcast.stop() }
/**
* Remove all persisted state associated with the HTTP broadcast with the given ID.
* @param removeFromDriver Whether to remove state from the driver
* @param blocking Whether to block until unbroadcasted
*/
def unbroadcast(id: Long, removeFromDriver: Boolean, blocking: Boolean) {
HttpBroadcast.unpersist(id, removeFromDriver, blocking)
}
}

View file

@ -17,24 +17,43 @@
package org.apache.spark.broadcast
import java.io._
import java.io.{ByteArrayInputStream, ObjectInputStream, ObjectOutputStream}
import scala.math
import scala.util.Random
import org.apache.spark._
import org.apache.spark.storage.{BroadcastBlockId, BroadcastHelperBlockId, StorageLevel}
import org.apache.spark.{Logging, SparkConf, SparkEnv, SparkException}
import org.apache.spark.storage.{BroadcastBlockId, StorageLevel}
import org.apache.spark.util.Utils
/**
* A [[org.apache.spark.broadcast.Broadcast]] implementation that uses a BitTorrent-like
* protocol to do a distributed transfer of the broadcasted data to the executors.
* The mechanism is as follows. The driver divides the serializes the broadcasted data,
* divides it into smaller chunks, and stores them in the BlockManager of the driver.
* These chunks are reported to the BlockManagerMaster so that all the executors can
* learn the location of those chunks. The first time the broadcast variable (sent as
* part of task) is deserialized at a executor, all the chunks are fetched using
* the BlockManager. When all the chunks are fetched (initially from the driver's
* BlockManager), they are combined and deserialized to recreate the broadcasted data.
* However, the chunks are also stored in the BlockManager and reported to the
* BlockManagerMaster. As more executors fetch the chunks, BlockManagerMaster learns
* multiple locations for each chunk. Hence, subsequent fetches of each chunk will be
* made to other executors who already have those chunks, resulting in a distributed
* fetching. This prevents the driver from being the bottleneck in sending out multiple
* copies of the broadcast data (one per executor) as done by the
* [[org.apache.spark.broadcast.HttpBroadcast]].
*/
private[spark] class TorrentBroadcast[T](@transient var value_ : T, isLocal: Boolean, id: Long)
extends Broadcast[T](id) with Logging with Serializable {
extends Broadcast[T](id) with Logging with Serializable {
def value = value_
def getValue = value_
def broadcastId = BroadcastBlockId(id)
val broadcastId = BroadcastBlockId(id)
TorrentBroadcast.synchronized {
SparkEnv.get.blockManager.putSingle(broadcastId, value_, StorageLevel.MEMORY_AND_DISK, false)
SparkEnv.get.blockManager.putSingle(
broadcastId, value_, StorageLevel.MEMORY_AND_DISK, tellMaster = false)
}
@transient var arrayOfBlocks: Array[TorrentBlock] = null
@ -46,32 +65,52 @@ extends Broadcast[T](id) with Logging with Serializable {
sendBroadcast()
}
def sendBroadcast() {
var tInfo = TorrentBroadcast.blockifyObject(value_)
/**
* Remove all persisted state associated with this Torrent broadcast on the executors.
*/
def doUnpersist(blocking: Boolean) {
TorrentBroadcast.unpersist(id, removeFromDriver = false, blocking)
}
/**
* Remove all persisted state associated with this Torrent broadcast on the executors
* and driver.
*/
def doDestroy(blocking: Boolean) {
TorrentBroadcast.unpersist(id, removeFromDriver = true, blocking)
}
def sendBroadcast() {
val tInfo = TorrentBroadcast.blockifyObject(value_)
totalBlocks = tInfo.totalBlocks
totalBytes = tInfo.totalBytes
hasBlocks = tInfo.totalBlocks
// Store meta-info
val metaId = BroadcastHelperBlockId(broadcastId, "meta")
val metaId = BroadcastBlockId(id, "meta")
val metaInfo = TorrentInfo(null, totalBlocks, totalBytes)
TorrentBroadcast.synchronized {
SparkEnv.get.blockManager.putSingle(
metaId, metaInfo, StorageLevel.MEMORY_AND_DISK, true)
metaId, metaInfo, StorageLevel.MEMORY_AND_DISK, tellMaster = true)
}
// Store individual pieces
for (i <- 0 until totalBlocks) {
val pieceId = BroadcastHelperBlockId(broadcastId, "piece" + i)
val pieceId = BroadcastBlockId(id, "piece" + i)
TorrentBroadcast.synchronized {
SparkEnv.get.blockManager.putSingle(
pieceId, tInfo.arrayOfBlocks(i), StorageLevel.MEMORY_AND_DISK, true)
pieceId, tInfo.arrayOfBlocks(i), StorageLevel.MEMORY_AND_DISK, tellMaster = true)
}
}
}
// Called by JVM when deserializing an object
/** Used by the JVM when serializing this object. */
private def writeObject(out: ObjectOutputStream) {
assertValid()
out.defaultWriteObject()
}
/** Used by the JVM when deserializing this object. */
private def readObject(in: ObjectInputStream) {
in.defaultReadObject()
TorrentBroadcast.synchronized {
@ -86,18 +125,22 @@ extends Broadcast[T](id) with Logging with Serializable {
// Initialize @transient variables that will receive garbage values from the master.
resetWorkerVariables()
if (receiveBroadcast(id)) {
if (receiveBroadcast()) {
value_ = TorrentBroadcast.unBlockifyObject[T](arrayOfBlocks, totalBytes, totalBlocks)
// Store the merged copy in cache so that the next worker doesn't need to rebuild it.
// This creates a tradeoff between memory usage and latency.
// Storing copy doubles the memory footprint; not storing doubles deserialization cost.
/* Store the merged copy in cache so that the next worker doesn't need to rebuild it.
* This creates a trade-off between memory usage and latency. Storing copy doubles
* the memory footprint; not storing doubles deserialization cost. Also,
* this does not need to be reported to BlockManagerMaster since other executors
* does not need to access this block (they only need to fetch the chunks,
* which are reported).
*/
SparkEnv.get.blockManager.putSingle(
broadcastId, value_, StorageLevel.MEMORY_AND_DISK, false)
broadcastId, value_, StorageLevel.MEMORY_AND_DISK, tellMaster = false)
// Remove arrayOfBlocks from memory once value_ is on local cache
resetWorkerVariables()
} else {
} else {
logError("Reading broadcast variable " + id + " failed")
}
@ -114,9 +157,10 @@ extends Broadcast[T](id) with Logging with Serializable {
hasBlocks = 0
}
def receiveBroadcast(variableID: Long): Boolean = {
// Receive meta-info
val metaId = BroadcastHelperBlockId(broadcastId, "meta")
def receiveBroadcast(): Boolean = {
// Receive meta-info about the size of broadcast data,
// the number of chunks it is divided into, etc.
val metaId = BroadcastBlockId(id, "meta")
var attemptId = 10
while (attemptId > 0 && totalBlocks == -1) {
TorrentBroadcast.synchronized {
@ -138,17 +182,21 @@ extends Broadcast[T](id) with Logging with Serializable {
return false
}
// Receive actual blocks
/*
* Fetch actual chunks of data. Note that all these chunks are stored in
* the BlockManager and reported to the master, so that other executors
* can find out and pull the chunks from this executor.
*/
val recvOrder = new Random().shuffle(Array.iterate(0, totalBlocks)(_ + 1).toList)
for (pid <- recvOrder) {
val pieceId = BroadcastHelperBlockId(broadcastId, "piece" + pid)
val pieceId = BroadcastBlockId(id, "piece" + pid)
TorrentBroadcast.synchronized {
SparkEnv.get.blockManager.getSingle(pieceId) match {
case Some(x) =>
arrayOfBlocks(pid) = x.asInstanceOf[TorrentBlock]
hasBlocks += 1
SparkEnv.get.blockManager.putSingle(
pieceId, arrayOfBlocks(pid), StorageLevel.MEMORY_AND_DISK, true)
pieceId, arrayOfBlocks(pid), StorageLevel.MEMORY_AND_DISK, tellMaster = true)
case None =>
throw new SparkException("Failed to get " + pieceId + " of " + broadcastId)
@ -156,16 +204,16 @@ extends Broadcast[T](id) with Logging with Serializable {
}
}
(hasBlocks == totalBlocks)
hasBlocks == totalBlocks
}
}
private object TorrentBroadcast
extends Logging {
private[spark] object TorrentBroadcast extends Logging {
private lazy val BLOCK_SIZE = conf.getInt("spark.broadcast.blockSize", 4096) * 1024
private var initialized = false
private var conf: SparkConf = null
def initialize(_isDriver: Boolean, conf: SparkConf) {
TorrentBroadcast.conf = conf // TODO: we might have to fix it in tests
synchronized {
@ -179,39 +227,37 @@ extends Logging {
initialized = false
}
lazy val BLOCK_SIZE = conf.getInt("spark.broadcast.blockSize", 4096) * 1024
def blockifyObject[T](obj: T): TorrentInfo = {
val byteArray = Utils.serialize[T](obj)
val bais = new ByteArrayInputStream(byteArray)
var blockNum = (byteArray.length / BLOCK_SIZE)
var blockNum = byteArray.length / BLOCK_SIZE
if (byteArray.length % BLOCK_SIZE != 0) {
blockNum += 1
}
var retVal = new Array[TorrentBlock](blockNum)
var blockID = 0
val blocks = new Array[TorrentBlock](blockNum)
var blockId = 0
for (i <- 0 until (byteArray.length, BLOCK_SIZE)) {
val thisBlockSize = math.min(BLOCK_SIZE, byteArray.length - i)
var tempByteArray = new Array[Byte](thisBlockSize)
val hasRead = bais.read(tempByteArray, 0, thisBlockSize)
val tempByteArray = new Array[Byte](thisBlockSize)
bais.read(tempByteArray, 0, thisBlockSize)
retVal(blockID) = new TorrentBlock(blockID, tempByteArray)
blockID += 1
blocks(blockId) = new TorrentBlock(blockId, tempByteArray)
blockId += 1
}
bais.close()
val tInfo = TorrentInfo(retVal, blockNum, byteArray.length)
tInfo.hasBlocks = blockNum
tInfo
val info = TorrentInfo(blocks, blockNum, byteArray.length)
info.hasBlocks = blockNum
info
}
def unBlockifyObject[T](arrayOfBlocks: Array[TorrentBlock],
totalBytes: Int,
totalBlocks: Int): T = {
def unBlockifyObject[T](
arrayOfBlocks: Array[TorrentBlock],
totalBytes: Int,
totalBlocks: Int): T = {
val retByteArray = new Array[Byte](totalBytes)
for (i <- 0 until totalBlocks) {
System.arraycopy(arrayOfBlocks(i).byteArray, 0, retByteArray,
@ -220,6 +266,13 @@ extends Logging {
Utils.deserialize[T](retByteArray, Thread.currentThread.getContextClassLoader)
}
/**
* Remove all persisted blocks associated with this torrent broadcast on the executors.
* If removeFromDriver is true, also remove these persisted blocks on the driver.
*/
def unpersist(id: Long, removeFromDriver: Boolean, blocking: Boolean) = synchronized {
SparkEnv.get.blockManager.master.removeBroadcast(id, removeFromDriver, blocking)
}
}
private[spark] case class TorrentBlock(
@ -228,25 +281,10 @@ private[spark] case class TorrentBlock(
extends Serializable
private[spark] case class TorrentInfo(
@transient arrayOfBlocks : Array[TorrentBlock],
@transient arrayOfBlocks: Array[TorrentBlock],
totalBlocks: Int,
totalBytes: Int)
extends Serializable {
@transient var hasBlocks = 0
}
/**
* A [[BroadcastFactory]] that creates a torrent-based implementation of broadcast.
*/
class TorrentBroadcastFactory extends BroadcastFactory {
def initialize(isDriver: Boolean, conf: SparkConf, securityMgr: SecurityManager) {
TorrentBroadcast.initialize(isDriver, conf)
}
def newBroadcast[T](value_ : T, isLocal: Boolean, id: Long) =
new TorrentBroadcast[T](value_, isLocal, id)
def stop() { TorrentBroadcast.stop() }
}

View file

@ -0,0 +1,46 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.broadcast
import org.apache.spark.{SecurityManager, SparkConf}
/**
* A [[org.apache.spark.broadcast.Broadcast]] implementation that uses a BitTorrent-like
* protocol to do a distributed transfer of the broadcasted data to the executors. Refer to
* [[org.apache.spark.broadcast.TorrentBroadcast]] for more details.
*/
class TorrentBroadcastFactory extends BroadcastFactory {
def initialize(isDriver: Boolean, conf: SparkConf, securityMgr: SecurityManager) {
TorrentBroadcast.initialize(isDriver, conf)
}
def newBroadcast[T](value_ : T, isLocal: Boolean, id: Long) =
new TorrentBroadcast[T](value_, isLocal, id)
def stop() { TorrentBroadcast.stop() }
/**
* Remove all persisted state associated with the torrent broadcast with the given ID.
* @param removeFromDriver Whether to remove state from the driver.
* @param blocking Whether to block until unbroadcasted
*/
def unbroadcast(id: Long, removeFromDriver: Boolean, blocking: Boolean) {
TorrentBroadcast.unpersist(id, removeFromDriver, blocking)
}
}

View file

@ -17,7 +17,6 @@
package org.apache.spark.network
import java.net._
import java.nio._
import java.nio.channels._
import java.nio.channels.spi._

View file

@ -138,6 +138,8 @@ abstract class RDD[T: ClassTag](
"Cannot change storage level of an RDD after it was already assigned a level")
}
sc.persistRDD(this)
// Register the RDD with the ContextCleaner for automatic GC-based cleanup
sc.cleaner.foreach(_.registerRDDForCleanup(this))
storageLevel = newLevel
this
}
@ -156,7 +158,7 @@ abstract class RDD[T: ClassTag](
*/
def unpersist(blocking: Boolean = true): RDD[T] = {
logInfo("Removing RDD " + id + " from persistence list")
sc.unpersistRDD(this, blocking)
sc.unpersistRDD(id, blocking)
storageLevel = StorageLevel.NONE
this
}
@ -1141,5 +1143,4 @@ abstract class RDD[T: ClassTag](
def toJavaRDD() : JavaRDD[T] = {
new JavaRDD(this)(elementClassTag)
}
}

View file

@ -32,7 +32,7 @@ import org.apache.spark.executor.TaskMetrics
import org.apache.spark.partial.{ApproximateActionListener, ApproximateEvaluator, PartialResult}
import org.apache.spark.rdd.RDD
import org.apache.spark.storage.{BlockId, BlockManager, BlockManagerMaster, RDDBlockId}
import org.apache.spark.util.{MetadataCleaner, MetadataCleanerType, TimeStampedHashMap, Utils}
import org.apache.spark.util.Utils
/**
* The high-level scheduling layer that implements stage-oriented scheduling. It computes a DAG of
@ -80,13 +80,13 @@ class DAGScheduler(
private[scheduler] def numTotalJobs: Int = nextJobId.get()
private val nextStageId = new AtomicInteger(0)
private[scheduler] val jobIdToStageIds = new TimeStampedHashMap[Int, HashSet[Int]]
private[scheduler] val stageIdToJobIds = new TimeStampedHashMap[Int, HashSet[Int]]
private[scheduler] val stageIdToStage = new TimeStampedHashMap[Int, Stage]
private[scheduler] val shuffleToMapStage = new TimeStampedHashMap[Int, Stage]
private[scheduler] val jobIdToStageIds = new HashMap[Int, HashSet[Int]]
private[scheduler] val stageIdToJobIds = new HashMap[Int, HashSet[Int]]
private[scheduler] val stageIdToStage = new HashMap[Int, Stage]
private[scheduler] val shuffleToMapStage = new HashMap[Int, Stage]
private[scheduler] val jobIdToActiveJob = new HashMap[Int, ActiveJob]
private[scheduler] val resultStageToJob = new HashMap[Stage, ActiveJob]
private[scheduler] val stageToInfos = new TimeStampedHashMap[Stage, StageInfo]
private[scheduler] val stageToInfos = new HashMap[Stage, StageInfo]
// Stages we need to run whose parents aren't done
private[scheduler] val waitingStages = new HashSet[Stage]
@ -98,7 +98,7 @@ class DAGScheduler(
private[scheduler] val failedStages = new HashSet[Stage]
// Missing tasks from each stage
private[scheduler] val pendingTasks = new TimeStampedHashMap[Stage, HashSet[Task[_]]]
private[scheduler] val pendingTasks = new HashMap[Stage, HashSet[Task[_]]]
private[scheduler] val activeJobs = new HashSet[ActiveJob]
@ -113,9 +113,6 @@ class DAGScheduler(
// stray messages to detect.
private val failedEpoch = new HashMap[String, Long]
private val metadataCleaner =
new MetadataCleaner(MetadataCleanerType.DAG_SCHEDULER, this.cleanup, env.conf)
taskScheduler.setDAGScheduler(this)
/**
@ -258,7 +255,7 @@ class DAGScheduler(
: Stage =
{
val stage = newStage(rdd, numTasks, Some(shuffleDep), jobId, callSite)
if (mapOutputTracker.has(shuffleDep.shuffleId)) {
if (mapOutputTracker.containsShuffle(shuffleDep.shuffleId)) {
val serLocs = mapOutputTracker.getSerializedMapOutputStatuses(shuffleDep.shuffleId)
val locs = MapOutputTracker.deserializeMapStatuses(serLocs)
for (i <- 0 until locs.size) {
@ -390,6 +387,9 @@ class DAGScheduler(
stageIdToStage -= stageId
stageIdToJobIds -= stageId
ShuffleMapTask.removeStage(stageId)
ResultTask.removeStage(stageId)
logDebug("After removal of stage %d, remaining stages = %d"
.format(stageId, stageIdToStage.size))
}
@ -1084,26 +1084,10 @@ class DAGScheduler(
Nil
}
private def cleanup(cleanupTime: Long) {
Map(
"stageIdToStage" -> stageIdToStage,
"shuffleToMapStage" -> shuffleToMapStage,
"pendingTasks" -> pendingTasks,
"stageToInfos" -> stageToInfos,
"jobIdToStageIds" -> jobIdToStageIds,
"stageIdToJobIds" -> stageIdToJobIds).
foreach { case (s, t) =>
val sizeBefore = t.size
t.clearOldValues(cleanupTime)
logInfo("%s %d --> %d".format(s, sizeBefore, t.size))
}
}
def stop() {
if (eventProcessActor != null) {
eventProcessActor ! StopDAGScheduler
}
metadataCleaner.cancel()
taskScheduler.stop()
}
}

View file

@ -20,21 +20,17 @@ package org.apache.spark.scheduler
import java.io._
import java.util.zip.{GZIPInputStream, GZIPOutputStream}
import scala.collection.mutable.HashMap
import org.apache.spark._
import org.apache.spark.rdd.RDD
import org.apache.spark.rdd.RDDCheckpointData
import org.apache.spark.util.{MetadataCleaner, MetadataCleanerType, TimeStampedHashMap}
import org.apache.spark.rdd.{RDD, RDDCheckpointData}
private[spark] object ResultTask {
// A simple map between the stage id to the serialized byte array of a task.
// Served as a cache for task serialization because serialization can be
// expensive on the master node if it needs to launch thousands of tasks.
val serializedInfoCache = new TimeStampedHashMap[Int, Array[Byte]]
// TODO: This object shouldn't have global variables
val metadataCleaner = new MetadataCleaner(
MetadataCleanerType.RESULT_TASK, serializedInfoCache.clearOldValues, new SparkConf)
private val serializedInfoCache = new HashMap[Int, Array[Byte]]
def serializeInfo(stageId: Int, rdd: RDD[_], func: (TaskContext, Iterator[_]) => _): Array[Byte] =
{
@ -67,6 +63,10 @@ private[spark] object ResultTask {
(rdd, func)
}
def removeStage(stageId: Int) {
serializedInfoCache.remove(stageId)
}
def clearCache() {
synchronized {
serializedInfoCache.clear()

View file

@ -24,22 +24,16 @@ import scala.collection.mutable.HashMap
import org.apache.spark._
import org.apache.spark.executor.ShuffleWriteMetrics
import org.apache.spark.rdd.RDD
import org.apache.spark.rdd.RDDCheckpointData
import org.apache.spark.rdd.{RDD, RDDCheckpointData}
import org.apache.spark.serializer.Serializer
import org.apache.spark.storage._
import org.apache.spark.util.{MetadataCleaner, MetadataCleanerType, TimeStampedHashMap}
private[spark] object ShuffleMapTask {
// A simple map between the stage id to the serialized byte array of a task.
// Served as a cache for task serialization because serialization can be
// expensive on the master node if it needs to launch thousands of tasks.
val serializedInfoCache = new TimeStampedHashMap[Int, Array[Byte]]
// TODO: This object shouldn't have global variables
val metadataCleaner = new MetadataCleaner(
MetadataCleanerType.SHUFFLE_MAP_TASK, serializedInfoCache.clearOldValues, new SparkConf)
private val serializedInfoCache = new HashMap[Int, Array[Byte]]
def serializeInfo(stageId: Int, rdd: RDD[_], dep: ShuffleDependency[_,_]): Array[Byte] = {
synchronized {
@ -80,6 +74,10 @@ private[spark] object ShuffleMapTask {
HashMap(set.toSeq: _*)
}
def removeStage(stageId: Int) {
serializedInfoCache.remove(stageId)
}
def clearCache() {
synchronized {
serializedInfoCache.clear()

View file

@ -42,7 +42,7 @@ import org.apache.spark.scheduler.SchedulingMode.SchedulingMode
*
* THREADING: SchedulerBackends and task-submitting clients can call this class from multiple
* threads, so it needs locks in public API methods to maintain its state. In addition, some
* SchedulerBackends sycnchronize on themselves when they want to send events here, and then
* SchedulerBackends synchronize on themselves when they want to send events here, and then
* acquire a lock on us, so we need to make sure that we don't try to lock the backend while
* we are holding a lock on ourselves.
*/

View file

@ -34,7 +34,7 @@ private[spark] sealed abstract class BlockId {
def asRDDId = if (isRDD) Some(asInstanceOf[RDDBlockId]) else None
def isRDD = isInstanceOf[RDDBlockId]
def isShuffle = isInstanceOf[ShuffleBlockId]
def isBroadcast = isInstanceOf[BroadcastBlockId] || isInstanceOf[BroadcastHelperBlockId]
def isBroadcast = isInstanceOf[BroadcastBlockId]
override def toString = name
override def hashCode = name.hashCode
@ -48,18 +48,13 @@ private[spark] case class RDDBlockId(rddId: Int, splitIndex: Int) extends BlockI
def name = "rdd_" + rddId + "_" + splitIndex
}
private[spark]
case class ShuffleBlockId(shuffleId: Int, mapId: Int, reduceId: Int) extends BlockId {
private[spark] case class ShuffleBlockId(shuffleId: Int, mapId: Int, reduceId: Int)
extends BlockId {
def name = "shuffle_" + shuffleId + "_" + mapId + "_" + reduceId
}
private[spark] case class BroadcastBlockId(broadcastId: Long) extends BlockId {
def name = "broadcast_" + broadcastId
}
private[spark]
case class BroadcastHelperBlockId(broadcastId: BroadcastBlockId, hType: String) extends BlockId {
def name = broadcastId.name + "_" + hType
private[spark] case class BroadcastBlockId(broadcastId: Long, field: String = "") extends BlockId {
def name = "broadcast_" + broadcastId + (if (field == "") "" else "_" + field)
}
private[spark] case class TaskResultBlockId(taskId: Long) extends BlockId {
@ -83,8 +78,7 @@ private[spark] case class TestBlockId(id: String) extends BlockId {
private[spark] object BlockId {
val RDD = "rdd_([0-9]+)_([0-9]+)".r
val SHUFFLE = "shuffle_([0-9]+)_([0-9]+)_([0-9]+)".r
val BROADCAST = "broadcast_([0-9]+)".r
val BROADCAST_HELPER = "broadcast_([0-9]+)_([A-Za-z0-9]+)".r
val BROADCAST = "broadcast_([0-9]+)([_A-Za-z0-9]*)".r
val TASKRESULT = "taskresult_([0-9]+)".r
val STREAM = "input-([0-9]+)-([0-9]+)".r
val TEST = "test_(.*)".r
@ -95,10 +89,8 @@ private[spark] object BlockId {
RDDBlockId(rddId.toInt, splitIndex.toInt)
case SHUFFLE(shuffleId, mapId, reduceId) =>
ShuffleBlockId(shuffleId.toInt, mapId.toInt, reduceId.toInt)
case BROADCAST(broadcastId) =>
BroadcastBlockId(broadcastId.toLong)
case BROADCAST_HELPER(broadcastId, hType) =>
BroadcastHelperBlockId(BroadcastBlockId(broadcastId.toLong), hType)
case BROADCAST(broadcastId, field) =>
BroadcastBlockId(broadcastId.toLong, field.stripPrefix("_"))
case TASKRESULT(taskId) =>
TaskResultBlockId(taskId.toLong)
case STREAM(streamId, uniqueId) =>

View file

@ -19,20 +19,22 @@ package org.apache.spark.storage
import java.io.{File, InputStream, OutputStream}
import java.nio.{ByteBuffer, MappedByteBuffer}
import scala.collection.mutable.{ArrayBuffer, HashMap}
import scala.concurrent.{Await, Future}
import scala.concurrent.duration._
import scala.util.Random
import akka.actor.{ActorSystem, Cancellable, Props}
import it.unimi.dsi.fastutil.io.{FastBufferedOutputStream, FastByteArrayOutputStream}
import sun.nio.ch.DirectBuffer
import org.apache.spark.{Logging, SecurityManager, SparkConf, SparkEnv, SparkException}
import org.apache.spark.{Logging, MapOutputTracker, SecurityManager, SparkConf, SparkEnv, SparkException}
import org.apache.spark.io.CompressionCodec
import org.apache.spark.network._
import org.apache.spark.serializer.Serializer
import org.apache.spark.util._
sealed trait Values
case class ByteBufferValues(buffer: ByteBuffer) extends Values
@ -46,7 +48,8 @@ private[spark] class BlockManager(
val defaultSerializer: Serializer,
maxMemory: Long,
val conf: SparkConf,
securityManager: SecurityManager)
securityManager: SecurityManager,
mapOutputTracker: MapOutputTracker)
extends Logging {
val shuffleBlockManager = new ShuffleBlockManager(this)
@ -55,7 +58,7 @@ private[spark] class BlockManager(
private val blockInfo = new TimeStampedHashMap[BlockId, BlockInfo]
private[storage] val memoryStore: BlockStore = new MemoryStore(this, maxMemory)
private[storage] val memoryStore = new MemoryStore(this, maxMemory)
private[storage] val diskStore = new DiskStore(this, diskBlockManager)
var tachyonInitialized = false
private[storage] lazy val tachyonStore: TachyonStore = {
@ -98,7 +101,7 @@ private[spark] class BlockManager(
val heartBeatFrequency = BlockManager.getHeartBeatFrequency(conf)
val slaveActor = actorSystem.actorOf(Props(new BlockManagerSlaveActor(this)),
val slaveActor = actorSystem.actorOf(Props(new BlockManagerSlaveActor(this, mapOutputTracker)),
name = "BlockManagerActor" + BlockManager.ID_GENERATOR.next)
// Pending re-registration action being executed asynchronously or null if none
@ -137,9 +140,10 @@ private[spark] class BlockManager(
master: BlockManagerMaster,
serializer: Serializer,
conf: SparkConf,
securityManager: SecurityManager) = {
securityManager: SecurityManager,
mapOutputTracker: MapOutputTracker) = {
this(execId, actorSystem, master, serializer, BlockManager.getMaxMemory(conf),
conf, securityManager)
conf, securityManager, mapOutputTracker)
}
/**
@ -217,9 +221,26 @@ private[spark] class BlockManager(
}
/**
* Get storage level of local block. If no info exists for the block, then returns null.
* Get the BlockStatus for the block identified by the given ID, if it exists.
* NOTE: This is mainly for testing, and it doesn't fetch information from Tachyon.
*/
def getLevel(blockId: BlockId): StorageLevel = blockInfo.get(blockId).map(_.level).orNull
def getStatus(blockId: BlockId): Option[BlockStatus] = {
blockInfo.get(blockId).map { info =>
val memSize = if (memoryStore.contains(blockId)) memoryStore.getSize(blockId) else 0L
val diskSize = if (diskStore.contains(blockId)) diskStore.getSize(blockId) else 0L
// Assume that block is not in Tachyon
BlockStatus(info.level, memSize, diskSize, 0L)
}
}
/**
* Get the ids of existing blocks that match the given filter. Note that this will
* query the blocks stored in the disk block manager (that the block manager
* may not know of).
*/
def getMatchingBlockIds(filter: BlockId => Boolean): Seq[BlockId] = {
(blockInfo.keys ++ diskBlockManager.getAllBlocks()).filter(filter).toSeq
}
/**
* Tell the master about the current storage status of a block. This will send a block update
@ -525,9 +546,8 @@ private[spark] class BlockManager(
/**
* A short circuited method to get a block writer that can write data directly to disk.
* The Block will be appended to the File specified by filename.
* This is currently used for writing shuffle files out. Callers should handle error
* cases.
* The Block will be appended to the File specified by filename. This is currently used for
* writing shuffle files out. Callers should handle error cases.
*/
def getDiskWriter(
blockId: BlockId,
@ -863,11 +883,22 @@ private[spark] class BlockManager(
* @return The number of blocks removed.
*/
def removeRdd(rddId: Int): Int = {
// TODO: Instead of doing a linear scan on the blockInfo map, create another map that maps
// from RDD.id to blocks.
// TODO: Avoid a linear scan by creating another mapping of RDD.id to blocks.
logInfo("Removing RDD " + rddId)
val blocksToRemove = blockInfo.keys.flatMap(_.asRDDId).filter(_.rddId == rddId)
blocksToRemove.foreach(blockId => removeBlock(blockId, tellMaster = false))
blocksToRemove.foreach { blockId => removeBlock(blockId, tellMaster = false) }
blocksToRemove.size
}
/**
* Remove all blocks belonging to the given broadcast.
*/
def removeBroadcast(broadcastId: Long, tellMaster: Boolean): Int = {
logInfo("Removing broadcast " + broadcastId)
val blocksToRemove = blockInfo.keys.collect {
case bid @ BroadcastBlockId(`broadcastId`, _) => bid
}
blocksToRemove.foreach { blockId => removeBlock(blockId, tellMaster) }
blocksToRemove.size
}
@ -908,10 +939,10 @@ private[spark] class BlockManager(
}
private def dropOldBlocks(cleanupTime: Long, shouldDrop: (BlockId => Boolean)) {
val iterator = blockInfo.internalMap.entrySet().iterator()
val iterator = blockInfo.getEntrySet.iterator
while (iterator.hasNext) {
val entry = iterator.next()
val (id, info, time) = (entry.getKey, entry.getValue._1, entry.getValue._2)
val (id, info, time) = (entry.getKey, entry.getValue.value, entry.getValue.timestamp)
if (time < cleanupTime && shouldDrop(id)) {
info.synchronized {
val level = info.level
@ -935,7 +966,7 @@ private[spark] class BlockManager(
def shouldCompress(blockId: BlockId): Boolean = blockId match {
case ShuffleBlockId(_, _, _) => compressShuffle
case BroadcastBlockId(_) => compressBroadcast
case BroadcastBlockId(_, _) => compressBroadcast
case RDDBlockId(_, _) => compressRdds
case TempBlockId(_) => compressShuffleSpill
case _ => false

View file

@ -81,6 +81,14 @@ class BlockManagerMaster(var driverActor: ActorRef, conf: SparkConf) extends Log
askDriverWithReply[Seq[Seq[BlockManagerId]]](GetLocationsMultipleBlockIds(blockIds))
}
/**
* Check if block manager master has a block. Note that this can be used to check for only
* those blocks that are reported to block manager master.
*/
def contains(blockId: BlockId) = {
!getLocations(blockId).isEmpty
}
/** Get ids of other nodes in the cluster from the driver */
def getPeers(blockManagerId: BlockManagerId, numPeers: Int): Seq[BlockManagerId] = {
val result = askDriverWithReply[Seq[BlockManagerId]](GetPeers(blockManagerId, numPeers))
@ -99,12 +107,10 @@ class BlockManagerMaster(var driverActor: ActorRef, conf: SparkConf) extends Log
askDriverWithReply(RemoveBlock(blockId))
}
/**
* Remove all blocks belonging to the given RDD.
*/
/** Remove all blocks belonging to the given RDD. */
def removeRdd(rddId: Int, blocking: Boolean) {
val future = askDriverWithReply[Future[Seq[Int]]](RemoveRdd(rddId))
future onFailure {
future.onFailure {
case e: Throwable => logError("Failed to remove RDD " + rddId, e)
}
if (blocking) {
@ -112,6 +118,31 @@ class BlockManagerMaster(var driverActor: ActorRef, conf: SparkConf) extends Log
}
}
/** Remove all blocks belonging to the given shuffle. */
def removeShuffle(shuffleId: Int, blocking: Boolean) {
val future = askDriverWithReply[Future[Seq[Boolean]]](RemoveShuffle(shuffleId))
future.onFailure {
case e: Throwable => logError("Failed to remove shuffle " + shuffleId, e)
}
if (blocking) {
Await.result(future, timeout)
}
}
/** Remove all blocks belonging to the given broadcast. */
def removeBroadcast(broadcastId: Long, removeFromMaster: Boolean, blocking: Boolean) {
val future = askDriverWithReply[Future[Seq[Int]]](
RemoveBroadcast(broadcastId, removeFromMaster))
future.onFailure {
case e: Throwable =>
logError("Failed to remove broadcast " + broadcastId +
" with removeFromMaster = " + removeFromMaster, e)
}
if (blocking) {
Await.result(future, timeout)
}
}
/**
* Return the memory status for each block manager, in the form of a map from
* the block manager's id to two long values. The first value is the maximum
@ -126,6 +157,51 @@ class BlockManagerMaster(var driverActor: ActorRef, conf: SparkConf) extends Log
askDriverWithReply[Array[StorageStatus]](GetStorageStatus)
}
/**
* Return the block's status on all block managers, if any. NOTE: This is a
* potentially expensive operation and should only be used for testing.
*
* If askSlaves is true, this invokes the master to query each block manager for the most
* updated block statuses. This is useful when the master is not informed of the given block
* by all block managers.
*/
def getBlockStatus(
blockId: BlockId,
askSlaves: Boolean = true): Map[BlockManagerId, BlockStatus] = {
val msg = GetBlockStatus(blockId, askSlaves)
/*
* To avoid potential deadlocks, the use of Futures is necessary, because the master actor
* should not block on waiting for a block manager, which can in turn be waiting for the
* master actor for a response to a prior message.
*/
val response = askDriverWithReply[Map[BlockManagerId, Future[Option[BlockStatus]]]](msg)
val (blockManagerIds, futures) = response.unzip
val result = Await.result(Future.sequence(futures), timeout)
if (result == null) {
throw new SparkException("BlockManager returned null for BlockStatus query: " + blockId)
}
val blockStatus = result.asInstanceOf[Iterable[Option[BlockStatus]]]
blockManagerIds.zip(blockStatus).flatMap { case (blockManagerId, status) =>
status.map { s => (blockManagerId, s) }
}.toMap
}
/**
* Return a list of ids of existing blocks such that the ids match the given filter. NOTE: This
* is a potentially expensive operation and should only be used for testing.
*
* If askSlaves is true, this invokes the master to query each block manager for the most
* updated block statuses. This is useful when the master is not informed of the given block
* by all block managers.
*/
def getMatchingBlockIds(
filter: BlockId => Boolean,
askSlaves: Boolean): Seq[BlockId] = {
val msg = GetMatchingBlockIds(filter, askSlaves)
val future = askDriverWithReply[Future[Seq[BlockId]]](msg)
Await.result(future, timeout)
}
/** Stop the driver actor, called only on the Spark driver node */
def stop() {
if (driverActor != null) {

View file

@ -94,9 +94,21 @@ class BlockManagerMasterActor(val isLocal: Boolean, conf: SparkConf, listenerBus
case GetStorageStatus =>
sender ! storageStatus
case GetBlockStatus(blockId, askSlaves) =>
sender ! blockStatus(blockId, askSlaves)
case GetMatchingBlockIds(filter, askSlaves) =>
sender ! getMatchingBlockIds(filter, askSlaves)
case RemoveRdd(rddId) =>
sender ! removeRdd(rddId)
case RemoveShuffle(shuffleId) =>
sender ! removeShuffle(shuffleId)
case RemoveBroadcast(broadcastId, removeFromDriver) =>
sender ! removeBroadcast(broadcastId, removeFromDriver)
case RemoveBlock(blockId) =>
removeBlockFromWorkers(blockId)
sender ! true
@ -140,9 +152,41 @@ class BlockManagerMasterActor(val isLocal: Boolean, conf: SparkConf, listenerBus
// The dispatcher is used as an implicit argument into the Future sequence construction.
import context.dispatcher
val removeMsg = RemoveRdd(rddId)
Future.sequence(blockManagerInfo.values.map { bm =>
bm.slaveActor.ask(removeMsg)(akkaTimeout).mapTo[Int]
}.toSeq)
Future.sequence(
blockManagerInfo.values.map { bm =>
bm.slaveActor.ask(removeMsg)(akkaTimeout).mapTo[Int]
}.toSeq
)
}
private def removeShuffle(shuffleId: Int): Future[Seq[Boolean]] = {
// Nothing to do in the BlockManagerMasterActor data structures
import context.dispatcher
val removeMsg = RemoveShuffle(shuffleId)
Future.sequence(
blockManagerInfo.values.map { bm =>
bm.slaveActor.ask(removeMsg)(akkaTimeout).mapTo[Boolean]
}.toSeq
)
}
/**
* Delegate RemoveBroadcast messages to each BlockManager because the master may not notified
* of all broadcast blocks. If removeFromDriver is false, broadcast blocks are only removed
* from the executors, but not from the driver.
*/
private def removeBroadcast(broadcastId: Long, removeFromDriver: Boolean): Future[Seq[Int]] = {
// TODO: Consolidate usages of <driver>
import context.dispatcher
val removeMsg = RemoveBroadcast(broadcastId, removeFromDriver)
val requiredBlockManagers = blockManagerInfo.values.filter { info =>
removeFromDriver || info.blockManagerId.executorId != "<driver>"
}
Future.sequence(
requiredBlockManagers.map { bm =>
bm.slaveActor.ask(removeMsg)(akkaTimeout).mapTo[Int]
}.toSeq
)
}
private def removeBlockManager(blockManagerId: BlockManagerId) {
@ -225,6 +269,61 @@ class BlockManagerMasterActor(val isLocal: Boolean, conf: SparkConf, listenerBus
}.toArray
}
/**
* Return the block's status for all block managers, if any. NOTE: This is a
* potentially expensive operation and should only be used for testing.
*
* If askSlaves is true, the master queries each block manager for the most updated block
* statuses. This is useful when the master is not informed of the given block by all block
* managers.
*/
private def blockStatus(
blockId: BlockId,
askSlaves: Boolean): Map[BlockManagerId, Future[Option[BlockStatus]]] = {
import context.dispatcher
val getBlockStatus = GetBlockStatus(blockId)
/*
* Rather than blocking on the block status query, master actor should simply return
* Futures to avoid potential deadlocks. This can arise if there exists a block manager
* that is also waiting for this master actor's response to a previous message.
*/
blockManagerInfo.values.map { info =>
val blockStatusFuture =
if (askSlaves) {
info.slaveActor.ask(getBlockStatus)(akkaTimeout).mapTo[Option[BlockStatus]]
} else {
Future { info.getStatus(blockId) }
}
(info.blockManagerId, blockStatusFuture)
}.toMap
}
/**
* Return the ids of blocks present in all the block managers that match the given filter.
* NOTE: This is a potentially expensive operation and should only be used for testing.
*
* If askSlaves is true, the master queries each block manager for the most updated block
* statuses. This is useful when the master is not informed of the given block by all block
* managers.
*/
private def getMatchingBlockIds(
filter: BlockId => Boolean,
askSlaves: Boolean): Future[Seq[BlockId]] = {
import context.dispatcher
val getMatchingBlockIds = GetMatchingBlockIds(filter)
Future.sequence(
blockManagerInfo.values.map { info =>
val future =
if (askSlaves) {
info.slaveActor.ask(getMatchingBlockIds)(akkaTimeout).mapTo[Seq[BlockId]]
} else {
Future { info.blocks.keys.filter(filter).toSeq }
}
future
}
).map(_.flatten.toSeq)
}
private def register(id: BlockManagerId, maxMemSize: Long, slaveActor: ActorRef) {
if (!blockManagerInfo.contains(id)) {
blockManagerIdByExecutor.get(id.executorId) match {
@ -334,6 +433,8 @@ private[spark] class BlockManagerInfo(
logInfo("Registering block manager %s with %s RAM".format(
blockManagerId.hostPort, Utils.bytesToString(maxMem)))
def getStatus(blockId: BlockId) = Option(_blocks.get(blockId))
def updateLastSeenMs() {
_lastSeenMs = System.currentTimeMillis()
}

View file

@ -34,6 +34,13 @@ private[storage] object BlockManagerMessages {
// Remove all blocks belonging to a specific RDD.
case class RemoveRdd(rddId: Int) extends ToBlockManagerSlave
// Remove all blocks belonging to a specific shuffle.
case class RemoveShuffle(shuffleId: Int) extends ToBlockManagerSlave
// Remove all blocks belonging to a specific broadcast.
case class RemoveBroadcast(broadcastId: Long, removeFromDriver: Boolean = true)
extends ToBlockManagerSlave
//////////////////////////////////////////////////////////////////////////////////
// Messages from slaves to the master.
@ -80,7 +87,8 @@ private[storage] object BlockManagerMessages {
}
object UpdateBlockInfo {
def apply(blockManagerId: BlockManagerId,
def apply(
blockManagerId: BlockManagerId,
blockId: BlockId,
storageLevel: StorageLevel,
memSize: Long,
@ -108,7 +116,13 @@ private[storage] object BlockManagerMessages {
case object GetMemoryStatus extends ToBlockManagerMaster
case object ExpireDeadHosts extends ToBlockManagerMaster
case object GetStorageStatus extends ToBlockManagerMaster
case class GetBlockStatus(blockId: BlockId, askSlaves: Boolean = true)
extends ToBlockManagerMaster
case class GetMatchingBlockIds(filter: BlockId => Boolean, askSlaves: Boolean = true)
extends ToBlockManagerMaster
case object ExpireDeadHosts extends ToBlockManagerMaster
}

View file

@ -17,8 +17,11 @@
package org.apache.spark.storage
import akka.actor.Actor
import scala.concurrent.Future
import akka.actor.{ActorRef, Actor}
import org.apache.spark.{Logging, MapOutputTracker}
import org.apache.spark.storage.BlockManagerMessages._
/**
@ -26,14 +29,59 @@ import org.apache.spark.storage.BlockManagerMessages._
* this is used to remove blocks from the slave's BlockManager.
*/
private[storage]
class BlockManagerSlaveActor(blockManager: BlockManager) extends Actor {
override def receive = {
class BlockManagerSlaveActor(
blockManager: BlockManager,
mapOutputTracker: MapOutputTracker)
extends Actor with Logging {
import context.dispatcher
// Operations that involve removing blocks may be slow and should be done asynchronously
override def receive = {
case RemoveBlock(blockId) =>
blockManager.removeBlock(blockId)
doAsync[Boolean]("removing block " + blockId, sender) {
blockManager.removeBlock(blockId)
true
}
case RemoveRdd(rddId) =>
val numBlocksRemoved = blockManager.removeRdd(rddId)
sender ! numBlocksRemoved
doAsync[Int]("removing RDD " + rddId, sender) {
blockManager.removeRdd(rddId)
}
case RemoveShuffle(shuffleId) =>
doAsync[Boolean]("removing shuffle " + shuffleId, sender) {
if (mapOutputTracker != null) {
mapOutputTracker.unregisterShuffle(shuffleId)
}
blockManager.shuffleBlockManager.removeShuffle(shuffleId)
}
case RemoveBroadcast(broadcastId, tellMaster) =>
doAsync[Int]("removing broadcast " + broadcastId, sender) {
blockManager.removeBroadcast(broadcastId, tellMaster)
}
case GetBlockStatus(blockId, _) =>
sender ! blockManager.getStatus(blockId)
case GetMatchingBlockIds(filter, _) =>
sender ! blockManager.getMatchingBlockIds(filter)
}
private def doAsync[T](actionMessage: String, responseActor: ActorRef)(body: => T) {
val future = Future {
logDebug(actionMessage)
body
}
future.onSuccess { case response =>
logDebug("Done " + actionMessage + ", response is " + response)
responseActor ! response
logDebug("Sent response: " + response + " to " + responseActor)
}
future.onFailure { case t: Throwable =>
logError("Error in " + actionMessage, t)
responseActor ! null.asInstanceOf[T]
}
}
}

View file

@ -90,6 +90,20 @@ private[spark] class DiskBlockManager(shuffleManager: ShuffleBlockManager, rootD
def getFile(blockId: BlockId): File = getFile(blockId.name)
/** Check if disk block manager has a block. */
def containsBlock(blockId: BlockId): Boolean = {
getBlockLocation(blockId).file.exists()
}
/** List all the blocks currently stored on disk by the disk manager. */
def getAllBlocks(): Seq[BlockId] = {
// Get all the files inside the array of array of directories
subDirs.flatten.filter(_ != null).flatMap { dir =>
val files = dir.list()
if (files != null) files else Seq.empty
}.map(BlockId.apply)
}
/** Produces a unique block id and File suitable for intermediate results. */
def createTempBlock(): (TempBlockId, File) = {
var blockId = new TempBlockId(UUID.randomUUID())

View file

@ -169,23 +169,43 @@ class ShuffleBlockManager(blockManager: BlockManager) extends Logging {
throw new IllegalStateException("Failed to find shuffle block: " + id)
}
/** Remove all the blocks / files and metadata related to a particular shuffle. */
def removeShuffle(shuffleId: ShuffleId): Boolean = {
// Do not change the ordering of this, if shuffleStates should be removed only
// after the corresponding shuffle blocks have been removed
val cleaned = removeShuffleBlocks(shuffleId)
shuffleStates.remove(shuffleId)
cleaned
}
/** Remove all the blocks / files related to a particular shuffle. */
private def removeShuffleBlocks(shuffleId: ShuffleId): Boolean = {
shuffleStates.get(shuffleId) match {
case Some(state) =>
if (consolidateShuffleFiles) {
for (fileGroup <- state.allFileGroups; file <- fileGroup.files) {
file.delete()
}
} else {
for (mapId <- state.completedMapTasks; reduceId <- 0 until state.numBuckets) {
val blockId = new ShuffleBlockId(shuffleId, mapId, reduceId)
blockManager.diskBlockManager.getFile(blockId).delete()
}
}
logInfo("Deleted all files for shuffle " + shuffleId)
true
case None =>
logInfo("Could not find files for shuffle " + shuffleId + " for deleting")
false
}
}
private def physicalFileName(shuffleId: Int, bucketId: Int, fileId: Int) = {
"merged_shuffle_%d_%d_%d".format(shuffleId, bucketId, fileId)
}
private def cleanup(cleanupTime: Long) {
shuffleStates.clearOldValues(cleanupTime, (shuffleId, state) => {
if (consolidateShuffleFiles) {
for (fileGroup <- state.allFileGroups; file <- fileGroup.files) {
file.delete()
}
} else {
for (mapId <- state.completedMapTasks; reduceId <- 0 until state.numBuckets) {
val blockId = new ShuffleBlockId(shuffleId, mapId, reduceId)
blockManager.diskBlockManager.getFile(blockId).delete()
}
}
})
shuffleStates.clearOldValues(cleanupTime, (shuffleId, state) => removeShuffleBlocks(shuffleId))
}
}

View file

@ -22,7 +22,7 @@ import java.util.concurrent.ArrayBlockingQueue
import akka.actor._
import util.Random
import org.apache.spark.{SecurityManager, SparkConf}
import org.apache.spark.{MapOutputTrackerMaster, SecurityManager, SparkConf}
import org.apache.spark.scheduler.LiveListenerBus
import org.apache.spark.serializer.KryoSerializer
@ -48,7 +48,7 @@ private[spark] object ThreadingTest {
val block = (1 to blockSize).map(_ => Random.nextInt())
val level = randomLevel()
val startTime = System.currentTimeMillis()
manager.put(blockId, block.iterator, level, true)
manager.put(blockId, block.iterator, level, tellMaster = true)
println("Pushed block " + blockId + " in " + (System.currentTimeMillis - startTime) + " ms")
queue.add((blockId, block))
}
@ -101,7 +101,7 @@ private[spark] object ThreadingTest {
conf)
val blockManager = new BlockManager(
"<driver>", actorSystem, blockManagerMaster, serializer, 1024 * 1024, conf,
new SecurityManager(conf))
new SecurityManager(conf), new MapOutputTrackerMaster(conf))
val producers = (1 to numProducers).map(i => new ProducerThread(blockManager, i))
val consumers = producers.map(p => new ConsumerThread(blockManager, p.queue))
producers.foreach(_.start)

View file

@ -62,8 +62,8 @@ private[spark] class MetadataCleaner(
private[spark] object MetadataCleanerType extends Enumeration {
val MAP_OUTPUT_TRACKER, SPARK_CONTEXT, HTTP_BROADCAST, DAG_SCHEDULER, RESULT_TASK,
SHUFFLE_MAP_TASK, BLOCK_MANAGER, SHUFFLE_BLOCK_MANAGER, BROADCAST_VARS = Value
val MAP_OUTPUT_TRACKER, SPARK_CONTEXT, HTTP_BROADCAST, BLOCK_MANAGER,
SHUFFLE_BLOCK_MANAGER, BROADCAST_VARS = Value
type MetadataCleanerType = Value
@ -78,15 +78,16 @@ private[spark] object MetadataCleaner {
conf.getInt("spark.cleaner.ttl", -1)
}
def getDelaySeconds(conf: SparkConf, cleanerType: MetadataCleanerType.MetadataCleanerType): Int =
{
conf.get(MetadataCleanerType.systemProperty(cleanerType), getDelaySeconds(conf).toString)
.toInt
def getDelaySeconds(
conf: SparkConf,
cleanerType: MetadataCleanerType.MetadataCleanerType): Int = {
conf.get(MetadataCleanerType.systemProperty(cleanerType), getDelaySeconds(conf).toString).toInt
}
def setDelaySeconds(conf: SparkConf, cleanerType: MetadataCleanerType.MetadataCleanerType,
delay: Int)
{
def setDelaySeconds(
conf: SparkConf,
cleanerType: MetadataCleanerType.MetadataCleanerType,
delay: Int) {
conf.set(MetadataCleanerType.systemProperty(cleanerType), delay.toString)
}

View file

@ -17,48 +17,54 @@
package org.apache.spark.util
import java.util.Set
import java.util.Map.Entry
import java.util.concurrent.ConcurrentHashMap
import scala.collection.JavaConversions
import scala.collection.immutable
import scala.collection.mutable.Map
import scala.collection.{JavaConversions, mutable}
import org.apache.spark.Logging
private[spark] case class TimeStampedValue[V](value: V, timestamp: Long)
/**
* This is a custom implementation of scala.collection.mutable.Map which stores the insertion
* timestamp along with each key-value pair. If specified, the timestamp of each pair can be
* updated every time it is accessed. Key-value pairs whose timestamp are older than a particular
* threshold time can then be removed using the clearOldValues method. This is intended to
* be a drop-in replacement of scala.collection.mutable.HashMap.
* @param updateTimeStampOnGet When enabled, the timestamp of a pair will be
* updated when it is accessed
*
* @param updateTimeStampOnGet Whether timestamp of a pair will be updated when it is accessed
*/
class TimeStampedHashMap[A, B](updateTimeStampOnGet: Boolean = false)
extends Map[A, B]() with Logging {
val internalMap = new ConcurrentHashMap[A, (B, Long)]()
private[spark] class TimeStampedHashMap[A, B](updateTimeStampOnGet: Boolean = false)
extends mutable.Map[A, B]() with Logging {
private val internalMap = new ConcurrentHashMap[A, TimeStampedValue[B]]()
def get(key: A): Option[B] = {
val value = internalMap.get(key)
if (value != null && updateTimeStampOnGet) {
internalMap.replace(key, value, (value._1, currentTime))
internalMap.replace(key, value, TimeStampedValue(value.value, currentTime))
}
Option(value).map(_._1)
Option(value).map(_.value)
}
def iterator: Iterator[(A, B)] = {
val jIterator = internalMap.entrySet().iterator()
JavaConversions.asScalaIterator(jIterator).map(kv => (kv.getKey, kv.getValue._1))
val jIterator = getEntrySet.iterator
JavaConversions.asScalaIterator(jIterator).map(kv => (kv.getKey, kv.getValue.value))
}
override def + [B1 >: B](kv: (A, B1)): Map[A, B1] = {
def getEntrySet: Set[Entry[A, TimeStampedValue[B]]] = internalMap.entrySet
override def + [B1 >: B](kv: (A, B1)): mutable.Map[A, B1] = {
val newMap = new TimeStampedHashMap[A, B1]
newMap.internalMap.putAll(this.internalMap)
newMap.internalMap.put(kv._1, (kv._2, currentTime))
val oldInternalMap = this.internalMap.asInstanceOf[ConcurrentHashMap[A, TimeStampedValue[B1]]]
newMap.internalMap.putAll(oldInternalMap)
kv match { case (a, b) => newMap.internalMap.put(a, TimeStampedValue(b, currentTime)) }
newMap
}
override def - (key: A): Map[A, B] = {
override def - (key: A): mutable.Map[A, B] = {
val newMap = new TimeStampedHashMap[A, B]
newMap.internalMap.putAll(this.internalMap)
newMap.internalMap.remove(key)
@ -66,17 +72,10 @@ class TimeStampedHashMap[A, B](updateTimeStampOnGet: Boolean = false)
}
override def += (kv: (A, B)): this.type = {
internalMap.put(kv._1, (kv._2, currentTime))
kv match { case (a, b) => internalMap.put(a, TimeStampedValue(b, currentTime)) }
this
}
// Should we return previous value directly or as Option ?
def putIfAbsent(key: A, value: B): Option[B] = {
val prev = internalMap.putIfAbsent(key, (value, currentTime))
if (prev != null) Some(prev._1) else None
}
override def -= (key: A): this.type = {
internalMap.remove(key)
this
@ -87,53 +86,65 @@ class TimeStampedHashMap[A, B](updateTimeStampOnGet: Boolean = false)
}
override def apply(key: A): B = {
val value = internalMap.get(key)
if (value == null) throw new NoSuchElementException()
value._1
get(key).getOrElse { throw new NoSuchElementException() }
}
override def filter(p: ((A, B)) => Boolean): Map[A, B] = {
JavaConversions.mapAsScalaConcurrentMap(internalMap).map(kv => (kv._1, kv._2._1)).filter(p)
override def filter(p: ((A, B)) => Boolean): mutable.Map[A, B] = {
JavaConversions.mapAsScalaConcurrentMap(internalMap)
.map { case (k, TimeStampedValue(v, t)) => (k, v) }
.filter(p)
}
override def empty: Map[A, B] = new TimeStampedHashMap[A, B]()
override def empty: mutable.Map[A, B] = new TimeStampedHashMap[A, B]()
override def size: Int = internalMap.size
override def foreach[U](f: ((A, B)) => U) {
val iterator = internalMap.entrySet().iterator()
while(iterator.hasNext) {
val entry = iterator.next()
val kv = (entry.getKey, entry.getValue._1)
val it = getEntrySet.iterator
while(it.hasNext) {
val entry = it.next()
val kv = (entry.getKey, entry.getValue.value)
f(kv)
}
}
def toMap: immutable.Map[A, B] = iterator.toMap
def putIfAbsent(key: A, value: B): Option[B] = {
val prev = internalMap.putIfAbsent(key, TimeStampedValue(value, currentTime))
Option(prev).map(_.value)
}
def putAll(map: Map[A, B]) {
map.foreach { case (k, v) => update(k, v) }
}
def toMap: Map[A, B] = iterator.toMap
/**
* Removes old key-value pairs that have timestamp earlier than `threshTime`,
* calling the supplied function on each such entry before removing.
*/
def clearOldValues(threshTime: Long, f: (A, B) => Unit) {
val iterator = internalMap.entrySet().iterator()
while (iterator.hasNext) {
val entry = iterator.next()
if (entry.getValue._2 < threshTime) {
f(entry.getKey, entry.getValue._1)
val it = getEntrySet.iterator
while (it.hasNext) {
val entry = it.next()
if (entry.getValue.timestamp < threshTime) {
f(entry.getKey, entry.getValue.value)
logDebug("Removing key " + entry.getKey)
iterator.remove()
it.remove()
}
}
}
/**
* Removes old key-value pairs that have timestamp earlier than `threshTime`
*/
/** Removes old key-value pairs that have timestamp earlier than `threshTime`. */
def clearOldValues(threshTime: Long) {
clearOldValues(threshTime, (_, _) => ())
}
private def currentTime: Long = System.currentTimeMillis()
private def currentTime: Long = System.currentTimeMillis
// For testing
def getTimeStampedValue(key: A): Option[TimeStampedValue[B]] = {
Option(internalMap.get(key))
}
def getTimestamp(key: A): Option[Long] = {
getTimeStampedValue(key).map(_.timestamp)
}
}

View file

@ -0,0 +1,170 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.util
import java.lang.ref.WeakReference
import java.util.concurrent.atomic.AtomicInteger
import scala.collection.mutable
import org.apache.spark.Logging
/**
* A wrapper of TimeStampedHashMap that ensures the values are weakly referenced and timestamped.
*
* If the value is garbage collected and the weak reference is null, get() will return a
* non-existent value. These entries are removed from the map periodically (every N inserts), as
* their values are no longer strongly reachable. Further, key-value pairs whose timestamps are
* older than a particular threshold can be removed using the clearOldValues method.
*
* TimeStampedWeakValueHashMap exposes a scala.collection.mutable.Map interface, which allows it
* to be a drop-in replacement for Scala HashMaps. Internally, it uses a Java ConcurrentHashMap,
* so all operations on this HashMap are thread-safe.
*
* @param updateTimeStampOnGet Whether timestamp of a pair will be updated when it is accessed.
*/
private[spark] class TimeStampedWeakValueHashMap[A, B](updateTimeStampOnGet: Boolean = false)
extends mutable.Map[A, B]() with Logging {
import TimeStampedWeakValueHashMap._
private val internalMap = new TimeStampedHashMap[A, WeakReference[B]](updateTimeStampOnGet)
private val insertCount = new AtomicInteger(0)
/** Return a map consisting only of entries whose values are still strongly reachable. */
private def nonNullReferenceMap = internalMap.filter { case (_, ref) => ref.get != null }
def get(key: A): Option[B] = internalMap.get(key)
def iterator: Iterator[(A, B)] = nonNullReferenceMap.iterator
override def + [B1 >: B](kv: (A, B1)): mutable.Map[A, B1] = {
val newMap = new TimeStampedWeakValueHashMap[A, B1]
val oldMap = nonNullReferenceMap.asInstanceOf[mutable.Map[A, WeakReference[B1]]]
newMap.internalMap.putAll(oldMap.toMap)
newMap.internalMap += kv
newMap
}
override def - (key: A): mutable.Map[A, B] = {
val newMap = new TimeStampedWeakValueHashMap[A, B]
newMap.internalMap.putAll(nonNullReferenceMap.toMap)
newMap.internalMap -= key
newMap
}
override def += (kv: (A, B)): this.type = {
internalMap += kv
if (insertCount.incrementAndGet() % CLEAR_NULL_VALUES_INTERVAL == 0) {
clearNullValues()
}
this
}
override def -= (key: A): this.type = {
internalMap -= key
this
}
override def update(key: A, value: B) = this += ((key, value))
override def apply(key: A): B = internalMap.apply(key)
override def filter(p: ((A, B)) => Boolean): mutable.Map[A, B] = nonNullReferenceMap.filter(p)
override def empty: mutable.Map[A, B] = new TimeStampedWeakValueHashMap[A, B]()
override def size: Int = internalMap.size
override def foreach[U](f: ((A, B)) => U) = nonNullReferenceMap.foreach(f)
def putIfAbsent(key: A, value: B): Option[B] = internalMap.putIfAbsent(key, value)
def toMap: Map[A, B] = iterator.toMap
/** Remove old key-value pairs with timestamps earlier than `threshTime`. */
def clearOldValues(threshTime: Long) = internalMap.clearOldValues(threshTime)
/** Remove entries with values that are no longer strongly reachable. */
def clearNullValues() {
val it = internalMap.getEntrySet.iterator
while (it.hasNext) {
val entry = it.next()
if (entry.getValue.value.get == null) {
logDebug("Removing key " + entry.getKey + " because it is no longer strongly reachable.")
it.remove()
}
}
}
// For testing
def getTimestamp(key: A): Option[Long] = {
internalMap.getTimeStampedValue(key).map(_.timestamp)
}
def getReference(key: A): Option[WeakReference[B]] = {
internalMap.getTimeStampedValue(key).map(_.value)
}
}
/**
* Helper methods for converting to and from WeakReferences.
*/
private object TimeStampedWeakValueHashMap {
// Number of inserts after which entries with null references are removed
val CLEAR_NULL_VALUES_INTERVAL = 100
/* Implicit conversion methods to WeakReferences. */
implicit def toWeakReference[V](v: V): WeakReference[V] = new WeakReference[V](v)
implicit def toWeakReferenceTuple[K, V](kv: (K, V)): (K, WeakReference[V]) = {
kv match { case (k, v) => (k, toWeakReference(v)) }
}
implicit def toWeakReferenceFunction[K, V, R](p: ((K, V)) => R): ((K, WeakReference[V])) => R = {
(kv: (K, WeakReference[V])) => p(kv)
}
/* Implicit conversion methods from WeakReferences. */
implicit def fromWeakReference[V](ref: WeakReference[V]): V = ref.get
implicit def fromWeakReferenceOption[V](v: Option[WeakReference[V]]): Option[V] = {
v match {
case Some(ref) => Option(fromWeakReference(ref))
case None => None
}
}
implicit def fromWeakReferenceTuple[K, V](kv: (K, WeakReference[V])): (K, V) = {
kv match { case (k, v) => (k, fromWeakReference(v)) }
}
implicit def fromWeakReferenceIterator[K, V](
it: Iterator[(K, WeakReference[V])]): Iterator[(K, V)] = {
it.map(fromWeakReferenceTuple)
}
implicit def fromWeakReferenceMap[K, V](
map: mutable.Map[K, WeakReference[V]]) : mutable.Map[K, V] = {
mutable.Map(map.mapValues(fromWeakReference).toSeq: _*)
}
}

View file

@ -499,10 +499,10 @@ private[spark] object Utils extends Logging {
private val hostPortParseResults = new ConcurrentHashMap[String, (String, Int)]()
def parseHostPort(hostPort: String): (String, Int) = {
{
// Check cache first.
val cached = hostPortParseResults.get(hostPort)
if (cached != null) return cached
// Check cache first.
val cached = hostPortParseResults.get(hostPort)
if (cached != null) {
return cached
}
val indx: Int = hostPort.lastIndexOf(':')

View file

@ -56,7 +56,7 @@ class AkkaUtilsSuite extends FunSuite with LocalSparkContext {
val (slaveSystem, _) = AkkaUtils.createActorSystem("spark-slave", hostname, 0,
conf = conf, securityManager = securityManagerBad)
val slaveTracker = new MapOutputTracker(conf)
val slaveTracker = new MapOutputTrackerWorker(conf)
val selection = slaveSystem.actorSelection(
s"akka.tcp://spark@localhost:$boundPort/user/MapOutputTracker")
val timeout = AkkaUtils.lookupTimeout(conf)
@ -93,7 +93,7 @@ class AkkaUtilsSuite extends FunSuite with LocalSparkContext {
val (slaveSystem, _) = AkkaUtils.createActorSystem("spark-slave", hostname, 0,
conf = badconf, securityManager = securityManagerBad)
val slaveTracker = new MapOutputTracker(conf)
val slaveTracker = new MapOutputTrackerWorker(conf)
val selection = slaveSystem.actorSelection(
s"akka.tcp://spark@localhost:$boundPort/user/MapOutputTracker")
val timeout = AkkaUtils.lookupTimeout(conf)
@ -147,7 +147,7 @@ class AkkaUtilsSuite extends FunSuite with LocalSparkContext {
val (slaveSystem, _) = AkkaUtils.createActorSystem("spark-slave", hostname, 0,
conf = goodconf, securityManager = securityManagerGood)
val slaveTracker = new MapOutputTracker(conf)
val slaveTracker = new MapOutputTrackerWorker(conf)
val selection = slaveSystem.actorSelection(
s"akka.tcp://spark@localhost:$boundPort/user/MapOutputTracker")
val timeout = AkkaUtils.lookupTimeout(conf)
@ -200,7 +200,7 @@ class AkkaUtilsSuite extends FunSuite with LocalSparkContext {
val (slaveSystem, _) = AkkaUtils.createActorSystem("spark-slave", hostname, 0,
conf = badconf, securityManager = securityManagerBad)
val slaveTracker = new MapOutputTracker(conf)
val slaveTracker = new MapOutputTrackerWorker(conf)
val selection = slaveSystem.actorSelection(
s"akka.tcp://spark@localhost:$boundPort/user/MapOutputTracker")
val timeout = AkkaUtils.lookupTimeout(conf)

View file

@ -19,68 +19,297 @@ package org.apache.spark
import org.scalatest.FunSuite
import org.apache.spark.storage._
import org.apache.spark.broadcast.{Broadcast, HttpBroadcast}
import org.apache.spark.storage.BroadcastBlockId
class BroadcastSuite extends FunSuite with LocalSparkContext {
override def afterEach() {
super.afterEach()
System.clearProperty("spark.broadcast.factory")
}
private val httpConf = broadcastConf("HttpBroadcastFactory")
private val torrentConf = broadcastConf("TorrentBroadcastFactory")
test("Using HttpBroadcast locally") {
System.setProperty("spark.broadcast.factory", "org.apache.spark.broadcast.HttpBroadcastFactory")
sc = new SparkContext("local", "test")
val list = List(1, 2, 3, 4)
val listBroadcast = sc.broadcast(list)
val results = sc.parallelize(1 to 2).map(x => (x, listBroadcast.value.sum))
assert(results.collect.toSet === Set((1, 10), (2, 10)))
sc = new SparkContext("local", "test", httpConf)
val list = List[Int](1, 2, 3, 4)
val broadcast = sc.broadcast(list)
val results = sc.parallelize(1 to 2).map(x => (x, broadcast.value.sum))
assert(results.collect().toSet === Set((1, 10), (2, 10)))
}
test("Accessing HttpBroadcast variables from multiple threads") {
System.setProperty("spark.broadcast.factory", "org.apache.spark.broadcast.HttpBroadcastFactory")
sc = new SparkContext("local[10]", "test")
val list = List(1, 2, 3, 4)
val listBroadcast = sc.broadcast(list)
val results = sc.parallelize(1 to 10).map(x => (x, listBroadcast.value.sum))
assert(results.collect.toSet === (1 to 10).map(x => (x, 10)).toSet)
sc = new SparkContext("local[10]", "test", httpConf)
val list = List[Int](1, 2, 3, 4)
val broadcast = sc.broadcast(list)
val results = sc.parallelize(1 to 10).map(x => (x, broadcast.value.sum))
assert(results.collect().toSet === (1 to 10).map(x => (x, 10)).toSet)
}
test("Accessing HttpBroadcast variables in a local cluster") {
System.setProperty("spark.broadcast.factory", "org.apache.spark.broadcast.HttpBroadcastFactory")
val numSlaves = 4
sc = new SparkContext("local-cluster[%d, 1, 512]".format(numSlaves), "test")
val list = List(1, 2, 3, 4)
val listBroadcast = sc.broadcast(list)
val results = sc.parallelize(1 to numSlaves).map(x => (x, listBroadcast.value.sum))
assert(results.collect.toSet === (1 to numSlaves).map(x => (x, 10)).toSet)
sc = new SparkContext("local-cluster[%d, 1, 512]".format(numSlaves), "test", httpConf)
val list = List[Int](1, 2, 3, 4)
val broadcast = sc.broadcast(list)
val results = sc.parallelize(1 to numSlaves).map(x => (x, broadcast.value.sum))
assert(results.collect().toSet === (1 to numSlaves).map(x => (x, 10)).toSet)
}
test("Using TorrentBroadcast locally") {
System.setProperty("spark.broadcast.factory", "org.apache.spark.broadcast.TorrentBroadcastFactory")
sc = new SparkContext("local", "test")
val list = List(1, 2, 3, 4)
val listBroadcast = sc.broadcast(list)
val results = sc.parallelize(1 to 2).map(x => (x, listBroadcast.value.sum))
assert(results.collect.toSet === Set((1, 10), (2, 10)))
sc = new SparkContext("local", "test", torrentConf)
val list = List[Int](1, 2, 3, 4)
val broadcast = sc.broadcast(list)
val results = sc.parallelize(1 to 2).map(x => (x, broadcast.value.sum))
assert(results.collect().toSet === Set((1, 10), (2, 10)))
}
test("Accessing TorrentBroadcast variables from multiple threads") {
System.setProperty("spark.broadcast.factory", "org.apache.spark.broadcast.TorrentBroadcastFactory")
sc = new SparkContext("local[10]", "test")
val list = List(1, 2, 3, 4)
val listBroadcast = sc.broadcast(list)
val results = sc.parallelize(1 to 10).map(x => (x, listBroadcast.value.sum))
assert(results.collect.toSet === (1 to 10).map(x => (x, 10)).toSet)
sc = new SparkContext("local[10]", "test", torrentConf)
val list = List[Int](1, 2, 3, 4)
val broadcast = sc.broadcast(list)
val results = sc.parallelize(1 to 10).map(x => (x, broadcast.value.sum))
assert(results.collect().toSet === (1 to 10).map(x => (x, 10)).toSet)
}
test("Accessing TorrentBroadcast variables in a local cluster") {
System.setProperty("spark.broadcast.factory", "org.apache.spark.broadcast.TorrentBroadcastFactory")
val numSlaves = 4
sc = new SparkContext("local-cluster[%d, 1, 512]".format(numSlaves), "test")
val list = List(1, 2, 3, 4)
val listBroadcast = sc.broadcast(list)
val results = sc.parallelize(1 to numSlaves).map(x => (x, listBroadcast.value.sum))
assert(results.collect.toSet === (1 to numSlaves).map(x => (x, 10)).toSet)
sc = new SparkContext("local-cluster[%d, 1, 512]".format(numSlaves), "test", torrentConf)
val list = List[Int](1, 2, 3, 4)
val broadcast = sc.broadcast(list)
val results = sc.parallelize(1 to numSlaves).map(x => (x, broadcast.value.sum))
assert(results.collect().toSet === (1 to numSlaves).map(x => (x, 10)).toSet)
}
test("Unpersisting HttpBroadcast on executors only in local mode") {
testUnpersistHttpBroadcast(distributed = false, removeFromDriver = false)
}
test("Unpersisting HttpBroadcast on executors and driver in local mode") {
testUnpersistHttpBroadcast(distributed = false, removeFromDriver = true)
}
test("Unpersisting HttpBroadcast on executors only in distributed mode") {
testUnpersistHttpBroadcast(distributed = true, removeFromDriver = false)
}
test("Unpersisting HttpBroadcast on executors and driver in distributed mode") {
testUnpersistHttpBroadcast(distributed = true, removeFromDriver = true)
}
test("Unpersisting TorrentBroadcast on executors only in local mode") {
testUnpersistTorrentBroadcast(distributed = false, removeFromDriver = false)
}
test("Unpersisting TorrentBroadcast on executors and driver in local mode") {
testUnpersistTorrentBroadcast(distributed = false, removeFromDriver = true)
}
test("Unpersisting TorrentBroadcast on executors only in distributed mode") {
testUnpersistTorrentBroadcast(distributed = true, removeFromDriver = false)
}
test("Unpersisting TorrentBroadcast on executors and driver in distributed mode") {
testUnpersistTorrentBroadcast(distributed = true, removeFromDriver = true)
}
/**
* Verify the persistence of state associated with an HttpBroadcast in either local mode or
* local-cluster mode (when distributed = true).
*
* This test creates a broadcast variable, uses it on all executors, and then unpersists it.
* In between each step, this test verifies that the broadcast blocks and the broadcast file
* are present only on the expected nodes.
*/
private def testUnpersistHttpBroadcast(distributed: Boolean, removeFromDriver: Boolean) {
val numSlaves = if (distributed) 2 else 0
def getBlockIds(id: Long) = Seq[BroadcastBlockId](BroadcastBlockId(id))
// Verify that the broadcast file is created, and blocks are persisted only on the driver
def afterCreation(blockIds: Seq[BroadcastBlockId], bmm: BlockManagerMaster) {
assert(blockIds.size === 1)
val statuses = bmm.getBlockStatus(blockIds.head, askSlaves = true)
assert(statuses.size === 1)
statuses.head match { case (bm, status) =>
assert(bm.executorId === "<driver>", "Block should only be on the driver")
assert(status.storageLevel === StorageLevel.MEMORY_AND_DISK)
assert(status.memSize > 0, "Block should be in memory store on the driver")
assert(status.diskSize === 0, "Block should not be in disk store on the driver")
}
if (distributed) {
// this file is only generated in distributed mode
assert(HttpBroadcast.getFile(blockIds.head.broadcastId).exists, "Broadcast file not found!")
}
}
// Verify that blocks are persisted in both the executors and the driver
def afterUsingBroadcast(blockIds: Seq[BroadcastBlockId], bmm: BlockManagerMaster) {
assert(blockIds.size === 1)
val statuses = bmm.getBlockStatus(blockIds.head, askSlaves = true)
assert(statuses.size === numSlaves + 1)
statuses.foreach { case (_, status) =>
assert(status.storageLevel === StorageLevel.MEMORY_AND_DISK)
assert(status.memSize > 0, "Block should be in memory store")
assert(status.diskSize === 0, "Block should not be in disk store")
}
}
// Verify that blocks are unpersisted on all executors, and on all nodes if removeFromDriver
// is true. In the latter case, also verify that the broadcast file is deleted on the driver.
def afterUnpersist(blockIds: Seq[BroadcastBlockId], bmm: BlockManagerMaster) {
assert(blockIds.size === 1)
val statuses = bmm.getBlockStatus(blockIds.head, askSlaves = true)
val expectedNumBlocks = if (removeFromDriver) 0 else 1
val possiblyNot = if (removeFromDriver) "" else " not"
assert(statuses.size === expectedNumBlocks,
"Block should%s be unpersisted on the driver".format(possiblyNot))
if (distributed && removeFromDriver) {
// this file is only generated in distributed mode
assert(!HttpBroadcast.getFile(blockIds.head.broadcastId).exists,
"Broadcast file should%s be deleted".format(possiblyNot))
}
}
testUnpersistBroadcast(distributed, numSlaves, httpConf, getBlockIds, afterCreation,
afterUsingBroadcast, afterUnpersist, removeFromDriver)
}
/**
* Verify the persistence of state associated with an TorrentBroadcast in a local-cluster.
*
* This test creates a broadcast variable, uses it on all executors, and then unpersists it.
* In between each step, this test verifies that the broadcast blocks are present only on the
* expected nodes.
*/
private def testUnpersistTorrentBroadcast(distributed: Boolean, removeFromDriver: Boolean) {
val numSlaves = if (distributed) 2 else 0
def getBlockIds(id: Long) = {
val broadcastBlockId = BroadcastBlockId(id)
val metaBlockId = BroadcastBlockId(id, "meta")
// Assume broadcast value is small enough to fit into 1 piece
val pieceBlockId = BroadcastBlockId(id, "piece0")
if (distributed) {
// the metadata and piece blocks are generated only in distributed mode
Seq[BroadcastBlockId](broadcastBlockId, metaBlockId, pieceBlockId)
} else {
Seq[BroadcastBlockId](broadcastBlockId)
}
}
// Verify that blocks are persisted only on the driver
def afterCreation(blockIds: Seq[BroadcastBlockId], bmm: BlockManagerMaster) {
blockIds.foreach { blockId =>
val statuses = bmm.getBlockStatus(blockIds.head, askSlaves = true)
assert(statuses.size === 1)
statuses.head match { case (bm, status) =>
assert(bm.executorId === "<driver>", "Block should only be on the driver")
assert(status.storageLevel === StorageLevel.MEMORY_AND_DISK)
assert(status.memSize > 0, "Block should be in memory store on the driver")
assert(status.diskSize === 0, "Block should not be in disk store on the driver")
}
}
}
// Verify that blocks are persisted in both the executors and the driver
def afterUsingBroadcast(blockIds: Seq[BroadcastBlockId], bmm: BlockManagerMaster) {
blockIds.foreach { blockId =>
val statuses = bmm.getBlockStatus(blockId, askSlaves = true)
if (blockId.field == "meta") {
// Meta data is only on the driver
assert(statuses.size === 1)
statuses.head match { case (bm, _) => assert(bm.executorId === "<driver>") }
} else {
// Other blocks are on both the executors and the driver
assert(statuses.size === numSlaves + 1,
blockId + " has " + statuses.size + " statuses: " + statuses.mkString(","))
statuses.foreach { case (_, status) =>
assert(status.storageLevel === StorageLevel.MEMORY_AND_DISK)
assert(status.memSize > 0, "Block should be in memory store")
assert(status.diskSize === 0, "Block should not be in disk store")
}
}
}
}
// Verify that blocks are unpersisted on all executors, and on all nodes if removeFromDriver
// is true.
def afterUnpersist(blockIds: Seq[BroadcastBlockId], bmm: BlockManagerMaster) {
val expectedNumBlocks = if (removeFromDriver) 0 else 1
val possiblyNot = if (removeFromDriver) "" else " not"
blockIds.foreach { blockId =>
val statuses = bmm.getBlockStatus(blockId, askSlaves = true)
assert(statuses.size === expectedNumBlocks,
"Block should%s be unpersisted on the driver".format(possiblyNot))
}
}
testUnpersistBroadcast(distributed, numSlaves, torrentConf, getBlockIds, afterCreation,
afterUsingBroadcast, afterUnpersist, removeFromDriver)
}
/**
* This test runs in 4 steps:
*
* 1) Create broadcast variable, and verify that all state is persisted on the driver.
* 2) Use the broadcast variable on all executors, and verify that all state is persisted
* on both the driver and the executors.
* 3) Unpersist the broadcast, and verify that all state is removed where they should be.
* 4) [Optional] If removeFromDriver is false, we verify that the broadcast is re-usable.
*/
private def testUnpersistBroadcast(
distributed: Boolean,
numSlaves: Int, // used only when distributed = true
broadcastConf: SparkConf,
getBlockIds: Long => Seq[BroadcastBlockId],
afterCreation: (Seq[BroadcastBlockId], BlockManagerMaster) => Unit,
afterUsingBroadcast: (Seq[BroadcastBlockId], BlockManagerMaster) => Unit,
afterUnpersist: (Seq[BroadcastBlockId], BlockManagerMaster) => Unit,
removeFromDriver: Boolean) {
sc = if (distributed) {
new SparkContext("local-cluster[%d, 1, 512]".format(numSlaves), "test", broadcastConf)
} else {
new SparkContext("local", "test", broadcastConf)
}
val blockManagerMaster = sc.env.blockManager.master
val list = List[Int](1, 2, 3, 4)
// Create broadcast variable
val broadcast = sc.broadcast(list)
val blocks = getBlockIds(broadcast.id)
afterCreation(blocks, blockManagerMaster)
// Use broadcast variable on all executors
val partitions = 10
assert(partitions > numSlaves)
val results = sc.parallelize(1 to partitions, partitions).map(x => (x, broadcast.value.sum))
assert(results.collect().toSet === (1 to partitions).map(x => (x, list.sum)).toSet)
afterUsingBroadcast(blocks, blockManagerMaster)
// Unpersist broadcast
if (removeFromDriver) {
broadcast.destroy(blocking = true)
} else {
broadcast.unpersist(blocking = true)
}
afterUnpersist(blocks, blockManagerMaster)
// If the broadcast is removed from driver, all subsequent uses of the broadcast variable
// should throw SparkExceptions. Otherwise, the result should be the same as before.
if (removeFromDriver) {
// Using this variable on the executors crashes them, which hangs the test.
// Instead, crash the driver by directly accessing the broadcast value.
intercept[SparkException] { broadcast.value }
intercept[SparkException] { broadcast.unpersist() }
intercept[SparkException] { broadcast.destroy(blocking = true) }
} else {
val results = sc.parallelize(1 to partitions, partitions).map(x => (x, broadcast.value.sum))
assert(results.collect().toSet === (1 to partitions).map(x => (x, list.sum)).toSet)
}
}
/** Helper method to create a SparkConf that uses the given broadcast factory. */
private def broadcastConf(factoryName: String): SparkConf = {
val conf = new SparkConf
conf.set("spark.broadcast.factory", "org.apache.spark.broadcast.%s".format(factoryName))
conf
}
}

View file

@ -0,0 +1,415 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark
import java.lang.ref.WeakReference
import scala.collection.mutable.{HashSet, SynchronizedSet}
import scala.util.Random
import org.scalatest.{BeforeAndAfter, FunSuite}
import org.scalatest.concurrent.Eventually
import org.scalatest.concurrent.Eventually._
import org.scalatest.time.SpanSugar._
import org.apache.spark.SparkContext._
import org.apache.spark.rdd.RDD
import org.apache.spark.storage.{BlockId, BroadcastBlockId, RDDBlockId, ShuffleBlockId}
class ContextCleanerSuite extends FunSuite with BeforeAndAfter with LocalSparkContext {
implicit val defaultTimeout = timeout(10000 millis)
val conf = new SparkConf()
.setMaster("local[2]")
.setAppName("ContextCleanerSuite")
.set("spark.cleaner.referenceTracking.blocking", "true")
before {
sc = new SparkContext(conf)
}
after {
if (sc != null) {
sc.stop()
sc = null
}
}
test("cleanup RDD") {
val rdd = newRDD.persist()
val collected = rdd.collect().toList
val tester = new CleanerTester(sc, rddIds = Seq(rdd.id))
// Explicit cleanup
cleaner.doCleanupRDD(rdd.id, blocking = true)
tester.assertCleanup()
// Verify that RDDs can be re-executed after cleaning up
assert(rdd.collect().toList === collected)
}
test("cleanup shuffle") {
val (rdd, shuffleDeps) = newRDDWithShuffleDependencies
val collected = rdd.collect().toList
val tester = new CleanerTester(sc, shuffleIds = shuffleDeps.map(_.shuffleId))
// Explicit cleanup
shuffleDeps.foreach(s => cleaner.doCleanupShuffle(s.shuffleId, blocking = true))
tester.assertCleanup()
// Verify that shuffles can be re-executed after cleaning up
assert(rdd.collect().toList === collected)
}
test("cleanup broadcast") {
val broadcast = newBroadcast
val tester = new CleanerTester(sc, broadcastIds = Seq(broadcast.id))
// Explicit cleanup
cleaner.doCleanupBroadcast(broadcast.id, blocking = true)
tester.assertCleanup()
}
test("automatically cleanup RDD") {
var rdd = newRDD.persist()
rdd.count()
// Test that GC does not cause RDD cleanup due to a strong reference
val preGCTester = new CleanerTester(sc, rddIds = Seq(rdd.id))
runGC()
intercept[Exception] {
preGCTester.assertCleanup()(timeout(1000 millis))
}
// Test that GC causes RDD cleanup after dereferencing the RDD
val postGCTester = new CleanerTester(sc, rddIds = Seq(rdd.id))
rdd = null // Make RDD out of scope
runGC()
postGCTester.assertCleanup()
}
test("automatically cleanup shuffle") {
var rdd = newShuffleRDD
rdd.count()
// Test that GC does not cause shuffle cleanup due to a strong reference
val preGCTester = new CleanerTester(sc, shuffleIds = Seq(0))
runGC()
intercept[Exception] {
preGCTester.assertCleanup()(timeout(1000 millis))
}
// Test that GC causes shuffle cleanup after dereferencing the RDD
val postGCTester = new CleanerTester(sc, shuffleIds = Seq(0))
rdd = null // Make RDD out of scope, so that corresponding shuffle goes out of scope
runGC()
postGCTester.assertCleanup()
}
test("automatically cleanup broadcast") {
var broadcast = newBroadcast
// Test that GC does not cause broadcast cleanup due to a strong reference
val preGCTester = new CleanerTester(sc, broadcastIds = Seq(broadcast.id))
runGC()
intercept[Exception] {
preGCTester.assertCleanup()(timeout(1000 millis))
}
// Test that GC causes broadcast cleanup after dereferencing the broadcast variable
val postGCTester = new CleanerTester(sc, broadcastIds = Seq(broadcast.id))
broadcast = null // Make broadcast variable out of scope
runGC()
postGCTester.assertCleanup()
}
test("automatically cleanup RDD + shuffle + broadcast") {
val numRdds = 100
val numBroadcasts = 4 // Broadcasts are more costly
val rddBuffer = (1 to numRdds).map(i => randomRdd).toBuffer
val broadcastBuffer = (1 to numBroadcasts).map(i => randomBroadcast).toBuffer
val rddIds = sc.persistentRdds.keys.toSeq
val shuffleIds = 0 until sc.newShuffleId
val broadcastIds = 0L until numBroadcasts
val preGCTester = new CleanerTester(sc, rddIds, shuffleIds, broadcastIds)
runGC()
intercept[Exception] {
preGCTester.assertCleanup()(timeout(1000 millis))
}
// Test that GC triggers the cleanup of all variables after the dereferencing them
val postGCTester = new CleanerTester(sc, rddIds, shuffleIds, broadcastIds)
broadcastBuffer.clear()
rddBuffer.clear()
runGC()
postGCTester.assertCleanup()
}
test("automatically cleanup RDD + shuffle + broadcast in distributed mode") {
sc.stop()
val conf2 = new SparkConf()
.setMaster("local-cluster[2, 1, 512]")
.setAppName("ContextCleanerSuite")
.set("spark.cleaner.referenceTracking.blocking", "true")
sc = new SparkContext(conf2)
val numRdds = 10
val numBroadcasts = 4 // Broadcasts are more costly
val rddBuffer = (1 to numRdds).map(i => randomRdd).toBuffer
val broadcastBuffer = (1 to numBroadcasts).map(i => randomBroadcast).toBuffer
val rddIds = sc.persistentRdds.keys.toSeq
val shuffleIds = 0 until sc.newShuffleId
val broadcastIds = 0L until numBroadcasts
val preGCTester = new CleanerTester(sc, rddIds, shuffleIds, broadcastIds)
runGC()
intercept[Exception] {
preGCTester.assertCleanup()(timeout(1000 millis))
}
// Test that GC triggers the cleanup of all variables after the dereferencing them
val postGCTester = new CleanerTester(sc, rddIds, shuffleIds, broadcastIds)
broadcastBuffer.clear()
rddBuffer.clear()
runGC()
postGCTester.assertCleanup()
}
//------ Helper functions ------
def newRDD = sc.makeRDD(1 to 10)
def newPairRDD = newRDD.map(_ -> 1)
def newShuffleRDD = newPairRDD.reduceByKey(_ + _)
def newBroadcast = sc.broadcast(1 to 100)
def newRDDWithShuffleDependencies: (RDD[_], Seq[ShuffleDependency[_, _]]) = {
def getAllDependencies(rdd: RDD[_]): Seq[Dependency[_]] = {
rdd.dependencies ++ rdd.dependencies.flatMap { dep =>
getAllDependencies(dep.rdd)
}
}
val rdd = newShuffleRDD
// Get all the shuffle dependencies
val shuffleDeps = getAllDependencies(rdd)
.filter(_.isInstanceOf[ShuffleDependency[_, _]])
.map(_.asInstanceOf[ShuffleDependency[_, _]])
(rdd, shuffleDeps)
}
def randomRdd = {
val rdd: RDD[_] = Random.nextInt(3) match {
case 0 => newRDD
case 1 => newShuffleRDD
case 2 => newPairRDD.join(newPairRDD)
}
if (Random.nextBoolean()) rdd.persist()
rdd.count()
rdd
}
def randomBroadcast = {
sc.broadcast(Random.nextInt(Int.MaxValue))
}
/** Run GC and make sure it actually has run */
def runGC() {
val weakRef = new WeakReference(new Object())
val startTime = System.currentTimeMillis
System.gc() // Make a best effort to run the garbage collection. It *usually* runs GC.
// Wait until a weak reference object has been GCed
while(System.currentTimeMillis - startTime < 10000 && weakRef.get != null) {
System.gc()
Thread.sleep(200)
}
}
def cleaner = sc.cleaner.get
}
/** Class to test whether RDDs, shuffles, etc. have been successfully cleaned. */
class CleanerTester(
sc: SparkContext,
rddIds: Seq[Int] = Seq.empty,
shuffleIds: Seq[Int] = Seq.empty,
broadcastIds: Seq[Long] = Seq.empty)
extends Logging {
val toBeCleanedRDDIds = new HashSet[Int] with SynchronizedSet[Int] ++= rddIds
val toBeCleanedShuffleIds = new HashSet[Int] with SynchronizedSet[Int] ++= shuffleIds
val toBeCleanedBroadcstIds = new HashSet[Long] with SynchronizedSet[Long] ++= broadcastIds
val isDistributed = !sc.isLocal
val cleanerListener = new CleanerListener {
def rddCleaned(rddId: Int): Unit = {
toBeCleanedRDDIds -= rddId
logInfo("RDD "+ rddId + " cleaned")
}
def shuffleCleaned(shuffleId: Int): Unit = {
toBeCleanedShuffleIds -= shuffleId
logInfo("Shuffle " + shuffleId + " cleaned")
}
def broadcastCleaned(broadcastId: Long): Unit = {
toBeCleanedBroadcstIds -= broadcastId
logInfo("Broadcast" + broadcastId + " cleaned")
}
}
val MAX_VALIDATION_ATTEMPTS = 10
val VALIDATION_ATTEMPT_INTERVAL = 100
logInfo("Attempting to validate before cleanup:\n" + uncleanedResourcesToString)
preCleanupValidate()
sc.cleaner.get.attachListener(cleanerListener)
/** Assert that all the stuff has been cleaned up */
def assertCleanup()(implicit waitTimeout: Eventually.Timeout) {
try {
eventually(waitTimeout, interval(100 millis)) {
assert(isAllCleanedUp)
}
postCleanupValidate()
} finally {
logInfo("Resources left from cleaning up:\n" + uncleanedResourcesToString)
}
}
/** Verify that RDDs, shuffles, etc. occupy resources */
private def preCleanupValidate() {
assert(rddIds.nonEmpty || shuffleIds.nonEmpty || broadcastIds.nonEmpty, "Nothing to cleanup")
// Verify the RDDs have been persisted and blocks are present
rddIds.foreach { rddId =>
assert(
sc.persistentRdds.contains(rddId),
"RDD " + rddId + " have not been persisted, cannot start cleaner test"
)
assert(
!getRDDBlocks(rddId).isEmpty,
"Blocks of RDD " + rddId + " cannot be found in block manager, " +
"cannot start cleaner test"
)
}
// Verify the shuffle ids are registered and blocks are present
shuffleIds.foreach { shuffleId =>
assert(
mapOutputTrackerMaster.containsShuffle(shuffleId),
"Shuffle " + shuffleId + " have not been registered, cannot start cleaner test"
)
assert(
!getShuffleBlocks(shuffleId).isEmpty,
"Blocks of shuffle " + shuffleId + " cannot be found in block manager, " +
"cannot start cleaner test"
)
}
// Verify that the broadcast blocks are present
broadcastIds.foreach { broadcastId =>
assert(
!getBroadcastBlocks(broadcastId).isEmpty,
"Blocks of broadcast " + broadcastId + "cannot be found in block manager, " +
"cannot start cleaner test"
)
}
}
/**
* Verify that RDDs, shuffles, etc. do not occupy resources. Tests multiple times as there is
* as there is not guarantee on how long it will take clean up the resources.
*/
private def postCleanupValidate() {
// Verify the RDDs have been persisted and blocks are present
rddIds.foreach { rddId =>
assert(
!sc.persistentRdds.contains(rddId),
"RDD " + rddId + " was not cleared from sc.persistentRdds"
)
assert(
getRDDBlocks(rddId).isEmpty,
"Blocks of RDD " + rddId + " were not cleared from block manager"
)
}
// Verify the shuffle ids are registered and blocks are present
shuffleIds.foreach { shuffleId =>
assert(
!mapOutputTrackerMaster.containsShuffle(shuffleId),
"Shuffle " + shuffleId + " was not deregistered from map output tracker"
)
assert(
getShuffleBlocks(shuffleId).isEmpty,
"Blocks of shuffle " + shuffleId + " were not cleared from block manager"
)
}
// Verify that the broadcast blocks are present
broadcastIds.foreach { broadcastId =>
assert(
getBroadcastBlocks(broadcastId).isEmpty,
"Blocks of broadcast " + broadcastId + " were not cleared from block manager"
)
}
}
private def uncleanedResourcesToString = {
s"""
|\tRDDs = ${toBeCleanedRDDIds.toSeq.sorted.mkString("[", ", ", "]")}
|\tShuffles = ${toBeCleanedShuffleIds.toSeq.sorted.mkString("[", ", ", "]")}
|\tBroadcasts = ${toBeCleanedBroadcstIds.toSeq.sorted.mkString("[", ", ", "]")}
""".stripMargin
}
private def isAllCleanedUp =
toBeCleanedRDDIds.isEmpty &&
toBeCleanedShuffleIds.isEmpty &&
toBeCleanedBroadcstIds.isEmpty
private def getRDDBlocks(rddId: Int): Seq[BlockId] = {
blockManager.master.getMatchingBlockIds( _ match {
case RDDBlockId(`rddId`, _) => true
case _ => false
}, askSlaves = true)
}
private def getShuffleBlocks(shuffleId: Int): Seq[BlockId] = {
blockManager.master.getMatchingBlockIds( _ match {
case ShuffleBlockId(`shuffleId`, _, _) => true
case _ => false
}, askSlaves = true)
}
private def getBroadcastBlocks(broadcastId: Long): Seq[BlockId] = {
blockManager.master.getMatchingBlockIds( _ match {
case BroadcastBlockId(`broadcastId`, _) => true
case _ => false
}, askSlaves = true)
}
private def blockManager = sc.env.blockManager
private def mapOutputTrackerMaster = sc.env.mapOutputTracker.asInstanceOf[MapOutputTrackerMaster]
}

View file

@ -57,12 +57,13 @@ class MapOutputTrackerSuite extends FunSuite with LocalSparkContext {
tracker.stop()
}
test("master register and fetch") {
test("master register shuffle and fetch") {
val actorSystem = ActorSystem("test")
val tracker = new MapOutputTrackerMaster(conf)
tracker.trackerActor =
actorSystem.actorOf(Props(new MapOutputTrackerMasterActor(tracker, conf)))
tracker.registerShuffle(10, 2)
assert(tracker.containsShuffle(10))
val compressedSize1000 = MapOutputTracker.compressSize(1000L)
val compressedSize10000 = MapOutputTracker.compressSize(10000L)
val size1000 = MapOutputTracker.decompressSize(compressedSize1000)
@ -77,7 +78,25 @@ class MapOutputTrackerSuite extends FunSuite with LocalSparkContext {
tracker.stop()
}
test("master register and unregister and fetch") {
test("master register and unregister shuffle") {
val actorSystem = ActorSystem("test")
val tracker = new MapOutputTrackerMaster(conf)
tracker.trackerActor = actorSystem.actorOf(Props(new MapOutputTrackerMasterActor(tracker, conf)))
tracker.registerShuffle(10, 2)
val compressedSize1000 = MapOutputTracker.compressSize(1000L)
val compressedSize10000 = MapOutputTracker.compressSize(10000L)
tracker.registerMapOutput(10, 0, new MapStatus(BlockManagerId("a", "hostA", 1000, 0),
Array(compressedSize1000, compressedSize10000)))
tracker.registerMapOutput(10, 1, new MapStatus(BlockManagerId("b", "hostB", 1000, 0),
Array(compressedSize10000, compressedSize1000)))
assert(tracker.containsShuffle(10))
assert(tracker.getServerStatuses(10, 0).nonEmpty)
tracker.unregisterShuffle(10)
assert(!tracker.containsShuffle(10))
assert(tracker.getServerStatuses(10, 0).isEmpty)
}
test("master register shuffle and unregister map output and fetch") {
val actorSystem = ActorSystem("test")
val tracker = new MapOutputTrackerMaster(conf)
tracker.trackerActor =
@ -114,7 +133,7 @@ class MapOutputTrackerSuite extends FunSuite with LocalSparkContext {
val (slaveSystem, _) = AkkaUtils.createActorSystem("spark-slave", hostname, 0, conf = conf,
securityManager = new SecurityManager(conf))
val slaveTracker = new MapOutputTracker(conf)
val slaveTracker = new MapOutputTrackerWorker(conf)
val selection = slaveSystem.actorSelection(
s"akka.tcp://spark@localhost:$boundPort/user/MapOutputTracker")
val timeout = AkkaUtils.lookupTimeout(conf)

View file

@ -28,7 +28,7 @@ import org.scalatest.concurrent.Timeouts._
import org.scalatest.matchers.ShouldMatchers._
import org.scalatest.time.SpanSugar._
import org.apache.spark.{SecurityManager, SparkConf}
import org.apache.spark.{MapOutputTrackerMaster, SecurityManager, SparkConf}
import org.apache.spark.scheduler.LiveListenerBus
import org.apache.spark.serializer.{JavaSerializer, KryoSerializer}
import org.apache.spark.util.{AkkaUtils, ByteBufferInputStream, SizeEstimator, Utils}
@ -42,6 +42,7 @@ class BlockManagerSuite extends FunSuite with BeforeAndAfter with PrivateMethodT
var oldArch: String = null
conf.set("spark.authenticate", "false")
val securityMgr = new SecurityManager(conf)
val mapOutputTracker = new MapOutputTrackerMaster(conf)
// Reuse a serializer across tests to avoid creating a new thread-local buffer on each test
conf.set("spark.kryoserializer.buffer.mb", "1")
@ -130,7 +131,8 @@ class BlockManagerSuite extends FunSuite with BeforeAndAfter with PrivateMethodT
}
test("master + 1 manager interaction") {
store = new BlockManager("<driver>", actorSystem, master, serializer, 2000, conf, securityMgr)
store = new BlockManager("<driver>", actorSystem, master, serializer, 2000, conf,
securityMgr, mapOutputTracker)
val a1 = new Array[Byte](400)
val a2 = new Array[Byte](400)
val a3 = new Array[Byte](400)
@ -160,9 +162,10 @@ class BlockManagerSuite extends FunSuite with BeforeAndAfter with PrivateMethodT
}
test("master + 2 managers interaction") {
store = new BlockManager("exec1", actorSystem, master, serializer, 2000, conf, securityMgr)
store = new BlockManager("exec1", actorSystem, master, serializer, 2000, conf,
securityMgr, mapOutputTracker)
store2 = new BlockManager("exec2", actorSystem, master, new KryoSerializer(conf), 2000, conf,
securityMgr)
securityMgr, mapOutputTracker)
val peers = master.getPeers(store.blockManagerId, 1)
assert(peers.size === 1, "master did not return the other manager as a peer")
@ -177,7 +180,8 @@ class BlockManagerSuite extends FunSuite with BeforeAndAfter with PrivateMethodT
}
test("removing block") {
store = new BlockManager("<driver>", actorSystem, master, serializer, 2000, conf, securityMgr)
store = new BlockManager("<driver>", actorSystem, master, serializer, 2000, conf,
securityMgr, mapOutputTracker)
val a1 = new Array[Byte](400)
val a2 = new Array[Byte](400)
val a3 = new Array[Byte](400)
@ -225,7 +229,8 @@ class BlockManagerSuite extends FunSuite with BeforeAndAfter with PrivateMethodT
}
test("removing rdd") {
store = new BlockManager("<driver>", actorSystem, master, serializer, 2000, conf, securityMgr)
store = new BlockManager("<driver>", actorSystem, master, serializer, 2000, conf,
securityMgr, mapOutputTracker)
val a1 = new Array[Byte](400)
val a2 = new Array[Byte](400)
val a3 = new Array[Byte](400)
@ -257,9 +262,82 @@ class BlockManagerSuite extends FunSuite with BeforeAndAfter with PrivateMethodT
master.getLocations(rdd(0, 1)) should have size 0
}
test("removing broadcast") {
store = new BlockManager("<driver>", actorSystem, master, serializer, 2000, conf,
securityMgr, mapOutputTracker)
val driverStore = store
val executorStore = new BlockManager("executor", actorSystem, master, serializer, 2000, conf,
securityMgr, mapOutputTracker)
val a1 = new Array[Byte](400)
val a2 = new Array[Byte](400)
val a3 = new Array[Byte](400)
val a4 = new Array[Byte](400)
val broadcast0BlockId = BroadcastBlockId(0)
val broadcast1BlockId = BroadcastBlockId(1)
val broadcast2BlockId = BroadcastBlockId(2)
val broadcast2BlockId2 = BroadcastBlockId(2, "_")
// insert broadcast blocks in both the stores
Seq(driverStore, executorStore).foreach { case s =>
s.putSingle(broadcast0BlockId, a1, StorageLevel.DISK_ONLY)
s.putSingle(broadcast1BlockId, a2, StorageLevel.DISK_ONLY)
s.putSingle(broadcast2BlockId, a3, StorageLevel.DISK_ONLY)
s.putSingle(broadcast2BlockId2, a4, StorageLevel.DISK_ONLY)
}
// verify whether the blocks exist in both the stores
Seq(driverStore, executorStore).foreach { case s =>
s.getLocal(broadcast0BlockId) should not be (None)
s.getLocal(broadcast1BlockId) should not be (None)
s.getLocal(broadcast2BlockId) should not be (None)
s.getLocal(broadcast2BlockId2) should not be (None)
}
// remove broadcast 0 block only from executors
master.removeBroadcast(0, removeFromMaster = false, blocking = true)
// only broadcast 0 block should be removed from the executor store
executorStore.getLocal(broadcast0BlockId) should be (None)
executorStore.getLocal(broadcast1BlockId) should not be (None)
executorStore.getLocal(broadcast2BlockId) should not be (None)
// nothing should be removed from the driver store
driverStore.getLocal(broadcast0BlockId) should not be (None)
driverStore.getLocal(broadcast1BlockId) should not be (None)
driverStore.getLocal(broadcast2BlockId) should not be (None)
// remove broadcast 0 block from the driver as well
master.removeBroadcast(0, removeFromMaster = true, blocking = true)
driverStore.getLocal(broadcast0BlockId) should be (None)
driverStore.getLocal(broadcast1BlockId) should not be (None)
// remove broadcast 1 block from both the stores asynchronously
// and verify all broadcast 1 blocks have been removed
master.removeBroadcast(1, removeFromMaster = true, blocking = false)
eventually(timeout(1000 milliseconds), interval(10 milliseconds)) {
driverStore.getLocal(broadcast1BlockId) should be (None)
executorStore.getLocal(broadcast1BlockId) should be (None)
}
// remove broadcast 2 from both the stores asynchronously
// and verify all broadcast 2 blocks have been removed
master.removeBroadcast(2, removeFromMaster = true, blocking = false)
eventually(timeout(1000 milliseconds), interval(10 milliseconds)) {
driverStore.getLocal(broadcast2BlockId) should be (None)
driverStore.getLocal(broadcast2BlockId2) should be (None)
executorStore.getLocal(broadcast2BlockId) should be (None)
executorStore.getLocal(broadcast2BlockId2) should be (None)
}
executorStore.stop()
driverStore.stop()
store = null
}
test("reregistration on heart beat") {
val heartBeat = PrivateMethod[Unit]('heartBeat)
store = new BlockManager("<driver>", actorSystem, master, serializer, 2000, conf, securityMgr)
store = new BlockManager("<driver>", actorSystem, master, serializer, 2000, conf,
securityMgr, mapOutputTracker)
val a1 = new Array[Byte](400)
store.putSingle("a1", a1, StorageLevel.MEMORY_ONLY)
@ -275,7 +353,8 @@ class BlockManagerSuite extends FunSuite with BeforeAndAfter with PrivateMethodT
}
test("reregistration on block update") {
store = new BlockManager("<driver>", actorSystem, master, serializer, 2000, conf, securityMgr)
store = new BlockManager("<driver>", actorSystem, master, serializer, 2000, conf,
securityMgr, mapOutputTracker)
val a1 = new Array[Byte](400)
val a2 = new Array[Byte](400)
@ -294,7 +373,8 @@ class BlockManagerSuite extends FunSuite with BeforeAndAfter with PrivateMethodT
test("reregistration doesn't dead lock") {
val heartBeat = PrivateMethod[Unit]('heartBeat)
store = new BlockManager("<driver>", actorSystem, master, serializer, 2000, conf, securityMgr)
store = new BlockManager("<driver>", actorSystem, master, serializer, 2000, conf,
securityMgr, mapOutputTracker)
val a1 = new Array[Byte](400)
val a2 = List(new Array[Byte](400))
@ -331,7 +411,8 @@ class BlockManagerSuite extends FunSuite with BeforeAndAfter with PrivateMethodT
}
test("in-memory LRU storage") {
store = new BlockManager("<driver>", actorSystem, master, serializer, 1200, conf, securityMgr)
store = new BlockManager("<driver>", actorSystem, master, serializer, 1200, conf,
securityMgr, mapOutputTracker)
val a1 = new Array[Byte](400)
val a2 = new Array[Byte](400)
val a3 = new Array[Byte](400)
@ -350,7 +431,8 @@ class BlockManagerSuite extends FunSuite with BeforeAndAfter with PrivateMethodT
}
test("in-memory LRU storage with serialization") {
store = new BlockManager("<driver>", actorSystem, master, serializer, 1200, conf, securityMgr)
store = new BlockManager("<driver>", actorSystem, master, serializer, 1200, conf,
securityMgr, mapOutputTracker)
val a1 = new Array[Byte](400)
val a2 = new Array[Byte](400)
val a3 = new Array[Byte](400)
@ -369,7 +451,8 @@ class BlockManagerSuite extends FunSuite with BeforeAndAfter with PrivateMethodT
}
test("in-memory LRU for partitions of same RDD") {
store = new BlockManager("<driver>", actorSystem, master, serializer, 1200, conf, securityMgr)
store = new BlockManager("<driver>", actorSystem, master, serializer, 1200, conf,
securityMgr, mapOutputTracker)
val a1 = new Array[Byte](400)
val a2 = new Array[Byte](400)
val a3 = new Array[Byte](400)
@ -388,7 +471,8 @@ class BlockManagerSuite extends FunSuite with BeforeAndAfter with PrivateMethodT
}
test("in-memory LRU for partitions of multiple RDDs") {
store = new BlockManager("<driver>", actorSystem, master, serializer, 1200, conf, securityMgr)
store = new BlockManager("<driver>", actorSystem, master, serializer, 1200, conf,
securityMgr, mapOutputTracker)
store.putSingle(rdd(0, 1), new Array[Byte](400), StorageLevel.MEMORY_ONLY)
store.putSingle(rdd(0, 2), new Array[Byte](400), StorageLevel.MEMORY_ONLY)
store.putSingle(rdd(1, 1), new Array[Byte](400), StorageLevel.MEMORY_ONLY)
@ -414,7 +498,8 @@ class BlockManagerSuite extends FunSuite with BeforeAndAfter with PrivateMethodT
// TODO Make the spark.test.tachyon.enable true after using tachyon 0.5.0 testing jar.
val tachyonUnitTestEnabled = conf.getBoolean("spark.test.tachyon.enable", false)
if (tachyonUnitTestEnabled) {
store = new BlockManager("<driver>", actorSystem, master, serializer, 1200, conf, securityMgr)
store = new BlockManager("<driver>", actorSystem, master, serializer, 1200, conf,
securityMgr, mapOutputTracker)
val a1 = new Array[Byte](400)
val a2 = new Array[Byte](400)
val a3 = new Array[Byte](400)
@ -430,7 +515,8 @@ class BlockManagerSuite extends FunSuite with BeforeAndAfter with PrivateMethodT
}
test("on-disk storage") {
store = new BlockManager("<driver>", actorSystem, master, serializer, 1200, conf, securityMgr)
store = new BlockManager("<driver>", actorSystem, master, serializer, 1200, conf,
securityMgr, mapOutputTracker)
val a1 = new Array[Byte](400)
val a2 = new Array[Byte](400)
val a3 = new Array[Byte](400)
@ -443,7 +529,8 @@ class BlockManagerSuite extends FunSuite with BeforeAndAfter with PrivateMethodT
}
test("disk and memory storage") {
store = new BlockManager("<driver>", actorSystem, master, serializer, 1200, conf, securityMgr)
store = new BlockManager("<driver>", actorSystem, master, serializer, 1200, conf,
securityMgr, mapOutputTracker)
val a1 = new Array[Byte](400)
val a2 = new Array[Byte](400)
val a3 = new Array[Byte](400)
@ -458,7 +545,8 @@ class BlockManagerSuite extends FunSuite with BeforeAndAfter with PrivateMethodT
}
test("disk and memory storage with getLocalBytes") {
store = new BlockManager("<driver>", actorSystem, master, serializer, 1200, conf, securityMgr)
store = new BlockManager("<driver>", actorSystem, master, serializer, 1200, conf,
securityMgr, mapOutputTracker)
val a1 = new Array[Byte](400)
val a2 = new Array[Byte](400)
val a3 = new Array[Byte](400)
@ -473,7 +561,8 @@ class BlockManagerSuite extends FunSuite with BeforeAndAfter with PrivateMethodT
}
test("disk and memory storage with serialization") {
store = new BlockManager("<driver>", actorSystem, master, serializer, 1200, conf, securityMgr)
store = new BlockManager("<driver>", actorSystem, master, serializer, 1200, conf,
securityMgr, mapOutputTracker)
val a1 = new Array[Byte](400)
val a2 = new Array[Byte](400)
val a3 = new Array[Byte](400)
@ -488,7 +577,8 @@ class BlockManagerSuite extends FunSuite with BeforeAndAfter with PrivateMethodT
}
test("disk and memory storage with serialization and getLocalBytes") {
store = new BlockManager("<driver>", actorSystem, master, serializer, 1200, conf, securityMgr)
store = new BlockManager("<driver>", actorSystem, master, serializer, 1200, conf,
securityMgr, mapOutputTracker)
val a1 = new Array[Byte](400)
val a2 = new Array[Byte](400)
val a3 = new Array[Byte](400)
@ -503,7 +593,8 @@ class BlockManagerSuite extends FunSuite with BeforeAndAfter with PrivateMethodT
}
test("LRU with mixed storage levels") {
store = new BlockManager("<driver>", actorSystem, master, serializer, 1200, conf, securityMgr)
store = new BlockManager("<driver>", actorSystem, master, serializer, 1200, conf,
securityMgr, mapOutputTracker)
val a1 = new Array[Byte](400)
val a2 = new Array[Byte](400)
val a3 = new Array[Byte](400)
@ -525,7 +616,8 @@ class BlockManagerSuite extends FunSuite with BeforeAndAfter with PrivateMethodT
}
test("in-memory LRU with streams") {
store = new BlockManager("<driver>", actorSystem, master, serializer, 1200, conf, securityMgr)
store = new BlockManager("<driver>", actorSystem, master, serializer, 1200, conf,
securityMgr, mapOutputTracker)
val list1 = List(new Array[Byte](200), new Array[Byte](200))
val list2 = List(new Array[Byte](200), new Array[Byte](200))
val list3 = List(new Array[Byte](200), new Array[Byte](200))
@ -549,7 +641,8 @@ class BlockManagerSuite extends FunSuite with BeforeAndAfter with PrivateMethodT
}
test("LRU with mixed storage levels and streams") {
store = new BlockManager("<driver>", actorSystem, master, serializer, 1200, conf, securityMgr)
store = new BlockManager("<driver>", actorSystem, master, serializer, 1200, conf,
securityMgr, mapOutputTracker)
val list1 = List(new Array[Byte](200), new Array[Byte](200))
val list2 = List(new Array[Byte](200), new Array[Byte](200))
val list3 = List(new Array[Byte](200), new Array[Byte](200))
@ -595,7 +688,8 @@ class BlockManagerSuite extends FunSuite with BeforeAndAfter with PrivateMethodT
}
test("overly large block") {
store = new BlockManager("<driver>", actorSystem, master, serializer, 500, conf, securityMgr)
store = new BlockManager("<driver>", actorSystem, master, serializer, 500, conf,
securityMgr, mapOutputTracker)
store.putSingle("a1", new Array[Byte](1000), StorageLevel.MEMORY_ONLY)
assert(store.getSingle("a1") === None, "a1 was in store")
store.putSingle("a2", new Array[Byte](1000), StorageLevel.MEMORY_AND_DISK)
@ -606,7 +700,8 @@ class BlockManagerSuite extends FunSuite with BeforeAndAfter with PrivateMethodT
test("block compression") {
try {
conf.set("spark.shuffle.compress", "true")
store = new BlockManager("exec1", actorSystem, master, serializer, 2000, conf, securityMgr)
store = new BlockManager("exec1", actorSystem, master, serializer, 2000, conf,
securityMgr, mapOutputTracker)
store.putSingle(ShuffleBlockId(0, 0, 0), new Array[Byte](1000), StorageLevel.MEMORY_ONLY_SER)
assert(store.memoryStore.getSize(ShuffleBlockId(0, 0, 0)) <= 100,
"shuffle_0_0_0 was not compressed")
@ -614,7 +709,8 @@ class BlockManagerSuite extends FunSuite with BeforeAndAfter with PrivateMethodT
store = null
conf.set("spark.shuffle.compress", "false")
store = new BlockManager("exec2", actorSystem, master, serializer, 2000, conf, securityMgr)
store = new BlockManager("exec2", actorSystem, master, serializer, 2000, conf,
securityMgr, mapOutputTracker)
store.putSingle(ShuffleBlockId(0, 0, 0), new Array[Byte](1000), StorageLevel.MEMORY_ONLY_SER)
assert(store.memoryStore.getSize(ShuffleBlockId(0, 0, 0)) >= 1000,
"shuffle_0_0_0 was compressed")
@ -622,7 +718,8 @@ class BlockManagerSuite extends FunSuite with BeforeAndAfter with PrivateMethodT
store = null
conf.set("spark.broadcast.compress", "true")
store = new BlockManager("exec3", actorSystem, master, serializer, 2000, conf, securityMgr)
store = new BlockManager("exec3", actorSystem, master, serializer, 2000, conf,
securityMgr, mapOutputTracker)
store.putSingle(BroadcastBlockId(0), new Array[Byte](1000), StorageLevel.MEMORY_ONLY_SER)
assert(store.memoryStore.getSize(BroadcastBlockId(0)) <= 100,
"broadcast_0 was not compressed")
@ -630,28 +727,32 @@ class BlockManagerSuite extends FunSuite with BeforeAndAfter with PrivateMethodT
store = null
conf.set("spark.broadcast.compress", "false")
store = new BlockManager("exec4", actorSystem, master, serializer, 2000, conf, securityMgr)
store = new BlockManager("exec4", actorSystem, master, serializer, 2000, conf,
securityMgr, mapOutputTracker)
store.putSingle(BroadcastBlockId(0), new Array[Byte](1000), StorageLevel.MEMORY_ONLY_SER)
assert(store.memoryStore.getSize(BroadcastBlockId(0)) >= 1000, "broadcast_0 was compressed")
store.stop()
store = null
conf.set("spark.rdd.compress", "true")
store = new BlockManager("exec5", actorSystem, master, serializer, 2000, conf, securityMgr)
store = new BlockManager("exec5", actorSystem, master, serializer, 2000, conf,
securityMgr, mapOutputTracker)
store.putSingle(rdd(0, 0), new Array[Byte](1000), StorageLevel.MEMORY_ONLY_SER)
assert(store.memoryStore.getSize(rdd(0, 0)) <= 100, "rdd_0_0 was not compressed")
store.stop()
store = null
conf.set("spark.rdd.compress", "false")
store = new BlockManager("exec6", actorSystem, master, serializer, 2000, conf, securityMgr)
store = new BlockManager("exec6", actorSystem, master, serializer, 2000, conf,
securityMgr, mapOutputTracker)
store.putSingle(rdd(0, 0), new Array[Byte](1000), StorageLevel.MEMORY_ONLY_SER)
assert(store.memoryStore.getSize(rdd(0, 0)) >= 1000, "rdd_0_0 was compressed")
store.stop()
store = null
// Check that any other block types are also kept uncompressed
store = new BlockManager("exec7", actorSystem, master, serializer, 2000, conf, securityMgr)
store = new BlockManager("exec7", actorSystem, master, serializer, 2000, conf,
securityMgr, mapOutputTracker)
store.putSingle("other_block", new Array[Byte](1000), StorageLevel.MEMORY_ONLY)
assert(store.memoryStore.getSize("other_block") >= 1000, "other_block was compressed")
store.stop()
@ -666,7 +767,7 @@ class BlockManagerSuite extends FunSuite with BeforeAndAfter with PrivateMethodT
test("block store put failure") {
// Use Java serializer so we can create an unserializable error.
store = new BlockManager("<driver>", actorSystem, master, new JavaSerializer(conf), 1200, conf,
securityMgr)
securityMgr, mapOutputTracker)
// The put should fail since a1 is not serializable.
class UnserializableClass
@ -682,7 +783,8 @@ class BlockManagerSuite extends FunSuite with BeforeAndAfter with PrivateMethodT
}
test("updated block statuses") {
store = new BlockManager("<driver>", actorSystem, master, serializer, 1200, conf, securityMgr)
store = new BlockManager("<driver>", actorSystem, master, serializer, 1200, conf,
securityMgr, mapOutputTracker)
val list = List.fill(2)(new Array[Byte](200))
val bigList = List.fill(8)(new Array[Byte](200))
@ -735,8 +837,83 @@ class BlockManagerSuite extends FunSuite with BeforeAndAfter with PrivateMethodT
assert(!store.get("list5").isDefined, "list5 was in store")
}
test("query block statuses") {
store = new BlockManager("<driver>", actorSystem, master, serializer, 1200, conf,
securityMgr, mapOutputTracker)
val list = List.fill(2)(new Array[Byte](200))
// Tell master. By LRU, only list2 and list3 remains.
store.put("list1", list.iterator, StorageLevel.MEMORY_ONLY, tellMaster = true)
store.put("list2", list.iterator, StorageLevel.MEMORY_AND_DISK, tellMaster = true)
store.put("list3", list.iterator, StorageLevel.MEMORY_ONLY, tellMaster = true)
// getLocations and getBlockStatus should yield the same locations
assert(store.master.getLocations("list1").size === 0)
assert(store.master.getLocations("list2").size === 1)
assert(store.master.getLocations("list3").size === 1)
assert(store.master.getBlockStatus("list1", askSlaves = false).size === 0)
assert(store.master.getBlockStatus("list2", askSlaves = false).size === 1)
assert(store.master.getBlockStatus("list3", askSlaves = false).size === 1)
assert(store.master.getBlockStatus("list1", askSlaves = true).size === 0)
assert(store.master.getBlockStatus("list2", askSlaves = true).size === 1)
assert(store.master.getBlockStatus("list3", askSlaves = true).size === 1)
// This time don't tell master and see what happens. By LRU, only list5 and list6 remains.
store.put("list4", list.iterator, StorageLevel.MEMORY_ONLY, tellMaster = false)
store.put("list5", list.iterator, StorageLevel.MEMORY_AND_DISK, tellMaster = false)
store.put("list6", list.iterator, StorageLevel.MEMORY_ONLY, tellMaster = false)
// getLocations should return nothing because the master is not informed
// getBlockStatus without asking slaves should have the same result
// getBlockStatus with asking slaves, however, should return the actual block statuses
assert(store.master.getLocations("list4").size === 0)
assert(store.master.getLocations("list5").size === 0)
assert(store.master.getLocations("list6").size === 0)
assert(store.master.getBlockStatus("list4", askSlaves = false).size === 0)
assert(store.master.getBlockStatus("list5", askSlaves = false).size === 0)
assert(store.master.getBlockStatus("list6", askSlaves = false).size === 0)
assert(store.master.getBlockStatus("list4", askSlaves = true).size === 0)
assert(store.master.getBlockStatus("list5", askSlaves = true).size === 1)
assert(store.master.getBlockStatus("list6", askSlaves = true).size === 1)
}
test("get matching blocks") {
store = new BlockManager("<driver>", actorSystem, master, serializer, 1200, conf,
securityMgr, mapOutputTracker)
val list = List.fill(2)(new Array[Byte](10))
// insert some blocks
store.put("list1", list.iterator, StorageLevel.MEMORY_AND_DISK, tellMaster = true)
store.put("list2", list.iterator, StorageLevel.MEMORY_AND_DISK, tellMaster = true)
store.put("list3", list.iterator, StorageLevel.MEMORY_AND_DISK, tellMaster = true)
// getLocations and getBlockStatus should yield the same locations
assert(store.master.getMatchingBlockIds(_.toString.contains("list"), askSlaves = false).size === 3)
assert(store.master.getMatchingBlockIds(_.toString.contains("list1"), askSlaves = false).size === 1)
// insert some more blocks
store.put("newlist1", list.iterator, StorageLevel.MEMORY_AND_DISK, tellMaster = true)
store.put("newlist2", list.iterator, StorageLevel.MEMORY_AND_DISK, tellMaster = false)
store.put("newlist3", list.iterator, StorageLevel.MEMORY_AND_DISK, tellMaster = false)
// getLocations and getBlockStatus should yield the same locations
assert(store.master.getMatchingBlockIds(_.toString.contains("newlist"), askSlaves = false).size === 1)
assert(store.master.getMatchingBlockIds(_.toString.contains("newlist"), askSlaves = true).size === 3)
val blockIds = Seq(RDDBlockId(1, 0), RDDBlockId(1, 1), RDDBlockId(2, 0))
blockIds.foreach { blockId =>
store.put(blockId, list.iterator, StorageLevel.MEMORY_ONLY, tellMaster = true)
}
val matchedBlockIds = store.master.getMatchingBlockIds(_ match {
case RDDBlockId(1, _) => true
case _ => false
}, askSlaves = true)
assert(matchedBlockIds.toSet === Set(RDDBlockId(1, 0), RDDBlockId(1, 1)))
}
test("SPARK-1194 regression: fix the same-RDD rule for cache replacement") {
store = new BlockManager("<driver>", actorSystem, master, serializer, 1200, conf, securityMgr)
store = new BlockManager("<driver>", actorSystem, master, serializer, 1200, conf,
securityMgr, mapOutputTracker)
store.putSingle(rdd(0, 0), new Array[Byte](400), StorageLevel.MEMORY_ONLY)
store.putSingle(rdd(1, 0), new Array[Byte](400), StorageLevel.MEMORY_ONLY)
// Access rdd_1_0 to ensure it's not least recently used.

View file

@ -59,8 +59,16 @@ class DiskBlockManagerSuite extends FunSuite with BeforeAndAfterEach {
val newFile = diskBlockManager.getFile(blockId)
writeToFile(newFile, 10)
assertSegmentEquals(blockId, blockId.name, 0, 10)
assert(diskBlockManager.containsBlock(blockId))
newFile.delete()
assert(!diskBlockManager.containsBlock(blockId))
}
test("enumerating blocks") {
val ids = (1 to 100).map(i => TestBlockId("test_" + i))
val files = ids.map(id => diskBlockManager.getFile(id))
files.foreach(file => writeToFile(file, 10))
assert(diskBlockManager.getAllBlocks.toSet === ids.toSet)
}
test("block appending") {

View file

@ -108,8 +108,7 @@ class JsonProtocolSuite extends FunSuite {
// BlockId
testBlockId(RDDBlockId(1, 2))
testBlockId(ShuffleBlockId(1, 2, 3))
testBlockId(BroadcastBlockId(1L))
testBlockId(BroadcastHelperBlockId(BroadcastBlockId(2L), "Spark"))
testBlockId(BroadcastBlockId(1L, "insert_words_of_wisdom_here"))
testBlockId(TaskResultBlockId(1L))
testBlockId(StreamBlockId(1, 2L))
}
@ -555,4 +554,4 @@ class JsonProtocolSuite extends FunSuite {
{"Event":"SparkListenerUnpersistRDD","RDD ID":12345}
"""
}
}

View file

@ -0,0 +1,264 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.util
import java.lang.ref.WeakReference
import scala.collection.mutable
import scala.collection.mutable.ArrayBuffer
import scala.util.Random
import org.scalatest.FunSuite
class TimeStampedHashMapSuite extends FunSuite {
// Test the testMap function - a Scala HashMap should obviously pass
testMap(new mutable.HashMap[String, String]())
// Test TimeStampedHashMap basic functionality
testMap(new TimeStampedHashMap[String, String]())
testMapThreadSafety(new TimeStampedHashMap[String, String]())
// Test TimeStampedWeakValueHashMap basic functionality
testMap(new TimeStampedWeakValueHashMap[String, String]())
testMapThreadSafety(new TimeStampedWeakValueHashMap[String, String]())
test("TimeStampedHashMap - clearing by timestamp") {
// clearing by insertion time
val map = new TimeStampedHashMap[String, String](updateTimeStampOnGet = false)
map("k1") = "v1"
assert(map("k1") === "v1")
Thread.sleep(10)
val threshTime = System.currentTimeMillis
assert(map.getTimestamp("k1").isDefined)
assert(map.getTimestamp("k1").get < threshTime)
map.clearOldValues(threshTime)
assert(map.get("k1") === None)
// clearing by modification time
val map1 = new TimeStampedHashMap[String, String](updateTimeStampOnGet = true)
map1("k1") = "v1"
map1("k2") = "v2"
assert(map1("k1") === "v1")
Thread.sleep(10)
val threshTime1 = System.currentTimeMillis
Thread.sleep(10)
assert(map1("k2") === "v2") // access k2 to update its access time to > threshTime
assert(map1.getTimestamp("k1").isDefined)
assert(map1.getTimestamp("k1").get < threshTime1)
assert(map1.getTimestamp("k2").isDefined)
assert(map1.getTimestamp("k2").get >= threshTime1)
map1.clearOldValues(threshTime1) //should only clear k1
assert(map1.get("k1") === None)
assert(map1.get("k2").isDefined)
}
test("TimeStampedWeakValueHashMap - clearing by timestamp") {
// clearing by insertion time
val map = new TimeStampedWeakValueHashMap[String, String](updateTimeStampOnGet = false)
map("k1") = "v1"
assert(map("k1") === "v1")
Thread.sleep(10)
val threshTime = System.currentTimeMillis
assert(map.getTimestamp("k1").isDefined)
assert(map.getTimestamp("k1").get < threshTime)
map.clearOldValues(threshTime)
assert(map.get("k1") === None)
// clearing by modification time
val map1 = new TimeStampedWeakValueHashMap[String, String](updateTimeStampOnGet = true)
map1("k1") = "v1"
map1("k2") = "v2"
assert(map1("k1") === "v1")
Thread.sleep(10)
val threshTime1 = System.currentTimeMillis
Thread.sleep(10)
assert(map1("k2") === "v2") // access k2 to update its access time to > threshTime
assert(map1.getTimestamp("k1").isDefined)
assert(map1.getTimestamp("k1").get < threshTime1)
assert(map1.getTimestamp("k2").isDefined)
assert(map1.getTimestamp("k2").get >= threshTime1)
map1.clearOldValues(threshTime1) //should only clear k1
assert(map1.get("k1") === None)
assert(map1.get("k2").isDefined)
}
test("TimeStampedWeakValueHashMap - clearing weak references") {
var strongRef = new Object
val weakRef = new WeakReference(strongRef)
val map = new TimeStampedWeakValueHashMap[String, Object]
map("k1") = strongRef
map("k2") = "v2"
map("k3") = "v3"
assert(map("k1") === strongRef)
// clear strong reference to "k1"
strongRef = null
val startTime = System.currentTimeMillis
System.gc() // Make a best effort to run the garbage collection. It *usually* runs GC.
System.runFinalization() // Make a best effort to call finalizer on all cleaned objects.
while(System.currentTimeMillis - startTime < 10000 && weakRef.get != null) {
System.gc()
System.runFinalization()
Thread.sleep(100)
}
assert(map.getReference("k1").isDefined)
val ref = map.getReference("k1").get
assert(ref.get === null)
assert(map.get("k1") === None)
// operations should only display non-null entries
assert(map.iterator.forall { case (k, v) => k != "k1" })
assert(map.filter { case (k, v) => k != "k2" }.size === 1)
assert(map.filter { case (k, v) => k != "k2" }.head._1 === "k3")
assert(map.toMap.size === 2)
assert(map.toMap.forall { case (k, v) => k != "k1" })
val buffer = new ArrayBuffer[String]
map.foreach { case (k, v) => buffer += v.toString }
assert(buffer.size === 2)
assert(buffer.forall(_ != "k1"))
val plusMap = map + (("k4", "v4"))
assert(plusMap.size === 3)
assert(plusMap.forall { case (k, v) => k != "k1" })
val minusMap = map - "k2"
assert(minusMap.size === 1)
assert(minusMap.head._1 == "k3")
// clear null values - should only clear k1
map.clearNullValues()
assert(map.getReference("k1") === None)
assert(map.get("k1") === None)
assert(map.get("k2").isDefined)
assert(map.get("k2").get === "v2")
}
/** Test basic operations of a Scala mutable Map. */
def testMap(hashMapConstructor: => mutable.Map[String, String]) {
def newMap() = hashMapConstructor
val testMap1 = newMap()
val testMap2 = newMap()
val name = testMap1.getClass.getSimpleName
test(name + " - basic test") {
// put, get, and apply
testMap1 += (("k1", "v1"))
assert(testMap1.get("k1").isDefined)
assert(testMap1.get("k1").get === "v1")
testMap1("k2") = "v2"
assert(testMap1.get("k2").isDefined)
assert(testMap1.get("k2").get === "v2")
assert(testMap1("k2") === "v2")
testMap1.update("k3", "v3")
assert(testMap1.get("k3").isDefined)
assert(testMap1.get("k3").get === "v3")
// remove
testMap1.remove("k1")
assert(testMap1.get("k1").isEmpty)
testMap1.remove("k2")
intercept[NoSuchElementException] {
testMap1("k2") // Map.apply(<non-existent-key>) causes exception
}
testMap1 -= "k3"
assert(testMap1.get("k3").isEmpty)
// multi put
val keys = (1 to 100).map(_.toString)
val pairs = keys.map(x => (x, x * 2))
assert((testMap2 ++ pairs).iterator.toSet === pairs.toSet)
testMap2 ++= pairs
// iterator
assert(testMap2.iterator.toSet === pairs.toSet)
// filter
val filtered = testMap2.filter { case (_, v) => v.toInt % 2 == 0 }
val evenPairs = pairs.filter { case (_, v) => v.toInt % 2 == 0 }
assert(filtered.iterator.toSet === evenPairs.toSet)
// foreach
val buffer = new ArrayBuffer[(String, String)]
testMap2.foreach(x => buffer += x)
assert(testMap2.toSet === buffer.toSet)
// multi remove
testMap2("k1") = "v1"
testMap2 --= keys
assert(testMap2.size === 1)
assert(testMap2.iterator.toSeq.head === ("k1", "v1"))
// +
val testMap3 = testMap2 + (("k0", "v0"))
assert(testMap3.size === 2)
assert(testMap3.get("k1").isDefined)
assert(testMap3.get("k1").get === "v1")
assert(testMap3.get("k0").isDefined)
assert(testMap3.get("k0").get === "v0")
// -
val testMap4 = testMap3 - "k0"
assert(testMap4.size === 1)
assert(testMap4.get("k1").isDefined)
assert(testMap4.get("k1").get === "v1")
}
}
/** Test thread safety of a Scala mutable map. */
def testMapThreadSafety(hashMapConstructor: => mutable.Map[String, String]) {
def newMap() = hashMapConstructor
val name = newMap().getClass.getSimpleName
val testMap = newMap()
@volatile var error = false
def getRandomKey(m: mutable.Map[String, String]): Option[String] = {
val keys = testMap.keysIterator.toSeq
if (keys.nonEmpty) {
Some(keys(Random.nextInt(keys.size)))
} else {
None
}
}
val threads = (1 to 25).map(i => new Thread() {
override def run() {
try {
for (j <- 1 to 1000) {
Random.nextInt(3) match {
case 0 =>
testMap(Random.nextString(10)) = Random.nextDouble().toString // put
case 1 =>
getRandomKey(testMap).map(testMap.get) // get
case 2 =>
getRandomKey(testMap).map(testMap.remove) // remove
}
}
} catch {
case t: Throwable =>
error = true
throw t
}
}
})
test(name + " - threading safety test") {
threads.map(_.start)
threads.map(_.join)
assert(!error)
}
}
}

View file

@ -341,9 +341,11 @@ abstract class DStream[T: ClassTag] (
*/
private[streaming] def clearMetadata(time: Time) {
val oldRDDs = generatedRDDs.filter(_._1 <= (time - rememberDuration))
logDebug("Clearing references to old RDDs: [" +
oldRDDs.map(x => s"${x._1} -> ${x._2.id}").mkString(", ") + "]")
generatedRDDs --= oldRDDs.keys
if (ssc.conf.getBoolean("spark.streaming.unpersist", false)) {
logDebug("Unpersisting old RDDs: " + oldRDDs.keys.mkString(", "))
logDebug("Unpersisting old RDDs: " + oldRDDs.values.map(_.id).mkString(", "))
oldRDDs.values.foreach(_.unpersist(false))
}
logDebug("Cleared " + oldRDDs.size + " RDDs that were older than " +