Merge remote-tracking branch 'upstream/dev' into dev
This commit is contained in:
commit
8f2bd399da
|
@ -1,7 +1,17 @@
|
|||
package spark
|
||||
|
||||
/** A set of functions used to aggregate data.
|
||||
*
|
||||
* @param createCombiner function to create the initial value of the aggregation.
|
||||
* @param mergeValue function to merge a new value into the aggregation result.
|
||||
* @param mergeCombiners function to merge outputs from multiple mergeValue function.
|
||||
* @param mapSideCombine whether to apply combiners on map partitions, also
|
||||
* known as map-side aggregations. When set to false,
|
||||
* mergeCombiners function is not used.
|
||||
*/
|
||||
class Aggregator[K, V, C] (
|
||||
val createCombiner: V => C,
|
||||
val mergeValue: (C, V) => C,
|
||||
val mergeCombiners: (C, C) => C)
|
||||
val mergeCombiners: (C, C) => C,
|
||||
val mapSideCombine: Boolean = true)
|
||||
extends Serializable
|
||||
|
|
|
@ -27,16 +27,35 @@ class ShuffledRDD[K, V, C](
|
|||
|
||||
override def compute(split: Split): Iterator[(K, C)] = {
|
||||
val combiners = new JHashMap[K, C]
|
||||
def mergePair(k: K, c: C) {
|
||||
val oldC = combiners.get(k)
|
||||
if (oldC == null) {
|
||||
combiners.put(k, c)
|
||||
} else {
|
||||
combiners.put(k, aggregator.mergeCombiners(oldC, c))
|
||||
}
|
||||
}
|
||||
val fetcher = SparkEnv.get.shuffleFetcher
|
||||
fetcher.fetch[K, C](dep.shuffleId, split.index, mergePair)
|
||||
|
||||
if (aggregator.mapSideCombine) {
|
||||
// Apply combiners on map partitions. In this case, post-shuffle we get a
|
||||
// list of outputs from the combiners and merge them using mergeCombiners.
|
||||
def mergePairWithMapSideCombiners(k: K, c: C) {
|
||||
val oldC = combiners.get(k)
|
||||
if (oldC == null) {
|
||||
combiners.put(k, c)
|
||||
} else {
|
||||
combiners.put(k, aggregator.mergeCombiners(oldC, c))
|
||||
}
|
||||
}
|
||||
fetcher.fetch[K, C](dep.shuffleId, split.index, mergePairWithMapSideCombiners)
|
||||
} else {
|
||||
// Do not apply combiners on map partitions (i.e. map side aggregation is
|
||||
// turned off). Post-shuffle we get a list of values and we use mergeValue
|
||||
// to merge them.
|
||||
def mergePairWithoutMapSideCombiners(k: K, v: V) {
|
||||
val oldC = combiners.get(k)
|
||||
if (oldC == null) {
|
||||
combiners.put(k, aggregator.createCombiner(v))
|
||||
} else {
|
||||
combiners.put(k, aggregator.mergeValue(oldC, v))
|
||||
}
|
||||
}
|
||||
fetcher.fetch[K, V](dep.shuffleId, split.index, mergePairWithoutMapSideCombiners)
|
||||
}
|
||||
|
||||
return new Iterator[(K, C)] {
|
||||
var iter = combiners.entrySet().iterator()
|
||||
|
||||
|
|
|
@ -106,27 +106,44 @@ class ShuffleMapTask(
|
|||
val numOutputSplits = dep.partitioner.numPartitions
|
||||
val aggregator = dep.aggregator.asInstanceOf[Aggregator[Any, Any, Any]]
|
||||
val partitioner = dep.partitioner
|
||||
val buckets = Array.tabulate(numOutputSplits)(_ => new HashMap[Any, Any])
|
||||
for (elem <- rdd.iterator(split)) {
|
||||
val (k, v) = elem.asInstanceOf[(Any, Any)]
|
||||
var bucketId = partitioner.getPartition(k)
|
||||
val bucket = buckets(bucketId)
|
||||
var existing = bucket.get(k)
|
||||
if (existing == null) {
|
||||
bucket.put(k, aggregator.createCombiner(v))
|
||||
|
||||
val bucketIterators =
|
||||
if (aggregator.mapSideCombine) {
|
||||
// Apply combiners (map-side aggregation) to the map output.
|
||||
val buckets = Array.tabulate(numOutputSplits)(_ => new HashMap[Any, Any])
|
||||
for (elem <- rdd.iterator(split)) {
|
||||
val (k, v) = elem.asInstanceOf[(Any, Any)]
|
||||
val bucketId = partitioner.getPartition(k)
|
||||
val bucket = buckets(bucketId)
|
||||
val existing = bucket.get(k)
|
||||
if (existing == null) {
|
||||
bucket.put(k, aggregator.createCombiner(v))
|
||||
} else {
|
||||
bucket.put(k, aggregator.mergeValue(existing, v))
|
||||
}
|
||||
}
|
||||
buckets.map(_.iterator)
|
||||
} else {
|
||||
bucket.put(k, aggregator.mergeValue(existing, v))
|
||||
// No combiners (no map-side aggregation). Simply partition the map output.
|
||||
val buckets = Array.tabulate(numOutputSplits)(_ => new ArrayBuffer[(Any, Any)])
|
||||
for (elem <- rdd.iterator(split)) {
|
||||
val pair = elem.asInstanceOf[(Any, Any)]
|
||||
val bucketId = partitioner.getPartition(pair._1)
|
||||
buckets(bucketId) += pair
|
||||
}
|
||||
buckets.map(_.iterator)
|
||||
}
|
||||
}
|
||||
|
||||
val ser = SparkEnv.get.serializer.newInstance()
|
||||
val blockManager = SparkEnv.get.blockManager
|
||||
for (i <- 0 until numOutputSplits) {
|
||||
val blockId = "shuffleid_" + dep.shuffleId + "_" + partition + "_" + i
|
||||
// Get a scala iterator from java map
|
||||
val iter: Iterator[(Any, Any)] = buckets(i).iterator
|
||||
val iter: Iterator[(Any, Any)] = bucketIterators(i)
|
||||
// TODO: This should probably be DISK_ONLY
|
||||
blockManager.put(blockId, iter, StorageLevel.MEMORY_ONLY, false)
|
||||
}
|
||||
|
||||
return SparkEnv.get.blockManager.blockManagerId
|
||||
}
|
||||
|
||||
|
|
|
@ -2,6 +2,7 @@ package spark
|
|||
|
||||
import org.scalatest.FunSuite
|
||||
import org.scalatest.BeforeAndAfter
|
||||
import org.scalatest.matchers.ShouldMatchers
|
||||
import org.scalatest.prop.Checkers
|
||||
import org.scalacheck.Arbitrary._
|
||||
import org.scalacheck.Gen
|
||||
|
@ -13,7 +14,7 @@ import scala.collection.mutable.ArrayBuffer
|
|||
|
||||
import SparkContext._
|
||||
|
||||
class ShuffleSuite extends FunSuite with BeforeAndAfter {
|
||||
class ShuffleSuite extends FunSuite with ShouldMatchers with BeforeAndAfter {
|
||||
|
||||
var sc: SparkContext = _
|
||||
|
||||
|
@ -196,4 +197,46 @@ class ShuffleSuite extends FunSuite with BeforeAndAfter {
|
|||
// Test that a shuffle on the file works, because this used to be a bug
|
||||
assert(file.map(line => (line, 1)).reduceByKey(_ + _).collect().toList === Nil)
|
||||
}
|
||||
|
||||
test("map-side combine") {
|
||||
sc = new SparkContext("local", "test")
|
||||
val pairs = sc.parallelize(Array((1, 1), (1, 2), (1, 3), (1, 1), (2, 1), (1, 1)), 2)
|
||||
|
||||
// Test with map-side combine on.
|
||||
val sums = pairs.reduceByKey(_+_).collect()
|
||||
assert(sums.toSet === Set((1, 8), (2, 1)))
|
||||
|
||||
// Turn off map-side combine and test the results.
|
||||
val aggregator = new Aggregator[Int, Int, Int](
|
||||
(v: Int) => v,
|
||||
_+_,
|
||||
_+_,
|
||||
false)
|
||||
val shuffledRdd = new ShuffledRDD(
|
||||
pairs, aggregator, new HashPartitioner(2))
|
||||
assert(shuffledRdd.collect().toSet === Set((1, 8), (2, 1)))
|
||||
|
||||
// Turn map-side combine off and pass a wrong mergeCombine function. Should
|
||||
// not see an exception because mergeCombine should not have been called.
|
||||
val aggregatorWithException = new Aggregator[Int, Int, Int](
|
||||
(v: Int) => v, _+_, ShuffleSuite.mergeCombineException, false)
|
||||
val shuffledRdd1 = new ShuffledRDD(
|
||||
pairs, aggregatorWithException, new HashPartitioner(2))
|
||||
assert(shuffledRdd1.collect().toSet === Set((1, 8), (2, 1)))
|
||||
|
||||
// Now run the same mergeCombine function with map-side combine on. We
|
||||
// expect to see an exception thrown.
|
||||
val aggregatorWithException1 = new Aggregator[Int, Int, Int](
|
||||
(v: Int) => v, _+_, ShuffleSuite.mergeCombineException)
|
||||
val shuffledRdd2 = new ShuffledRDD(
|
||||
pairs, aggregatorWithException1, new HashPartitioner(2))
|
||||
evaluating { shuffledRdd2.collect() } should produce [SparkException]
|
||||
}
|
||||
}
|
||||
|
||||
object ShuffleSuite {
|
||||
def mergeCombineException(x: Int, y: Int): Int = {
|
||||
throw new SparkException("Exception for map-side combine.")
|
||||
x + y
|
||||
}
|
||||
}
|
||||
|
|
|
@ -59,7 +59,9 @@ def parse_args():
|
|||
"WARNING: must be 64-bit; small instances won't work")
|
||||
parser.add_option("-m", "--master-instance-type", default="",
|
||||
help="Master instance type (leave empty for same as instance-type)")
|
||||
parser.add_option("-z", "--zone", default="us-east-1b",
|
||||
parser.add_option("-r", "--region", default="us-east-1",
|
||||
help="EC2 region zone to launch instances in")
|
||||
parser.add_option("-z", "--zone", default="",
|
||||
help="Availability zone to launch instances in")
|
||||
parser.add_option("-a", "--ami", default="latest",
|
||||
help="Amazon Machine Image ID to use, or 'latest' to use latest " +
|
||||
|
@ -470,7 +472,7 @@ def ssh(host, opts, command):
|
|||
|
||||
def main():
|
||||
(opts, action, cluster_name) = parse_args()
|
||||
conn = boto.connect_ec2()
|
||||
conn = boto.ec2.connect_to_region(opts.region)
|
||||
|
||||
# Select an AZ at random if it was not specified.
|
||||
if opts.zone == "":
|
||||
|
|
Loading…
Reference in a new issue