[SPARK-2044] Pluggable interface for shuffles

This is a first cut at moving shuffle logic behind a pluggable interface, as described at https://issues.apache.org/jira/browse/SPARK-2044, to let us more easily experiment with new shuffle implementations. It moves the existing shuffle code to a class HashShuffleManager behind a general ShuffleManager interface.

Two things are still missing to make this complete:
* MapOutputTracker needs to be hidden behind the ShuffleManager interface; this will also require adding methods to ShuffleManager that will let the DAGScheduler interact with it as it does with the MapOutputTracker today
* The code to do map-sides and reduce-side combine in ShuffledRDD, PairRDDFunctions, etc needs to be moved into the ShuffleManager's readers and writers

However, some of these may also be done later after we merge the current interface.

Author: Matei Zaharia <matei@databricks.com>

Closes #1009 from mateiz/pluggable-shuffle and squashes the following commits:

7a09862 [Matei Zaharia] review comments
be33d3f [Matei Zaharia] review comments
1513d4e [Matei Zaharia] Add ASF header
ac56831 [Matei Zaharia] Bug fix and better error message
4f681ba [Matei Zaharia] Move write part of ShuffleMapTask to ShuffleManager
f6f011d [Matei Zaharia] Move hash shuffle reader behind ShuffleManager interface
55c7717 [Matei Zaharia] Changed RDD code to use ShuffleReader
75cc044 [Matei Zaharia] Partial work to move hash shuffle in
This commit is contained in:
Matei Zaharia 2014-06-11 20:45:29 -07:00 committed by Reynold Xin
parent d9203350b0
commit 508fd371d6
22 changed files with 459 additions and 130 deletions

View file

@ -96,7 +96,7 @@ private[spark] class ContextCleaner(sc: SparkContext) extends Logging {
}
/** Register a ShuffleDependency for cleanup when it is garbage collected. */
def registerShuffleForCleanup(shuffleDependency: ShuffleDependency[_, _]) {
def registerShuffleForCleanup(shuffleDependency: ShuffleDependency[_, _, _]) {
registerForCleanup(shuffleDependency, CleanShuffle(shuffleDependency.shuffleId))
}

View file

@ -20,6 +20,7 @@ package org.apache.spark
import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.rdd.RDD
import org.apache.spark.serializer.Serializer
import org.apache.spark.shuffle.ShuffleHandle
/**
* :: DeveloperApi ::
@ -50,19 +51,24 @@ abstract class NarrowDependency[T](rdd: RDD[T]) extends Dependency(rdd) {
* Represents a dependency on the output of a shuffle stage.
* @param rdd the parent RDD
* @param partitioner partitioner used to partition the shuffle output
* @param serializer [[org.apache.spark.serializer.Serializer Serializer]] to use. If set to null,
* @param serializer [[org.apache.spark.serializer.Serializer Serializer]] to use. If set to None,
* the default serializer, as specified by `spark.serializer` config option, will
* be used.
*/
@DeveloperApi
class ShuffleDependency[K, V](
class ShuffleDependency[K, V, C](
@transient rdd: RDD[_ <: Product2[K, V]],
val partitioner: Partitioner,
val serializer: Serializer = null)
val serializer: Option[Serializer] = None,
val keyOrdering: Option[Ordering[K]] = None,
val aggregator: Option[Aggregator[K, V, C]] = None)
extends Dependency(rdd.asInstanceOf[RDD[Product2[K, V]]]) {
val shuffleId: Int = rdd.context.newShuffleId()
val shuffleHandle: ShuffleHandle = rdd.context.env.shuffleManager.registerShuffle(
shuffleId, rdd.partitions.size, this)
rdd.sparkContext.cleaner.foreach(_.registerShuffleForCleanup(this))
}

View file

@ -34,6 +34,7 @@ import org.apache.spark.metrics.MetricsSystem
import org.apache.spark.network.ConnectionManager
import org.apache.spark.scheduler.LiveListenerBus
import org.apache.spark.serializer.Serializer
import org.apache.spark.shuffle.ShuffleManager
import org.apache.spark.storage._
import org.apache.spark.util.{AkkaUtils, Utils}
@ -56,7 +57,7 @@ class SparkEnv (
val closureSerializer: Serializer,
val cacheManager: CacheManager,
val mapOutputTracker: MapOutputTracker,
val shuffleFetcher: ShuffleFetcher,
val shuffleManager: ShuffleManager,
val broadcastManager: BroadcastManager,
val blockManager: BlockManager,
val connectionManager: ConnectionManager,
@ -80,7 +81,7 @@ class SparkEnv (
pythonWorkers.foreach { case(key, worker) => worker.stop() }
httpFileServer.stop()
mapOutputTracker.stop()
shuffleFetcher.stop()
shuffleManager.stop()
broadcastManager.stop()
blockManager.stop()
blockManager.master.stop()
@ -163,13 +164,20 @@ object SparkEnv extends Logging {
def instantiateClass[T](propertyName: String, defaultClassName: String): T = {
val name = conf.get(propertyName, defaultClassName)
val cls = Class.forName(name, true, Utils.getContextOrSparkClassLoader)
// First try with the constructor that takes SparkConf. If we can't find one,
// use a no-arg constructor instead.
// Look for a constructor taking a SparkConf and a boolean isDriver, then one taking just
// SparkConf, then one taking no arguments
try {
cls.getConstructor(classOf[SparkConf]).newInstance(conf).asInstanceOf[T]
cls.getConstructor(classOf[SparkConf], java.lang.Boolean.TYPE)
.newInstance(conf, new java.lang.Boolean(isDriver))
.asInstanceOf[T]
} catch {
case _: NoSuchMethodException =>
cls.getConstructor().newInstance().asInstanceOf[T]
try {
cls.getConstructor(classOf[SparkConf]).newInstance(conf).asInstanceOf[T]
} catch {
case _: NoSuchMethodException =>
cls.getConstructor().newInstance().asInstanceOf[T]
}
}
}
@ -219,9 +227,6 @@ object SparkEnv extends Logging {
val cacheManager = new CacheManager(blockManager)
val shuffleFetcher = instantiateClass[ShuffleFetcher](
"spark.shuffle.fetcher", "org.apache.spark.BlockStoreShuffleFetcher")
val httpFileServer = new HttpFileServer(securityManager)
httpFileServer.initialize()
conf.set("spark.fileserver.uri", httpFileServer.serverUri)
@ -242,6 +247,9 @@ object SparkEnv extends Logging {
"."
}
val shuffleManager = instantiateClass[ShuffleManager](
"spark.shuffle.manager", "org.apache.spark.shuffle.hash.HashShuffleManager")
// Warn about deprecated spark.cache.class property
if (conf.contains("spark.cache.class")) {
logWarning("The spark.cache.class property is no longer being used! Specify storage " +
@ -255,7 +263,7 @@ object SparkEnv extends Logging {
closureSerializer,
cacheManager,
mapOutputTracker,
shuffleFetcher,
shuffleManager,
broadcastManager,
blockManager,
connectionManager,

View file

@ -27,6 +27,7 @@ import org.apache.spark.{Dependency, OneToOneDependency, ShuffleDependency}
import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.util.collection.{ExternalAppendOnlyMap, AppendOnlyMap}
import org.apache.spark.serializer.Serializer
import org.apache.spark.shuffle.ShuffleHandle
private[spark] sealed trait CoGroupSplitDep extends Serializable
@ -44,7 +45,7 @@ private[spark] case class NarrowCoGroupSplitDep(
}
}
private[spark] case class ShuffleCoGroupSplitDep(shuffleId: Int) extends CoGroupSplitDep
private[spark] case class ShuffleCoGroupSplitDep(handle: ShuffleHandle) extends CoGroupSplitDep
private[spark] class CoGroupPartition(idx: Int, val deps: Array[CoGroupSplitDep])
extends Partition with Serializable {
@ -74,10 +75,11 @@ class CoGroupedRDD[K](@transient var rdds: Seq[RDD[_ <: Product2[K, _]]], part:
private type CoGroupValue = (Any, Int) // Int is dependency number
private type CoGroupCombiner = Seq[CoGroup]
private var serializer: Serializer = null
private var serializer: Option[Serializer] = None
/** Set a serializer for this RDD's shuffle, or null to use the default (spark.serializer) */
def setSerializer(serializer: Serializer): CoGroupedRDD[K] = {
this.serializer = serializer
this.serializer = Option(serializer)
this
}
@ -88,7 +90,7 @@ class CoGroupedRDD[K](@transient var rdds: Seq[RDD[_ <: Product2[K, _]]], part:
new OneToOneDependency(rdd)
} else {
logDebug("Adding shuffle dependency with " + rdd)
new ShuffleDependency[Any, Any](rdd, part, serializer)
new ShuffleDependency[K, Any, CoGroupCombiner](rdd, part, serializer)
}
}
}
@ -100,8 +102,8 @@ class CoGroupedRDD[K](@transient var rdds: Seq[RDD[_ <: Product2[K, _]]], part:
array(i) = new CoGroupPartition(i, rdds.zipWithIndex.map { case (rdd, j) =>
// Assume each RDD contributed a single dependency, and get it
dependencies(j) match {
case s: ShuffleDependency[_, _] =>
new ShuffleCoGroupSplitDep(s.shuffleId)
case s: ShuffleDependency[_, _, _] =>
new ShuffleCoGroupSplitDep(s.shuffleHandle)
case _ =>
new NarrowCoGroupSplitDep(rdd, i, rdd.partitions(i))
}
@ -126,11 +128,11 @@ class CoGroupedRDD[K](@transient var rdds: Seq[RDD[_ <: Product2[K, _]]], part:
val it = rdd.iterator(itsSplit, context).asInstanceOf[Iterator[Product2[K, Any]]]
rddIterators += ((it, depNum))
case ShuffleCoGroupSplitDep(shuffleId) =>
case ShuffleCoGroupSplitDep(handle) =>
// Read map outputs of shuffle
val fetcher = SparkEnv.get.shuffleFetcher
val ser = Serializer.getSerializer(serializer)
val it = fetcher.fetch[Product2[K, Any]](shuffleId, split.index, context, ser)
val it = SparkEnv.get.shuffleManager
.getReader(handle, split.index, split.index + 1, context)
.read()
rddIterators += ((it, depNum))
}

View file

@ -42,10 +42,11 @@ class ShuffledRDD[K, V, P <: Product2[K, V] : ClassTag](
part: Partitioner)
extends RDD[P](prev.context, Nil) {
private var serializer: Serializer = null
private var serializer: Option[Serializer] = None
/** Set a serializer for this RDD's shuffle, or null to use the default (spark.serializer) */
def setSerializer(serializer: Serializer): ShuffledRDD[K, V, P] = {
this.serializer = serializer
this.serializer = Option(serializer)
this
}
@ -60,9 +61,10 @@ class ShuffledRDD[K, V, P <: Product2[K, V] : ClassTag](
}
override def compute(split: Partition, context: TaskContext): Iterator[P] = {
val shuffledId = dependencies.head.asInstanceOf[ShuffleDependency[K, V]].shuffleId
val ser = Serializer.getSerializer(serializer)
SparkEnv.get.shuffleFetcher.fetch[P](shuffledId, split.index, context, ser)
val dep = dependencies.head.asInstanceOf[ShuffleDependency[K, V, V]]
SparkEnv.get.shuffleManager.getReader(dep.shuffleHandle, split.index, split.index + 1, context)
.read()
.asInstanceOf[Iterator[P]]
}
override def clearDependencies() {

View file

@ -54,10 +54,11 @@ private[spark] class SubtractedRDD[K: ClassTag, V: ClassTag, W: ClassTag](
part: Partitioner)
extends RDD[(K, V)](rdd1.context, Nil) {
private var serializer: Serializer = null
private var serializer: Option[Serializer] = None
/** Set a serializer for this RDD's shuffle, or null to use the default (spark.serializer) */
def setSerializer(serializer: Serializer): SubtractedRDD[K, V, W] = {
this.serializer = serializer
this.serializer = Option(serializer)
this
}
@ -79,8 +80,8 @@ private[spark] class SubtractedRDD[K: ClassTag, V: ClassTag, W: ClassTag](
// Each CoGroupPartition will depend on rdd1 and rdd2
array(i) = new CoGroupPartition(i, Seq(rdd1, rdd2).zipWithIndex.map { case (rdd, j) =>
dependencies(j) match {
case s: ShuffleDependency[_, _] =>
new ShuffleCoGroupSplitDep(s.shuffleId)
case s: ShuffleDependency[_, _, _] =>
new ShuffleCoGroupSplitDep(s.shuffleHandle)
case _ =>
new NarrowCoGroupSplitDep(rdd, i, rdd.partitions(i))
}
@ -93,7 +94,6 @@ private[spark] class SubtractedRDD[K: ClassTag, V: ClassTag, W: ClassTag](
override def compute(p: Partition, context: TaskContext): Iterator[(K, V)] = {
val partition = p.asInstanceOf[CoGroupPartition]
val ser = Serializer.getSerializer(serializer)
val map = new JHashMap[K, ArrayBuffer[V]]
def getSeq(k: K): ArrayBuffer[V] = {
val seq = map.get(k)
@ -109,9 +109,10 @@ private[spark] class SubtractedRDD[K: ClassTag, V: ClassTag, W: ClassTag](
case NarrowCoGroupSplitDep(rdd, _, itsSplit) =>
rdd.iterator(itsSplit, context).asInstanceOf[Iterator[Product2[K, V]]].foreach(op)
case ShuffleCoGroupSplitDep(shuffleId) =>
val iter = SparkEnv.get.shuffleFetcher.fetch[Product2[K, V]](shuffleId, partition.index,
context, ser)
case ShuffleCoGroupSplitDep(handle) =>
val iter = SparkEnv.get.shuffleManager
.getReader(handle, partition.index, partition.index + 1, context)
.read()
iter.foreach(op)
}
// the first dep is rdd1; add all values to the map

View file

@ -190,7 +190,7 @@ class DAGScheduler(
* The jobId value passed in will be used if the stage doesn't already exist with
* a lower jobId (jobId always increases across jobs.)
*/
private def getShuffleMapStage(shuffleDep: ShuffleDependency[_,_], jobId: Int): Stage = {
private def getShuffleMapStage(shuffleDep: ShuffleDependency[_, _, _], jobId: Int): Stage = {
shuffleToMapStage.get(shuffleDep.shuffleId) match {
case Some(stage) => stage
case None =>
@ -210,7 +210,7 @@ class DAGScheduler(
private def newStage(
rdd: RDD[_],
numTasks: Int,
shuffleDep: Option[ShuffleDependency[_,_]],
shuffleDep: Option[ShuffleDependency[_, _, _]],
jobId: Int,
callSite: Option[String] = None)
: Stage =
@ -233,7 +233,7 @@ class DAGScheduler(
private def newOrUsedStage(
rdd: RDD[_],
numTasks: Int,
shuffleDep: ShuffleDependency[_,_],
shuffleDep: ShuffleDependency[_, _, _],
jobId: Int,
callSite: Option[String] = None)
: Stage =
@ -269,7 +269,7 @@ class DAGScheduler(
// we can't do it in its constructor because # of partitions is unknown
for (dep <- r.dependencies) {
dep match {
case shufDep: ShuffleDependency[_,_] =>
case shufDep: ShuffleDependency[_, _, _] =>
parents += getShuffleMapStage(shufDep, jobId)
case _ =>
visit(dep.rdd)
@ -290,7 +290,7 @@ class DAGScheduler(
if (getCacheLocs(rdd).contains(Nil)) {
for (dep <- rdd.dependencies) {
dep match {
case shufDep: ShuffleDependency[_,_] =>
case shufDep: ShuffleDependency[_, _, _] =>
val mapStage = getShuffleMapStage(shufDep, stage.jobId)
if (!mapStage.isAvailable) {
missing += mapStage
@ -1088,7 +1088,7 @@ class DAGScheduler(
visitedRdds += rdd
for (dep <- rdd.dependencies) {
dep match {
case shufDep: ShuffleDependency[_,_] =>
case shufDep: ShuffleDependency[_, _, _] =>
val mapStage = getShuffleMapStage(shufDep, stage.jobId)
if (!mapStage.isAvailable) {
visitedStages += mapStage

View file

@ -29,6 +29,7 @@ import org.apache.spark.executor.ShuffleWriteMetrics
import org.apache.spark.rdd.{RDD, RDDCheckpointData}
import org.apache.spark.serializer.Serializer
import org.apache.spark.storage._
import org.apache.spark.shuffle.ShuffleWriter
private[spark] object ShuffleMapTask {
@ -37,7 +38,7 @@ private[spark] object ShuffleMapTask {
// expensive on the master node if it needs to launch thousands of tasks.
private val serializedInfoCache = new HashMap[Int, Array[Byte]]
def serializeInfo(stageId: Int, rdd: RDD[_], dep: ShuffleDependency[_,_]): Array[Byte] = {
def serializeInfo(stageId: Int, rdd: RDD[_], dep: ShuffleDependency[_, _, _]): Array[Byte] = {
synchronized {
val old = serializedInfoCache.get(stageId).orNull
if (old != null) {
@ -56,12 +57,12 @@ private[spark] object ShuffleMapTask {
}
}
def deserializeInfo(stageId: Int, bytes: Array[Byte]): (RDD[_], ShuffleDependency[_,_]) = {
def deserializeInfo(stageId: Int, bytes: Array[Byte]): (RDD[_], ShuffleDependency[_, _, _]) = {
val in = new GZIPInputStream(new ByteArrayInputStream(bytes))
val ser = SparkEnv.get.closureSerializer.newInstance()
val objIn = ser.deserializeStream(in)
val rdd = objIn.readObject().asInstanceOf[RDD[_]]
val dep = objIn.readObject().asInstanceOf[ShuffleDependency[_,_]]
val dep = objIn.readObject().asInstanceOf[ShuffleDependency[_, _, _]]
(rdd, dep)
}
@ -99,7 +100,7 @@ private[spark] object ShuffleMapTask {
private[spark] class ShuffleMapTask(
stageId: Int,
var rdd: RDD[_],
var dep: ShuffleDependency[_,_],
var dep: ShuffleDependency[_, _, _],
_partitionId: Int,
@transient private var locs: Seq[TaskLocation])
extends Task[MapStatus](stageId, _partitionId)
@ -141,66 +142,22 @@ private[spark] class ShuffleMapTask(
}
override def runTask(context: TaskContext): MapStatus = {
val numOutputSplits = dep.partitioner.numPartitions
metrics = Some(context.taskMetrics)
val blockManager = SparkEnv.get.blockManager
val shuffleBlockManager = blockManager.shuffleBlockManager
var shuffle: ShuffleWriterGroup = null
var success = false
var writer: ShuffleWriter[Any, Any] = null
try {
// Obtain all the block writers for shuffle blocks.
val ser = Serializer.getSerializer(dep.serializer)
shuffle = shuffleBlockManager.forMapTask(dep.shuffleId, partitionId, numOutputSplits, ser)
// Write the map output to its associated buckets.
val manager = SparkEnv.get.shuffleManager
writer = manager.getWriter[Any, Any](dep.shuffleHandle, partitionId, context)
for (elem <- rdd.iterator(split, context)) {
val pair = elem.asInstanceOf[Product2[Any, Any]]
val bucketId = dep.partitioner.getPartition(pair._1)
shuffle.writers(bucketId).write(pair)
writer.write(elem.asInstanceOf[Product2[Any, Any]])
}
// Commit the writes. Get the size of each bucket block (total block size).
var totalBytes = 0L
var totalTime = 0L
val compressedSizes: Array[Byte] = shuffle.writers.map { writer: BlockObjectWriter =>
writer.commit()
writer.close()
val size = writer.fileSegment().length
totalBytes += size
totalTime += writer.timeWriting()
MapOutputTracker.compressSize(size)
}
// Update shuffle metrics.
val shuffleMetrics = new ShuffleWriteMetrics
shuffleMetrics.shuffleBytesWritten = totalBytes
shuffleMetrics.shuffleWriteTime = totalTime
metrics.get.shuffleWriteMetrics = Some(shuffleMetrics)
success = true
new MapStatus(blockManager.blockManagerId, compressedSizes)
} catch { case e: Exception =>
// If there is an exception from running the task, revert the partial writes
// and throw the exception upstream to Spark.
if (shuffle != null && shuffle.writers != null) {
for (writer <- shuffle.writers) {
writer.revertPartialWrites()
writer.close()
return writer.stop(success = true).get
} catch {
case e: Exception =>
if (writer != null) {
writer.stop(success = false)
}
}
throw e
throw e
} finally {
// Release the writers back to the shuffle block manager.
if (shuffle != null && shuffle.writers != null) {
try {
shuffle.releaseWriters(success)
} catch {
case e: Exception => logError("Failed to release shuffle writers", e)
}
}
// Execute the callbacks on task completion.
context.executeOnCompleteCallbacks()
}
}

View file

@ -40,7 +40,7 @@ private[spark] class Stage(
val id: Int,
val rdd: RDD[_],
val numTasks: Int,
val shuffleDep: Option[ShuffleDependency[_,_]], // Output shuffle if stage is a map stage
val shuffleDep: Option[ShuffleDependency[_, _, _]], // Output shuffle if stage is a map stage
val parents: List[Stage],
val jobId: Int,
callSite: Option[String])

View file

@ -71,7 +71,8 @@ private[spark] class TaskResultGetter(sparkEnv: SparkEnv, scheduler: TaskSchedul
val loader = Thread.currentThread.getContextClassLoader
taskSetManager.abort("ClassNotFound with classloader: " + loader)
case ex: Exception =>
taskSetManager.abort("Exception while deserializing and fetching task: %s".format(ex))
logError("Exception while getting task result", ex)
taskSetManager.abort("Exception while getting task result: %s".format(ex))
}
}
})

View file

@ -52,6 +52,10 @@ object Serializer {
def getSerializer(serializer: Serializer): Serializer = {
if (serializer == null) SparkEnv.get.serializer else serializer
}
def getSerializer(serializer: Option[Serializer]): Serializer = {
serializer.getOrElse(SparkEnv.get.serializer)
}
}

View file

@ -15,22 +15,16 @@
* limitations under the License.
*/
package org.apache.spark
package org.apache.spark.shuffle
import org.apache.spark.{ShuffleDependency, Aggregator, Partitioner}
import org.apache.spark.serializer.Serializer
private[spark] abstract class ShuffleFetcher {
/**
* Fetch the shuffle outputs for a given ShuffleDependency.
* @return An iterator over the elements of the fetched shuffle outputs.
*/
def fetch[T](
shuffleId: Int,
reduceId: Int,
context: TaskContext,
serializer: Serializer = SparkEnv.get.serializer): Iterator[T]
/** Stop the fetcher */
def stop() {}
}
/**
* A basic ShuffleHandle implementation that just captures registerShuffle's parameters.
*/
private[spark] class BaseShuffleHandle[K, V, C](
shuffleId: Int,
val numMaps: Int,
val dependency: ShuffleDependency[K, V, C])
extends ShuffleHandle(shuffleId)

View file

@ -0,0 +1,25 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.shuffle
/**
* An opaque handle to a shuffle, used by a ShuffleManager to pass information about it to tasks.
*
* @param shuffleId ID of the shuffle
*/
private[spark] abstract class ShuffleHandle(val shuffleId: Int) extends Serializable {}

View file

@ -0,0 +1,57 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.shuffle
import org.apache.spark.{TaskContext, ShuffleDependency}
/**
* Pluggable interface for shuffle systems. A ShuffleManager is created in SparkEnv on both the
* driver and executors, based on the spark.shuffle.manager setting. The driver registers shuffles
* with it, and executors (or tasks running locally in the driver) can ask to read and write data.
*
* NOTE: this will be instantiated by SparkEnv so its constructor can take a SparkConf and
* boolean isDriver as parameters.
*/
private[spark] trait ShuffleManager {
/**
* Register a shuffle with the manager and obtain a handle for it to pass to tasks.
*/
def registerShuffle[K, V, C](
shuffleId: Int,
numMaps: Int,
dependency: ShuffleDependency[K, V, C]): ShuffleHandle
/** Get a writer for a given partition. Called on executors by map tasks. */
def getWriter[K, V](handle: ShuffleHandle, mapId: Int, context: TaskContext): ShuffleWriter[K, V]
/**
* Get a reader for a range of reduce partitions (startPartition to endPartition-1, inclusive).
* Called on executors by reduce tasks.
*/
def getReader[K, C](
handle: ShuffleHandle,
startPartition: Int,
endPartition: Int,
context: TaskContext): ShuffleReader[K, C]
/** Remove a shuffle's metadata from the ShuffleManager. */
def unregisterShuffle(shuffleId: Int)
/** Shut down this ShuffleManager. */
def stop(): Unit
}

View file

@ -0,0 +1,29 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.shuffle
/**
* Obtained inside a reduce task to read combined records from the mappers.
*/
private[spark] trait ShuffleReader[K, C] {
/** Read the combined key-values for this reduce task */
def read(): Iterator[Product2[K, C]]
/** Close this reader */
def stop(): Unit
}

View file

@ -0,0 +1,31 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.shuffle
import org.apache.spark.scheduler.MapStatus
/**
* Obtained inside a map task to write out records to the shuffle system.
*/
private[spark] trait ShuffleWriter[K, V] {
/** Write a record to this task's output */
def write(record: Product2[K, V]): Unit
/** Close this writer, passing along whether the map completed */
def stop(success: Boolean): Option[MapStatus]
}

View file

@ -15,7 +15,7 @@
* limitations under the License.
*/
package org.apache.spark
package org.apache.spark.shuffle.hash
import scala.collection.mutable.ArrayBuffer
import scala.collection.mutable.HashMap
@ -24,17 +24,16 @@ import org.apache.spark.executor.ShuffleReadMetrics
import org.apache.spark.serializer.Serializer
import org.apache.spark.storage.{BlockId, BlockManagerId, ShuffleBlockId}
import org.apache.spark.util.CompletionIterator
import org.apache.spark._
private[spark] class BlockStoreShuffleFetcher extends ShuffleFetcher with Logging {
override def fetch[T](
private[hash] object BlockStoreShuffleFetcher extends Logging {
def fetch[T](
shuffleId: Int,
reduceId: Int,
context: TaskContext,
serializer: Serializer)
: Iterator[T] =
{
logDebug("Fetching outputs for shuffle %d, reduce %d".format(shuffleId, reduceId))
val blockManager = SparkEnv.get.blockManager

View file

@ -0,0 +1,60 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.shuffle.hash
import org.apache.spark._
import org.apache.spark.shuffle._
/**
* A ShuffleManager using hashing, that creates one output file per reduce partition on each
* mapper (possibly reusing these across waves of tasks).
*/
class HashShuffleManager(conf: SparkConf) extends ShuffleManager {
/* Register a shuffle with the manager and obtain a handle for it to pass to tasks. */
override def registerShuffle[K, V, C](
shuffleId: Int,
numMaps: Int,
dependency: ShuffleDependency[K, V, C]): ShuffleHandle = {
new BaseShuffleHandle(shuffleId, numMaps, dependency)
}
/**
* Get a reader for a range of reduce partitions (startPartition to endPartition-1, inclusive).
* Called on executors by reduce tasks.
*/
override def getReader[K, C](
handle: ShuffleHandle,
startPartition: Int,
endPartition: Int,
context: TaskContext): ShuffleReader[K, C] = {
new HashShuffleReader(
handle.asInstanceOf[BaseShuffleHandle[K, _, C]], startPartition, endPartition, context)
}
/** Get a writer for a given partition. Called on executors by map tasks. */
override def getWriter[K, V](handle: ShuffleHandle, mapId: Int, context: TaskContext)
: ShuffleWriter[K, V] = {
new HashShuffleWriter(handle.asInstanceOf[BaseShuffleHandle[K, V, _]], mapId, context)
}
/** Remove a shuffle's metadata from the ShuffleManager. */
override def unregisterShuffle(shuffleId: Int): Unit = {}
/** Shut down this ShuffleManager. */
override def stop(): Unit = {}
}

View file

@ -0,0 +1,42 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.shuffle.hash
import org.apache.spark.serializer.Serializer
import org.apache.spark.shuffle.{BaseShuffleHandle, ShuffleReader}
import org.apache.spark.TaskContext
class HashShuffleReader[K, C](
handle: BaseShuffleHandle[K, _, C],
startPartition: Int,
endPartition: Int,
context: TaskContext)
extends ShuffleReader[K, C]
{
require(endPartition == startPartition + 1,
"Hash shuffle currently only supports fetching one partition")
/** Read the combined key-values for this reduce task */
override def read(): Iterator[Product2[K, C]] = {
BlockStoreShuffleFetcher.fetch(handle.shuffleId, startPartition, context,
Serializer.getSerializer(handle.dependency.serializer))
}
/** Close this reader */
override def stop(): Unit = ???
}

View file

@ -0,0 +1,111 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.shuffle.hash
import org.apache.spark.shuffle.{BaseShuffleHandle, ShuffleWriter}
import org.apache.spark.{Logging, MapOutputTracker, SparkEnv, TaskContext}
import org.apache.spark.storage.{BlockObjectWriter}
import org.apache.spark.serializer.Serializer
import org.apache.spark.executor.ShuffleWriteMetrics
import org.apache.spark.scheduler.MapStatus
class HashShuffleWriter[K, V](
handle: BaseShuffleHandle[K, V, _],
mapId: Int,
context: TaskContext)
extends ShuffleWriter[K, V] with Logging {
private val dep = handle.dependency
private val numOutputSplits = dep.partitioner.numPartitions
private val metrics = context.taskMetrics
private var stopping = false
private val blockManager = SparkEnv.get.blockManager
private val shuffleBlockManager = blockManager.shuffleBlockManager
private val ser = Serializer.getSerializer(dep.serializer.getOrElse(null))
private val shuffle = shuffleBlockManager.forMapTask(dep.shuffleId, mapId, numOutputSplits, ser)
/** Write a record to this task's output */
override def write(record: Product2[K, V]): Unit = {
val pair = record.asInstanceOf[Product2[Any, Any]]
val bucketId = dep.partitioner.getPartition(pair._1)
shuffle.writers(bucketId).write(pair)
}
/** Close this writer, passing along whether the map completed */
override def stop(success: Boolean): Option[MapStatus] = {
try {
if (stopping) {
return None
}
stopping = true
if (success) {
try {
return Some(commitWritesAndBuildStatus())
} catch {
case e: Exception =>
revertWrites()
throw e
}
} else {
revertWrites()
return None
}
} finally {
// Release the writers back to the shuffle block manager.
if (shuffle != null && shuffle.writers != null) {
try {
shuffle.releaseWriters(success)
} catch {
case e: Exception => logError("Failed to release shuffle writers", e)
}
}
}
}
private def commitWritesAndBuildStatus(): MapStatus = {
// Commit the writes. Get the size of each bucket block (total block size).
var totalBytes = 0L
var totalTime = 0L
val compressedSizes = shuffle.writers.map { writer: BlockObjectWriter =>
writer.commit()
writer.close()
val size = writer.fileSegment().length
totalBytes += size
totalTime += writer.timeWriting()
MapOutputTracker.compressSize(size)
}
// Update shuffle metrics.
val shuffleMetrics = new ShuffleWriteMetrics
shuffleMetrics.shuffleBytesWritten = totalBytes
shuffleMetrics.shuffleWriteTime = totalTime
metrics.shuffleWriteMetrics = Some(shuffleMetrics)
new MapStatus(blockManager.blockManagerId, compressedSizes)
}
private def revertWrites(): Unit = {
if (shuffle != null && shuffle.writers != null) {
for (writer <- shuffle.writers) {
writer.revertPartialWrites()
writer.close()
}
}
}
}

View file

@ -201,7 +201,7 @@ class ContextCleanerSuite extends FunSuite with BeforeAndAfter with LocalSparkCo
def newPairRDD = newRDD.map(_ -> 1)
def newShuffleRDD = newPairRDD.reduceByKey(_ + _)
def newBroadcast = sc.broadcast(1 to 100)
def newRDDWithShuffleDependencies: (RDD[_], Seq[ShuffleDependency[_, _]]) = {
def newRDDWithShuffleDependencies: (RDD[_], Seq[ShuffleDependency[_, _, _]]) = {
def getAllDependencies(rdd: RDD[_]): Seq[Dependency[_]] = {
rdd.dependencies ++ rdd.dependencies.flatMap { dep =>
getAllDependencies(dep.rdd)
@ -211,8 +211,8 @@ class ContextCleanerSuite extends FunSuite with BeforeAndAfter with LocalSparkCo
// Get all the shuffle dependencies
val shuffleDeps = getAllDependencies(rdd)
.filter(_.isInstanceOf[ShuffleDependency[_, _]])
.map(_.asInstanceOf[ShuffleDependency[_, _]])
.filter(_.isInstanceOf[ShuffleDependency[_, _, _]])
.map(_.asInstanceOf[ShuffleDependency[_, _, _]])
(rdd, shuffleDeps)
}

View file

@ -58,7 +58,7 @@ class ShuffleSuite extends FunSuite with Matchers with LocalSparkContext {
// default Java serializer cannot handle the non serializable class.
val c = new ShuffledRDD[Int, NonJavaSerializableClass, (Int, NonJavaSerializableClass)](
b, new HashPartitioner(NUM_BLOCKS)).setSerializer(new KryoSerializer(conf))
val shuffleId = c.dependencies.head.asInstanceOf[ShuffleDependency[Int, Int]].shuffleId
val shuffleId = c.dependencies.head.asInstanceOf[ShuffleDependency[_, _, _]].shuffleId
assert(c.count === 10)
@ -97,7 +97,7 @@ class ShuffleSuite extends FunSuite with Matchers with LocalSparkContext {
val c = new ShuffledRDD[Int, Int, (Int, Int)](b, new HashPartitioner(10))
.setSerializer(new KryoSerializer(conf))
val shuffleId = c.dependencies.head.asInstanceOf[ShuffleDependency[Int, Int]].shuffleId
val shuffleId = c.dependencies.head.asInstanceOf[ShuffleDependency[_, _, _]].shuffleId
assert(c.count === 4)
val blockSizes = (0 until NUM_BLOCKS).flatMap { id =>
@ -122,7 +122,7 @@ class ShuffleSuite extends FunSuite with Matchers with LocalSparkContext {
// NOTE: The default Java serializer should create zero-sized blocks
val c = new ShuffledRDD[Int, Int, (Int, Int)](b, new HashPartitioner(10))
val shuffleId = c.dependencies.head.asInstanceOf[ShuffleDependency[Int, Int]].shuffleId
val shuffleId = c.dependencies.head.asInstanceOf[ShuffleDependency[_, _, _]].shuffleId
assert(c.count === 4)
val blockSizes = (0 until NUM_BLOCKS).flatMap { id =>