Merge remote-tracking branch 'upstream/dev' into dev

This commit is contained in:
Mosharaf Chowdhury 2012-08-30 15:21:08 -07:00
commit 8f2bd399da
5 changed files with 115 additions and 24 deletions

View file

@ -1,7 +1,17 @@
package spark
/** A set of functions used to aggregate data.
*
* @param createCombiner function to create the initial value of the aggregation.
* @param mergeValue function to merge a new value into the aggregation result.
* @param mergeCombiners function to merge outputs from multiple mergeValue function.
* @param mapSideCombine whether to apply combiners on map partitions, also
* known as map-side aggregations. When set to false,
* mergeCombiners function is not used.
*/
class Aggregator[K, V, C] (
val createCombiner: V => C,
val mergeValue: (C, V) => C,
val mergeCombiners: (C, C) => C)
val mergeCombiners: (C, C) => C,
val mapSideCombine: Boolean = true)
extends Serializable

View file

@ -27,16 +27,35 @@ class ShuffledRDD[K, V, C](
override def compute(split: Split): Iterator[(K, C)] = {
val combiners = new JHashMap[K, C]
def mergePair(k: K, c: C) {
val oldC = combiners.get(k)
if (oldC == null) {
combiners.put(k, c)
} else {
combiners.put(k, aggregator.mergeCombiners(oldC, c))
}
}
val fetcher = SparkEnv.get.shuffleFetcher
fetcher.fetch[K, C](dep.shuffleId, split.index, mergePair)
if (aggregator.mapSideCombine) {
// Apply combiners on map partitions. In this case, post-shuffle we get a
// list of outputs from the combiners and merge them using mergeCombiners.
def mergePairWithMapSideCombiners(k: K, c: C) {
val oldC = combiners.get(k)
if (oldC == null) {
combiners.put(k, c)
} else {
combiners.put(k, aggregator.mergeCombiners(oldC, c))
}
}
fetcher.fetch[K, C](dep.shuffleId, split.index, mergePairWithMapSideCombiners)
} else {
// Do not apply combiners on map partitions (i.e. map side aggregation is
// turned off). Post-shuffle we get a list of values and we use mergeValue
// to merge them.
def mergePairWithoutMapSideCombiners(k: K, v: V) {
val oldC = combiners.get(k)
if (oldC == null) {
combiners.put(k, aggregator.createCombiner(v))
} else {
combiners.put(k, aggregator.mergeValue(oldC, v))
}
}
fetcher.fetch[K, V](dep.shuffleId, split.index, mergePairWithoutMapSideCombiners)
}
return new Iterator[(K, C)] {
var iter = combiners.entrySet().iterator()

View file

@ -106,27 +106,44 @@ class ShuffleMapTask(
val numOutputSplits = dep.partitioner.numPartitions
val aggregator = dep.aggregator.asInstanceOf[Aggregator[Any, Any, Any]]
val partitioner = dep.partitioner
val buckets = Array.tabulate(numOutputSplits)(_ => new HashMap[Any, Any])
for (elem <- rdd.iterator(split)) {
val (k, v) = elem.asInstanceOf[(Any, Any)]
var bucketId = partitioner.getPartition(k)
val bucket = buckets(bucketId)
var existing = bucket.get(k)
if (existing == null) {
bucket.put(k, aggregator.createCombiner(v))
val bucketIterators =
if (aggregator.mapSideCombine) {
// Apply combiners (map-side aggregation) to the map output.
val buckets = Array.tabulate(numOutputSplits)(_ => new HashMap[Any, Any])
for (elem <- rdd.iterator(split)) {
val (k, v) = elem.asInstanceOf[(Any, Any)]
val bucketId = partitioner.getPartition(k)
val bucket = buckets(bucketId)
val existing = bucket.get(k)
if (existing == null) {
bucket.put(k, aggregator.createCombiner(v))
} else {
bucket.put(k, aggregator.mergeValue(existing, v))
}
}
buckets.map(_.iterator)
} else {
bucket.put(k, aggregator.mergeValue(existing, v))
// No combiners (no map-side aggregation). Simply partition the map output.
val buckets = Array.tabulate(numOutputSplits)(_ => new ArrayBuffer[(Any, Any)])
for (elem <- rdd.iterator(split)) {
val pair = elem.asInstanceOf[(Any, Any)]
val bucketId = partitioner.getPartition(pair._1)
buckets(bucketId) += pair
}
buckets.map(_.iterator)
}
}
val ser = SparkEnv.get.serializer.newInstance()
val blockManager = SparkEnv.get.blockManager
for (i <- 0 until numOutputSplits) {
val blockId = "shuffleid_" + dep.shuffleId + "_" + partition + "_" + i
// Get a scala iterator from java map
val iter: Iterator[(Any, Any)] = buckets(i).iterator
val iter: Iterator[(Any, Any)] = bucketIterators(i)
// TODO: This should probably be DISK_ONLY
blockManager.put(blockId, iter, StorageLevel.MEMORY_ONLY, false)
}
return SparkEnv.get.blockManager.blockManagerId
}

View file

@ -2,6 +2,7 @@ package spark
import org.scalatest.FunSuite
import org.scalatest.BeforeAndAfter
import org.scalatest.matchers.ShouldMatchers
import org.scalatest.prop.Checkers
import org.scalacheck.Arbitrary._
import org.scalacheck.Gen
@ -13,7 +14,7 @@ import scala.collection.mutable.ArrayBuffer
import SparkContext._
class ShuffleSuite extends FunSuite with BeforeAndAfter {
class ShuffleSuite extends FunSuite with ShouldMatchers with BeforeAndAfter {
var sc: SparkContext = _
@ -196,4 +197,46 @@ class ShuffleSuite extends FunSuite with BeforeAndAfter {
// Test that a shuffle on the file works, because this used to be a bug
assert(file.map(line => (line, 1)).reduceByKey(_ + _).collect().toList === Nil)
}
test("map-side combine") {
sc = new SparkContext("local", "test")
val pairs = sc.parallelize(Array((1, 1), (1, 2), (1, 3), (1, 1), (2, 1), (1, 1)), 2)
// Test with map-side combine on.
val sums = pairs.reduceByKey(_+_).collect()
assert(sums.toSet === Set((1, 8), (2, 1)))
// Turn off map-side combine and test the results.
val aggregator = new Aggregator[Int, Int, Int](
(v: Int) => v,
_+_,
_+_,
false)
val shuffledRdd = new ShuffledRDD(
pairs, aggregator, new HashPartitioner(2))
assert(shuffledRdd.collect().toSet === Set((1, 8), (2, 1)))
// Turn map-side combine off and pass a wrong mergeCombine function. Should
// not see an exception because mergeCombine should not have been called.
val aggregatorWithException = new Aggregator[Int, Int, Int](
(v: Int) => v, _+_, ShuffleSuite.mergeCombineException, false)
val shuffledRdd1 = new ShuffledRDD(
pairs, aggregatorWithException, new HashPartitioner(2))
assert(shuffledRdd1.collect().toSet === Set((1, 8), (2, 1)))
// Now run the same mergeCombine function with map-side combine on. We
// expect to see an exception thrown.
val aggregatorWithException1 = new Aggregator[Int, Int, Int](
(v: Int) => v, _+_, ShuffleSuite.mergeCombineException)
val shuffledRdd2 = new ShuffledRDD(
pairs, aggregatorWithException1, new HashPartitioner(2))
evaluating { shuffledRdd2.collect() } should produce [SparkException]
}
}
object ShuffleSuite {
def mergeCombineException(x: Int, y: Int): Int = {
throw new SparkException("Exception for map-side combine.")
x + y
}
}

View file

@ -59,7 +59,9 @@ def parse_args():
"WARNING: must be 64-bit; small instances won't work")
parser.add_option("-m", "--master-instance-type", default="",
help="Master instance type (leave empty for same as instance-type)")
parser.add_option("-z", "--zone", default="us-east-1b",
parser.add_option("-r", "--region", default="us-east-1",
help="EC2 region zone to launch instances in")
parser.add_option("-z", "--zone", default="",
help="Availability zone to launch instances in")
parser.add_option("-a", "--ami", default="latest",
help="Amazon Machine Image ID to use, or 'latest' to use latest " +
@ -470,7 +472,7 @@ def ssh(host, opts, command):
def main():
(opts, action, cluster_name) = parse_args()
conn = boto.connect_ec2()
conn = boto.ec2.connect_to_region(opts.region)
# Select an AZ at random if it was not specified.
if opts.zone == "":