diff --git a/graph/src/main/scala/org/apache/spark/graph/Pregel.scala b/graph/src/main/scala/org/apache/spark/graph/Pregel.scala index 1e59a39485..8e9ca89efc 100644 --- a/graph/src/main/scala/org/apache/spark/graph/Pregel.scala +++ b/graph/src/main/scala/org/apache/spark/graph/Pregel.scala @@ -105,7 +105,7 @@ object Pregel { // receive the messages val changedVerts = g.vertices.deltaJoin(messages)(vprog).cache() // updating the vertices // replicate the changed vertices - g = graph.deltaJoinVertices(changedVerts) + g = g.deltaJoinVertices(changedVerts) // count the iteration i += 1 } @@ -177,19 +177,19 @@ object Pregel { var g = graph.mapVertices( (vid, vdata) => vprog(vid, vdata, initialMsg) ) // compute the messages var messages = g.mapReduceTriplets(sendMsgFun, mergeMsg).cache() - var activeMessages = messages.count + var activeMessages = messages.count() // Loop var i = 0 while (activeMessages > 0) { // receive the messages val changedVerts = g.vertices.deltaJoin(messages)(vprog).cache() // updating the vertices // replicate the changed vertices - g = graph.deltaJoinVertices(changedVerts) + g = g.deltaJoinVertices(changedVerts) val oldMessages = messages // compute the messages - messages = g.mapReduceTriplets(sendMsgFun, mergeMsg).cache - activeMessages = messages.count + messages = g.mapReduceTriplets(sendMsgFun, mergeMsg).cache() + activeMessages = messages.count() // after counting we can unpersist the old messages oldMessages.unpersist(blocking=false) // count the iteration diff --git a/graph/src/main/scala/org/apache/spark/graph/impl/VTableReplicated.scala b/graph/src/main/scala/org/apache/spark/graph/impl/VTableReplicated.scala index 3e3769b9af..0c50ad09c7 100644 --- a/graph/src/main/scala/org/apache/spark/graph/impl/VTableReplicated.scala +++ b/graph/src/main/scala/org/apache/spark/graph/impl/VTableReplicated.scala @@ -2,7 +2,7 @@ package org.apache.spark.graph.impl import org.apache.spark.SparkContext._ import org.apache.spark.rdd.RDD -import org.apache.spark.util.collection.OpenHashSet +import org.apache.spark.util.collection.{PrimitiveVector, OpenHashSet} import org.apache.spark.graph._ @@ -109,14 +109,23 @@ object VTableReplicated { def buildBuffer[VD: ClassManifest](pid2vidIter: Iterator[Array[Array[Vid]]], vertexPartIter: Iterator[VertexPartition[VD]]) = { val pid2vid: Array[Array[Vid]] = pid2vidIter.next() val vertexPart: VertexPartition[VD] = vertexPartIter.next() - val output = new Array[(Pid, VertexAttributeBlock[VD])](pid2vid.size) - //val output = mmm.newArray(pid2vid.size) - for (pid <- 0 until pid2vid.size) { - val block = new VertexAttributeBlock( - pid2vid(pid), pid2vid(pid).map(vid => vertexPart(vid)).toArray) - output(pid) = (pid, block) + + Iterator.tabulate(pid2vid.size) { pid => + val vidsCandidate = pid2vid(pid) + val size = vidsCandidate.length + val vids = new PrimitiveVector[Vid](pid2vid(pid).size) + val attrs = new PrimitiveVector[VD](pid2vid(pid).size) + var i = 0 + while (i < size) { + val vid = vidsCandidate(i) + if (vertexPart.isDefined(vid)) { + vids += vid + attrs += vertexPart(vid) + } + i += 1 + } + (pid, new VertexAttributeBlock(vids.trim().array, attrs.trim().array)) } - output.iterator } } diff --git a/graph/src/test/scala/org/apache/spark/graph/GraphSuite.scala b/graph/src/test/scala/org/apache/spark/graph/GraphSuite.scala index fa4ebf3c88..514d20b76c 100644 --- a/graph/src/test/scala/org/apache/spark/graph/GraphSuite.scala +++ b/graph/src/test/scala/org/apache/spark/graph/GraphSuite.scala @@ -159,14 +159,20 @@ class GraphSuite extends FunSuite with LocalSparkContext { val star = Graph.fromEdgeTuples(sc.parallelize((1 to n).map(x => (0: Vid, x: Vid))), "v1").cache() // Modify only vertices whose vids are even - val newVerts = star.vertices.mapValues((vid, attr) => if (vid % 2 == 0) "v2" else attr) - val changedVerts = star.vertices.diff(newVerts) + val changedVerts = star.vertices.filter(_._1 % 2 == 0).mapValues((vid, attr) => "v2") // Apply the modification to the graph - val changedStar = star.deltaJoinVertices(newVerts, changedVerts) + val changedStar = star.deltaJoinVertices(changedVerts) + + val newVertices = star.vertices.leftZipJoin(changedVerts) { (vid, oldVd, newVdOpt) => + newVdOpt match { + case Some(newVd) => newVd + case None => oldVd + } + } // The graph's vertices should be correct - assert(changedStar.vertices.collect().toSet === newVerts.collect().toSet) + assert(changedStar.vertices.collect().toSet === newVertices.collect().toSet) // Send the leaf attributes to the center val sums = changedStar.mapReduceTriplets(