From 2a8f3db94d779c6155d6bef8df952a90ef2b640c Mon Sep 17 00:00:00 2001 From: Dan Crankshaw Date: Sun, 6 Oct 2013 19:52:40 -0700 Subject: [PATCH] Fixed groupEdgeTriplets - it now passes a basic unit test. The problem was with the way the EdgeTripletRDD iterator worked. Calling toList on it returned the last value repeatedly. Fixed by overriding toList in the iterator. --- .../spark/graph/impl/EdgeTripletRDD.scala | 29 ++++++++++++++- .../apache/spark/graph/impl/GraphImpl.scala | 21 ++--------- .../org/apache/spark/graph/GraphSuite.scala | 37 +++++++------------ 3 files changed, 44 insertions(+), 43 deletions(-) diff --git a/graph/src/main/scala/org/apache/spark/graph/impl/EdgeTripletRDD.scala b/graph/src/main/scala/org/apache/spark/graph/impl/EdgeTripletRDD.scala index 18d5d2b5aa..1cd48120a1 100644 --- a/graph/src/main/scala/org/apache/spark/graph/impl/EdgeTripletRDD.scala +++ b/graph/src/main/scala/org/apache/spark/graph/impl/EdgeTripletRDD.scala @@ -1,5 +1,7 @@ package org.apache.spark.graph.impl +import scala.collection.mutable + import org.apache.spark.Aggregator import org.apache.spark.Partition import org.apache.spark.SparkEnv @@ -29,8 +31,8 @@ class EdgeTripletRDD[VD: ClassManifest, ED: ClassManifest]( eTable: RDD[(Pid, EdgePartition[ED])]) extends RDD[(VertexHashMap[VD], Iterator[EdgeTriplet[VD, ED]])](eTable.context, Nil) { - println(vTableReplicated.partitioner.get.numPartitions) - println(eTable.partitioner.get.numPartitions) + //println("ddshfkdfhds" + vTableReplicated.partitioner.get.numPartitions) + //println("9757984589347598734549" + eTable.partitioner.get.numPartitions) assert(vTableReplicated.partitioner == eTable.partitioner) @@ -77,10 +79,33 @@ class EdgeTripletRDD[VD: ClassManifest, ED: ClassManifest]( // assert(vmap.containsKey(e.dst.id)) e.dst.data = vmap.get(e.dst.id) + //println("Iter called: " + pos) e.data = edgePartition.data(pos) pos += 1 e } + + override def toList: List[EdgeTriplet[VD, ED]] = { + val lb = new mutable.ListBuffer[EdgeTriplet[VD,ED]] + for (i <- (0 until edgePartition.size)) { + val currentEdge = new EdgeTriplet[VD, ED] + currentEdge.src = new Vertex[VD] + currentEdge.dst = new Vertex[VD] + currentEdge.src.id = edgePartition.srcIds.getLong(i) + // assert(vmap.containsKey(e.src.id)) + currentEdge.src.data = vmap.get(currentEdge.src.id) + + currentEdge.dst.id = edgePartition.dstIds.getLong(i) + // assert(vmap.containsKey(e.dst.id)) + currentEdge.dst.data = vmap.get(currentEdge.dst.id) + + currentEdge.data = edgePartition.data(i) + //println("Iter: " + pos + " " + e.src.id + " " + e.dst.id + " " + e.data) + //println("List: " + i + " " + currentEdge.src.id + " " + currentEdge.dst.id + " " + currentEdge.data) + lb += currentEdge + } + lb.toList + } } Iterator((vmap, iter)) } diff --git a/graph/src/main/scala/org/apache/spark/graph/impl/GraphImpl.scala b/graph/src/main/scala/org/apache/spark/graph/impl/GraphImpl.scala index 6d2ce70ead..a6953d764c 100644 --- a/graph/src/main/scala/org/apache/spark/graph/impl/GraphImpl.scala +++ b/graph/src/main/scala/org/apache/spark/graph/impl/GraphImpl.scala @@ -145,13 +145,14 @@ class GraphImpl[VD: ClassManifest, ED: ClassManifest] protected ( // type that should have val newEdges: RDD[Edge[ED2]] = triplets.mapPartitions { partIter => // toList lets us operate on all EdgeTriplets in a single partition at once - partIter.toList + partIter + .toList // groups all ETs in this partition that have the same src and dst // Because all ETs with the same src and dst will live on the same // partition due to the EdgePartitioner, this guarantees that these // ET groups will be complete. .groupBy { t: EdgeTriplet[VD, ED] => - println(t.src.id + " " + t.dst.id) + //println("(" + t.src.id + ", " + t.dst.id + ", " + t.data + ")") (t.src.id, t.dst.id) } //.groupBy { e => (e.src, e.dst) } // Apply the user supplied supplied edge group function to @@ -202,22 +203,6 @@ class GraphImpl[VD: ClassManifest, ED: ClassManifest] protected ( - - //override def groupEdges[ED2: ClassManifest](f: Iterator[EdgeTriplet[VD,ED]] => ED2 ): - // Graph[VD,ED] = { - // val groups = triplets.collect.toList.groupBy { t => (t.src.id, t.dst.id) } - // for (k <- groups.keys) { - // println("^^^^^^^^^^^^^^^^^ " + k + " ^^^^^^^^^^^^^^^^^^^^^") - - // } - // val transformMap: Map[(Vid, Vid), ED2] = groups.mapValues { ts => f(ts.toIterator) } - // val newList: List[((Vid, Vid), ED2)] = transformMap.toList - // val newEdges: List[Edge[ED2]] = newList.map { case ((src, dst), data) => Edge(src, dst, data) } - - // newGraph(vertices, edges) - - //} - ////////////////////////////////////////////////////////////////////////////////////////////////// // Lower level transformation methods ////////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/graph/src/test/scala/org/apache/spark/graph/GraphSuite.scala b/graph/src/test/scala/org/apache/spark/graph/GraphSuite.scala index ce9d2104a2..8c85260c1b 100644 --- a/graph/src/test/scala/org/apache/spark/graph/GraphSuite.scala +++ b/graph/src/test/scala/org/apache/spark/graph/GraphSuite.scala @@ -47,34 +47,25 @@ class GraphSuite extends FunSuite with LocalSparkContext { withSpark(new SparkContext("local", "test")) { sc => val vertices = sc.parallelize(List(Vertex(6, 1),Vertex(7, 1), Vertex(8,1))) val edges = sc.parallelize(List( - Edge(6, 7, 0.4), - Edge(6, 7, 0.9), - Edge(6, 7, 0.7), - Edge(7, 6, 25.0), - Edge(7, 6, 300.0), - Edge(7, 6, 600.0), - Edge(8, 7, 11.0), - Edge(7, 8, 89.0))) + Edge(6, 7, 4), + Edge(6, 7, 9), + Edge(6, 7, 7), + Edge(7, 6, 25), + Edge(7, 6, 300), + Edge(7, 6, 600), + Edge(8, 7, 11), + Edge(7, 8, 89))) val original = Graph(vertices, edges) - for (e <- original.edges) { - println("(" + e.src + ", " + e.dst + ", " + e.data + ")") - } - //assert(original.edges.count() === 6) val grouped = original.groupEdgeTriplets { iter => - println("----------------------------------------") iter.map(_.data).sum } - for (e <- grouped.edges) { - println("******************************(" + e.src + ", " + e.dst + ", " + e.data + ")") - } - - //val groups: Map[(Vid, Vid), List[Edge[Double]]] = original.edges.collect.toList.groupBy { e => (e.src, e.dst) } - //for (k <- groups.keys) { - // println("################# " + k + " #################") - //} - //assert(grouped.edges.count() === 2) - //assert(grouped.edges.collect().toSet === Set(Edge(0, 1, 2.0), Edge(1, 0, 6.0))) + assert(grouped.edges.count() === 4) + assert(grouped.edges.collect().toSet === Set( + Edge(6, 7, 20), + Edge(7, 6, 925), + Edge(8, 7, 11), + Edge(7, 8, 89))) } }