Added mapSideCombine flag to CoGroupedRDD. Added unit test for
CoGroupedRDD.
This commit is contained in:
parent
c1e9cdc49f
commit
00a11304fd
|
@ -2,10 +2,11 @@ package spark.rdd
|
|||
|
||||
import java.io.{ObjectOutputStream, IOException}
|
||||
import java.util.{HashMap => JHashMap}
|
||||
|
||||
import scala.collection.JavaConversions
|
||||
import scala.collection.mutable.ArrayBuffer
|
||||
|
||||
import spark.{Aggregator, Logging, Partitioner, RDD, SparkEnv, Partition, TaskContext}
|
||||
import spark.{Aggregator, Logging, Partition, Partitioner, RDD, SparkEnv, TaskContext}
|
||||
import spark.{Dependency, OneToOneDependency, ShuffleDependency}
|
||||
|
||||
|
||||
|
@ -28,7 +29,8 @@ private[spark] case class NarrowCoGroupSplitDep(
|
|||
private[spark] case class ShuffleCoGroupSplitDep(shuffleId: Int) extends CoGroupSplitDep
|
||||
|
||||
private[spark]
|
||||
class CoGroupPartition(idx: Int, val deps: Seq[CoGroupSplitDep]) extends Partition with Serializable {
|
||||
class CoGroupPartition(idx: Int, val deps: Seq[CoGroupSplitDep])
|
||||
extends Partition with Serializable {
|
||||
override val index: Int = idx
|
||||
override def hashCode(): Int = idx
|
||||
}
|
||||
|
@ -40,7 +42,19 @@ private[spark] class CoGroupAggregator
|
|||
{ (b1, b2) => b1 ++ b2 })
|
||||
with Serializable
|
||||
|
||||
class CoGroupedRDD[K](@transient var rdds: Seq[RDD[(K, _)]], part: Partitioner)
|
||||
|
||||
/**
|
||||
* A RDD that cogroups its parents. For each key k in parent RDDs, the resulting RDD contains a
|
||||
* tuple with the list of values for that key.
|
||||
*
|
||||
* @param rdds parent RDDs.
|
||||
* @param part partitioner used to partition the shuffle output.
|
||||
* @param mapSideCombine flag indicating whether to merge values before shuffle step.
|
||||
*/
|
||||
class CoGroupedRDD[K](
|
||||
@transient var rdds: Seq[RDD[(K, _)]],
|
||||
part: Partitioner,
|
||||
val mapSideCombine: Boolean = true)
|
||||
extends RDD[(K, Seq[Seq[_]])](rdds.head.context, Nil) {
|
||||
|
||||
private val aggr = new CoGroupAggregator
|
||||
|
@ -52,8 +66,12 @@ class CoGroupedRDD[K](@transient var rdds: Seq[RDD[(K, _)]], part: Partitioner)
|
|||
new OneToOneDependency(rdd)
|
||||
} else {
|
||||
logInfo("Adding shuffle dependency with " + rdd)
|
||||
val mapSideCombinedRDD = rdd.mapPartitions(aggr.combineValuesByKey(_), true)
|
||||
new ShuffleDependency[Any, ArrayBuffer[Any]](mapSideCombinedRDD, part)
|
||||
if (mapSideCombine) {
|
||||
val mapSideCombinedRDD = rdd.mapPartitions(aggr.combineValuesByKey(_), true)
|
||||
new ShuffleDependency[Any, ArrayBuffer[Any]](mapSideCombinedRDD, part)
|
||||
} else {
|
||||
new ShuffleDependency[Any, Any](rdd.asInstanceOf[RDD[(Any, Any)]], part)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -82,6 +100,7 @@ class CoGroupedRDD[K](@transient var rdds: Seq[RDD[(K, _)]], part: Partitioner)
|
|||
val numRdds = split.deps.size
|
||||
// e.g. for `(k, a) cogroup (k, b)`, K -> Seq(ArrayBuffer as, ArrayBuffer bs)
|
||||
val map = new JHashMap[K, Seq[ArrayBuffer[Any]]]
|
||||
|
||||
def getSeq(k: K): Seq[ArrayBuffer[Any]] = {
|
||||
val seq = map.get(k)
|
||||
if (seq != null) {
|
||||
|
@ -92,6 +111,7 @@ class CoGroupedRDD[K](@transient var rdds: Seq[RDD[(K, _)]], part: Partitioner)
|
|||
seq
|
||||
}
|
||||
}
|
||||
|
||||
for ((dep, depNum) <- split.deps.zipWithIndex) dep match {
|
||||
case NarrowCoGroupSplitDep(rdd, _, itsSplit) => {
|
||||
// Read them from the parent
|
||||
|
@ -102,9 +122,16 @@ class CoGroupedRDD[K](@transient var rdds: Seq[RDD[(K, _)]], part: Partitioner)
|
|||
case ShuffleCoGroupSplitDep(shuffleId) => {
|
||||
// Read map outputs of shuffle
|
||||
val fetcher = SparkEnv.get.shuffleFetcher
|
||||
val fetchItr = fetcher.fetch[K, Seq[Any]](shuffleId, split.index, context.taskMetrics)
|
||||
for ((k, vs) <- fetchItr) {
|
||||
getSeq(k)(depNum) ++= vs
|
||||
if (mapSideCombine) {
|
||||
// With map side combine on, for each key, the shuffle fetcher returns a list of values.
|
||||
fetcher.fetch[K, Seq[Any]](shuffleId, split.index, context.taskMetrics).foreach {
|
||||
case (key, values) => getSeq(key)(depNum) ++= values
|
||||
}
|
||||
} else {
|
||||
// With map side combine off, for each key the shuffle fetcher returns a single value.
|
||||
fetcher.fetch[K, Any](shuffleId, split.index, context.taskMetrics).foreach {
|
||||
case (key, value) => getSeq(key)(depNum) += value
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -3,7 +3,7 @@ package spark
|
|||
import scala.collection.mutable.HashMap
|
||||
import org.scalatest.FunSuite
|
||||
import spark.SparkContext._
|
||||
import spark.rdd.{CoalescedRDD, PartitionPruningRDD}
|
||||
import spark.rdd.{CoalescedRDD, CoGroupedRDD, PartitionPruningRDD}
|
||||
|
||||
class RDDSuite extends FunSuite with LocalSparkContext {
|
||||
|
||||
|
@ -123,6 +123,36 @@ class RDDSuite extends FunSuite with LocalSparkContext {
|
|||
assert(rdd.collect().toList === List(1, 2, 3, 4))
|
||||
}
|
||||
|
||||
test("cogrouped RDDs") {
|
||||
sc = new SparkContext("local", "test")
|
||||
val rdd1 = sc.makeRDD(Array((1, "one"), (1, "another one"), (2, "two"), (3, "three")), 2)
|
||||
val rdd2 = sc.makeRDD(Array((1, "one1"), (1, "another one1"), (2, "two1")), 2)
|
||||
|
||||
// Use cogroup function
|
||||
val cogrouped = rdd1.cogroup(rdd2).collectAsMap()
|
||||
assert(cogrouped(1) === (Seq("one", "another one"), Seq("one1", "another one1")))
|
||||
assert(cogrouped(2) === (Seq("two"), Seq("two1")))
|
||||
assert(cogrouped(3) === (Seq("three"), Seq()))
|
||||
|
||||
// Construct CoGroupedRDD directly, with map side combine enabled
|
||||
val cogrouped1 = new CoGroupedRDD[Int](
|
||||
Seq(rdd1.asInstanceOf[RDD[(Int, Any)]], rdd2.asInstanceOf[RDD[(Int, Any)]]),
|
||||
new HashPartitioner(3),
|
||||
true).collectAsMap()
|
||||
assert(cogrouped1(1).toSeq === Seq(Seq("one", "another one"), Seq("one1", "another one1")))
|
||||
assert(cogrouped1(2).toSeq === Seq(Seq("two"), Seq("two1")))
|
||||
assert(cogrouped1(3).toSeq === Seq(Seq("three"), Seq()))
|
||||
|
||||
// Construct CoGroupedRDD directly, with map side combine disabled
|
||||
val cogrouped2 = new CoGroupedRDD[Int](
|
||||
Seq(rdd1.asInstanceOf[RDD[(Int, Any)]], rdd2.asInstanceOf[RDD[(Int, Any)]]),
|
||||
new HashPartitioner(3),
|
||||
false).collectAsMap()
|
||||
assert(cogrouped2(1).toSeq === Seq(Seq("one", "another one"), Seq("one1", "another one1")))
|
||||
assert(cogrouped2(2).toSeq === Seq(Seq("two"), Seq("two1")))
|
||||
assert(cogrouped2(3).toSeq === Seq(Seq("three"), Seq()))
|
||||
}
|
||||
|
||||
test("coalesced RDDs") {
|
||||
sc = new SparkContext("local", "test")
|
||||
val data = sc.parallelize(1 to 10, 10)
|
||||
|
|
Loading…
Reference in a new issue