Defensively allocate memory from global pool

This is an alternative to the existing approach, which evenly distributes the
collective shuffle memory among all running tasks. In the new approach, each
thread requests a chunk of memory whenever its map is about to multiplicatively
grow. If there is sufficient memory in the global pool, the thread allocates it
and grows its map. Otherwise, it spills.

A danger with the previous approach is that a new task may quickly fill up its
map before old tasks finish spilling, potentially causing an OOM. This approach
prevents this scenario as it favors existing tasks over new tasks; any thread
that may step over the boundary of other threads defensively backs off and
starts spilling.

Testing through spark-perf reveals: (1) When no spills have occured, the
performance of external sorting using this memory management approach is
essentially the same as without external sorting. (2) When one or more spills
have occured, the performance of external sorting is a small multiple (3x) worse
This commit is contained in:
Andrew Or 2014-01-09 21:43:58 -08:00
parent d76e1f90a8
commit aa5002bb96
5 changed files with 80 additions and 47 deletions

View file

@ -17,8 +17,6 @@
package org.apache.spark
import java.util.concurrent.atomic.AtomicInteger
import scala.collection.mutable
import scala.concurrent.Await
@ -56,13 +54,14 @@ class SparkEnv private[spark] (
val httpFileServer: HttpFileServer,
val sparkFilesDir: String,
val metricsSystem: MetricsSystem,
val conf: SparkConf) {
val conf: SparkConf) extends Logging {
// A mapping of thread ID to amount of memory used for shuffle in bytes
// All accesses should be manually synchronized
val shuffleMemoryMap = mutable.HashMap[Long, Long]()
private val pythonWorkers = mutable.HashMap[(String, Map[String, String]), PythonWorkerFactory]()
// Number of tasks currently running across all threads
private val _numRunningTasks = new AtomicInteger(0)
// A general, soft-reference map for metadata needed during HadoopRDD split computation
// (e.g., HadoopFileRDD uses this to cache JobConfs and InputFormats).
private[spark] val hadoopJobMetadata = new MapMaker().softValues().makeMap[String, Any]()
@ -90,13 +89,6 @@ class SparkEnv private[spark] (
pythonWorkers.getOrElseUpdate(key, new PythonWorkerFactory(pythonExec, envVars)).create()
}
}
/**
* Return the number of tasks currently running across all threads
*/
def numRunningTasks: Int = _numRunningTasks.intValue()
def incrementNumRunningTasks(): Int = _numRunningTasks.incrementAndGet()
def decrementNumRunningTasks(): Int = _numRunningTasks.decrementAndGet()
}
object SparkEnv extends Logging {

View file

@ -186,7 +186,6 @@ private[spark] class Executor(
var taskStart: Long = 0
def gcTime = ManagementFactory.getGarbageCollectorMXBeans.map(_.getCollectionTime).sum
val startGCTime = gcTime
env.incrementNumRunningTasks()
try {
SparkEnv.set(env)
@ -280,7 +279,11 @@ private[spark] class Executor(
//System.exit(1)
}
} finally {
env.decrementNumRunningTasks()
// TODO: Unregister shuffle memory only for ShuffleMapTask
val shuffleMemoryMap = env.shuffleMemoryMap
shuffleMemoryMap.synchronized {
shuffleMemoryMap.remove(Thread.currentThread().getId)
}
runningTasks.remove(taskId)
}
}

View file

@ -181,4 +181,8 @@ class DiskBlockObjectWriter(
// Only valid if called after close()
override def timeWriting() = _timeWriting
def bytesWritten: Long = {
lastValidPosition - initialPosition
}
}

View file

@ -30,14 +30,15 @@ import java.util.{Arrays, Comparator}
* TODO: Cache the hash values of each key? java.util.HashMap does that.
*/
private[spark]
class AppendOnlyMap[K, V](initialCapacity: Int = 64) extends Iterable[(K, V)] with Serializable {
class AppendOnlyMap[K, V](initialCapacity: Int = 64) extends Iterable[(K,
V)] with Serializable {
require(initialCapacity <= (1 << 29), "Can't make capacity bigger than 2^29 elements")
require(initialCapacity >= 1, "Invalid initial capacity")
private var capacity = nextPowerOf2(initialCapacity)
private var mask = capacity - 1
private var curSize = 0
private var growThreshold = LOAD_FACTOR * capacity
private var growThreshold = (LOAD_FACTOR * capacity).toInt
// Holds keys and values in the same array for memory locality; specifically, the order of
// elements is key0, value0, key1, value1, key2, value2, etc.
@ -239,7 +240,7 @@ class AppendOnlyMap[K, V](initialCapacity: Int = 64) extends Iterable[(K, V)] wi
data = newData
capacity = newCapacity
mask = newMask
growThreshold = LOAD_FACTOR * newCapacity
growThreshold = (LOAD_FACTOR * newCapacity).toInt
}
private def nextPowerOf2(n: Int): Int = {
@ -288,4 +289,9 @@ class AppendOnlyMap[K, V](initialCapacity: Int = 64) extends Iterable[(K, V)] wi
}
}
}
/**
* Return whether the next insert will cause the map to grow
*/
def atGrowThreshold: Boolean = curSize == growThreshold
}

View file

@ -22,14 +22,16 @@ import java.util.Comparator
import it.unimi.dsi.fastutil.io.FastBufferedInputStream
import scala.collection.mutable.{ArrayBuffer, PriorityQueue}
import scala.collection.mutable
import scala.collection.mutable.ArrayBuffer
import org.apache.spark.{Logging, SparkEnv}
import org.apache.spark.serializer.Serializer
import org.apache.spark.storage.{DiskBlockManager, DiskBlockObjectWriter}
/**
* An append-only map that spills sorted content to disk when the memory threshold is exceeded.
* An append-only map that spills sorted content to disk when there is insufficient space for it
* to grow.
*
* This map takes two passes over the data:
*
@ -42,7 +44,7 @@ import org.apache.spark.storage.{DiskBlockManager, DiskBlockObjectWriter}
* writes. This may lead to a performance regression compared to the normal case of using the
* non-spilling AppendOnlyMap.
*
* A few parameters control the memory threshold:
* Two parameters control the memory threshold:
*
* `spark.shuffle.memoryFraction` specifies the collective amount of memory used for storing
* these maps as a fraction of the executor's total memory. Since each concurrently running
@ -51,9 +53,6 @@ import org.apache.spark.storage.{DiskBlockManager, DiskBlockObjectWriter}
*
* `spark.shuffle.safetyFraction` specifies an additional margin of safety as a fraction of
* this threshold, in case map size estimation is not sufficiently accurate.
*
* `spark.shuffle.updateThresholdInterval` controls how frequently each thread checks on
* shared executor state to update its local memory threshold.
*/
private[spark] class ExternalAppendOnlyMap[K, V, C](
@ -77,12 +76,9 @@ private[spark] class ExternalAppendOnlyMap[K, V, C](
(Runtime.getRuntime.maxMemory * memoryFraction * safetyFraction).toLong
}
// Maximum size for this map before a spill is triggered
private var spillThreshold = maxMemoryThreshold
// How often to update spillThreshold
private val updateThresholdInterval =
sparkConf.getInt("spark.shuffle.updateThresholdInterval", 100)
// How many inserts into this map before tracking its shuffle memory usage
private val initialInsertThreshold =
sparkConf.getLong("spark.shuffle.initialInsertThreshold", 1000)
private val fileBufferSize = sparkConf.getInt("spark.shuffle.file.buffer.kb", 100) * 1024
private val syncWrites = sparkConf.get("spark.shuffle.sync", "false").toBoolean
@ -91,30 +87,54 @@ private[spark] class ExternalAppendOnlyMap[K, V, C](
private var insertCount = 0
private var spillCount = 0
/**
* Insert the given key and value into the map.
*
* If the underlying map is about to grow, check if the global pool of shuffle memory has
* enough room for this to happen. If so, allocate the memory required to grow the map;
* otherwise, spill the in-memory map to disk.
*
* The shuffle memory usage of the first initialInsertThreshold entries is not tracked.
*/
def insert(key: K, value: V) {
insertCount += 1
val update: (Boolean, C) => C = (hadVal, oldVal) => {
if (hadVal) mergeValue(oldVal, value) else createCombiner(value)
}
if (insertCount > initialInsertThreshold && currentMap.atGrowThreshold) {
val mapSize = currentMap.estimateSize()
var shouldSpill = false
val shuffleMemoryMap = SparkEnv.get.shuffleMemoryMap
// Atomically check whether there is sufficient memory in the global pool for
// this map to grow and, if possible, allocate the required amount
shuffleMemoryMap.synchronized {
val threadId = Thread.currentThread().getId
val previouslyOccupiedMemory = shuffleMemoryMap.get(threadId)
val availableMemory = maxMemoryThreshold -
(shuffleMemoryMap.values.sum - previouslyOccupiedMemory.getOrElse(0L))
// Assume map grow factor is 2x
shouldSpill = availableMemory < mapSize * 2
if (!shouldSpill) {
shuffleMemoryMap(threadId) = mapSize * 2
}
}
// Do not synchronize spills
if (shouldSpill) {
spill(mapSize)
}
}
currentMap.changeValue(key, update)
if (insertCount % updateThresholdInterval == 1) {
updateSpillThreshold()
}
if (currentMap.estimateSize() > spillThreshold) {
spill()
}
}
// TODO: differentiate ShuffleMapTask's from ResultTask's
private def updateSpillThreshold() {
val numRunningTasks = math.max(SparkEnv.get.numRunningTasks, 1)
spillThreshold = maxMemoryThreshold / numRunningTasks
}
private def spill() {
/**
* Sort the existing contents of the in-memory map and spill them to a temporary file on disk
*/
private def spill(mapSize: Long) {
spillCount += 1
logWarning("In-memory map exceeded %s MB! Spilling to disk (%d time%s so far)"
.format(spillThreshold / (1024 * 1024), spillCount, if (spillCount > 1) "s" else ""))
logWarning("* Spilling in-memory map of %d MB to disk (%d time%s so far)"
.format(mapSize / (1024 * 1024), spillCount, if (spillCount > 1) "s" else ""))
val (blockId, file) = diskBlockManager.createTempBlock()
val writer =
new DiskBlockObjectWriter(blockId, file, serializer, fileBufferSize, identity, syncWrites)
@ -131,6 +151,13 @@ private[spark] class ExternalAppendOnlyMap[K, V, C](
}
currentMap = new SizeTrackingAppendOnlyMap[K, C]
spilledMaps.append(new DiskMapIterator(file))
// Reset the amount of shuffle memory used by this map in the global pool
val shuffleMemoryMap = SparkEnv.get.shuffleMemoryMap
shuffleMemoryMap.synchronized {
shuffleMemoryMap(Thread.currentThread().getId) = 0
}
insertCount = 0
}
override def iterator: Iterator[(K, C)] = {
@ -145,11 +172,12 @@ private[spark] class ExternalAppendOnlyMap[K, V, C](
private class ExternalIterator extends Iterator[(K, C)] {
// A fixed-size queue that maintains a buffer for each stream we are currently merging
val mergeHeap = new PriorityQueue[StreamBuffer]
val mergeHeap = new mutable.PriorityQueue[StreamBuffer]
// Input streams are derived both from the in-memory map and spilled maps on disk
// The in-memory map is sorted in place, while the spilled maps are already in sorted order
val inputStreams = Seq(currentMap.destructiveSortedIterator(comparator)) ++ spilledMaps
val sortedMap = currentMap.destructiveSortedIterator(comparator)
val inputStreams = Seq(sortedMap) ++ spilledMaps
inputStreams.foreach{ it =>
val kcPairs = getMorePairs(it)