diff --git a/core/src/main/scala/org/apache/spark/rdd/IndexedRDD.scala b/core/src/main/scala/org/apache/spark/rdd/IndexedRDD.scala index a881ee3a1d..e099669c22 100644 --- a/core/src/main/scala/org/apache/spark/rdd/IndexedRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/IndexedRDD.scala @@ -216,6 +216,50 @@ object IndexedRDD { } } + + def reduceByKey[K: ClassManifest, V: ClassManifest]( + rdd: RDD[(K,V)], reduceFun: (V, V) => V, index: RDDIndex[K]): IndexedRDD[K,V] = { + // Get the index Partitioner + val partitioner = index.rdd.partitioner match { + case Some(p) => p + case None => throw new SparkException("An index must have a partitioner.") + } + // Preaggregate and shuffle if necessary + val partitioned = + if (rdd.partitioner != Some(partitioner)) { + // Preaggregation. + val aggregator = new Aggregator[K, V, V](v => v, reduceFun, reduceFun) + val combined = rdd.mapPartitions(aggregator.combineValuesByKey, preservesPartitioning = true) + combined.partitionBy(partitioner) //new ShuffledRDD[K, V, (K, V)](combined, partitioner) + } else { + rdd + } + + // Use the index to build the new values table + val values = index.rdd.zipPartitions(partitioned)( (indexIter, tblIter) => { + // There is only one map + val index = indexIter.next() + assert(!indexIter.hasNext()) + val values = new Array[Array[V]](index.size) + for ((k,v) <- tblIter) { + if (!index.contains(k)) { + throw new SparkException("Error: Trying to bind an external index " + + "to an RDD which contains keys that are not in the index.") + } + val ind = index(k) + if (values(ind) == null) { + values(ind) = Array(v) + } else { + values(ind)(0) = reduceFun(values(ind).head, v) + } + } + List(values.view.map(x => if (x != null) x.toSeq else null ).toSeq).iterator + }) + + new IndexedRDD[K,V](index, values) + + } + /** * Construct and index of the unique values in a given RDD. */ diff --git a/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala b/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala index 7fadbcf4ec..5d00917dab 100644 --- a/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala +++ b/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala @@ -245,7 +245,7 @@ class PairRDDFunctions[K: ClassManifest, V: ClassManifest](self: RDD[(K, V)]) if (getKeyClass().isArray && partitioner.isInstanceOf[HashPartitioner]) { throw new SparkException("Default partitioner cannot partition array keys.") } - new ShuffledRDD[K, V, (K, V)](self, partitioner) + new ShuffledRDD[K, V, (K, V)](self, partitioner) } /** diff --git a/graph/src/main/scala/org/apache/spark/graph/EdgeTriplet.scala b/graph/src/main/scala/org/apache/spark/graph/EdgeTriplet.scala index 7dfb5caa4c..c2ef63d1fd 100644 --- a/graph/src/main/scala/org/apache/spark/graph/EdgeTriplet.scala +++ b/graph/src/main/scala/org/apache/spark/graph/EdgeTriplet.scala @@ -18,12 +18,12 @@ class EdgeTriplet[VD, ED] extends Edge[ED] { /** * The source vertex attribute */ - var srcAttr: VD = nullValue[VD] + var srcAttr: VD = _ //nullValue[VD] /** * The destination vertex attribute */ - var dstAttr: VD = nullValue[VD] + var dstAttr: VD = _ //nullValue[VD] /** * Set the edge properties of this triplet. diff --git a/graph/src/main/scala/org/apache/spark/graph/impl/GraphImpl.scala b/graph/src/main/scala/org/apache/spark/graph/impl/GraphImpl.scala index aa0aaaaef4..73538862b1 100644 --- a/graph/src/main/scala/org/apache/spark/graph/impl/GraphImpl.scala +++ b/graph/src/main/scala/org/apache/spark/graph/impl/GraphImpl.scala @@ -341,9 +341,9 @@ class GraphImpl[VD: ClassManifest, ED: ClassManifest] protected ( ClosureCleaner.clean(mapFunc) ClosureCleaner.clean(reduceFunc) - val newVTable: RDD[(Vid, A)] = - vTableReplicated.join(eTable).flatMap{ - case (pid, (vmap, edgePartition)) => + // Map and preaggregate + val preAgg = vTableReplicated.join(eTable).flatMap{ + case (pid, (vmap, edgePartition)) => val aggMap = new VertexHashMap[A] val et = new EdgeTriplet[VD, ED] edgePartition.foreach{e => @@ -353,17 +353,17 @@ class GraphImpl[VD: ClassManifest, ED: ClassManifest] protected ( mapFunc(et).foreach{case (vid, a) => if(aggMap.containsKey(vid)) { aggMap.put(vid, reduceFunc(aggMap.get(vid), a)) - } else { aggMap.put(vid, a) } + } else { aggMap.put(vid, a) } + } } - } // Return the aggregate map aggMap.long2ObjectEntrySet().fastIterator().map{ entry => (entry.getLongKey(), entry.getValue()) } - } - .indexed(vTable.index).reduceByKey(reduceFunc) + }.partitionBy(vTable.index.rdd.partitioner.get) - newVTable + // do the final reduction reusing the index map + IndexedRDD.reduceByKey(preAgg, reduceFunc, vTable.index) }