diff --git a/graph/src/main/scala/org/apache/spark/graph/Graph.scala b/graph/src/main/scala/org/apache/spark/graph/Graph.scala index 86502182fb..a0907c319a 100644 --- a/graph/src/main/scala/org/apache/spark/graph/Graph.scala +++ b/graph/src/main/scala/org/apache/spark/graph/Graph.scala @@ -278,7 +278,7 @@ abstract class Graph[VD: ClassManifest, ED: ClassManifest] { (mapFunc: (Vid, VD, Option[U]) => VD2) : Graph[VD2, ED] - def deltaJoinVertices(newVerts: VertexRDD[VD], changedVerts: VertexRDD[VD]): Graph[VD, ED] + def deltaJoinVertices(changedVerts: VertexRDD[VD]): Graph[VD, ED] // Save a copy of the GraphOps object so there is always one unique GraphOps object // for a given Graph object, and thus the lazy vals in GraphOps would work as intended. diff --git a/graph/src/main/scala/org/apache/spark/graph/Pregel.scala b/graph/src/main/scala/org/apache/spark/graph/Pregel.scala index ffbb6fe3ca..1e59a39485 100644 --- a/graph/src/main/scala/org/apache/spark/graph/Pregel.scala +++ b/graph/src/main/scala/org/apache/spark/graph/Pregel.scala @@ -103,12 +103,9 @@ object Pregel { // compute the messages val messages = g.mapReduceTriplets(sendMsg, mergeMsg) // broadcast & aggregation // receive the messages - val newVerts = g.vertices.zipJoin(messages)(vprog).cache() // updating the vertices - val changedVerts = g.vertices.diff(newVerts) - println("Replicating %d changed vertices instead of %d total vertices".format( - changedVerts.count, newVerts.count)) + val changedVerts = g.vertices.deltaJoin(messages)(vprog).cache() // updating the vertices // replicate the changed vertices - g = graph.deltaJoinVertices(newVerts, changedVerts) + g = graph.deltaJoinVertices(changedVerts) // count the iteration i += 1 } @@ -185,12 +182,9 @@ object Pregel { var i = 0 while (activeMessages > 0) { // receive the messages - val newVerts = g.vertices.zipJoin(messages)(vprog).cache() // updating the vertices - val changedVerts = g.vertices.diff(newVerts) - println("Replicating %d changed vertices instead of %d total vertices".format( - changedVerts.count, newVerts.count)) + val changedVerts = g.vertices.deltaJoin(messages)(vprog).cache() // updating the vertices // replicate the changed vertices - g = graph.deltaJoinVertices(newVerts, changedVerts) + g = graph.deltaJoinVertices(changedVerts) val oldMessages = messages // compute the messages diff --git a/graph/src/main/scala/org/apache/spark/graph/VertexRDD.scala b/graph/src/main/scala/org/apache/spark/graph/VertexRDD.scala index 5afe2df0ca..1b8ab89ebe 100644 --- a/graph/src/main/scala/org/apache/spark/graph/VertexRDD.scala +++ b/graph/src/main/scala/org/apache/spark/graph/VertexRDD.scala @@ -207,6 +207,14 @@ class VertexRDD[@specialized VD: ClassManifest]( } } + def deltaJoin[VD2: ClassManifest] + (other: VertexRDD[VD2])(f: (Vid, VD, VD2) => VD): VertexRDD[VD] = + { + this.zipVertexPartitions(other) { (thisPart, otherPart) => + thisPart.deltaJoin(otherPart)(f) + } + } + /** * Left join this VertexSet with another VertexSet which has the * same Index. This function will fail if both VertexSets do not diff --git a/graph/src/main/scala/org/apache/spark/graph/impl/GraphImpl.scala b/graph/src/main/scala/org/apache/spark/graph/impl/GraphImpl.scala index 6e9566e060..4300812990 100644 --- a/graph/src/main/scala/org/apache/spark/graph/impl/GraphImpl.scala +++ b/graph/src/main/scala/org/apache/spark/graph/impl/GraphImpl.scala @@ -239,9 +239,13 @@ class GraphImpl[VD: ClassManifest, ED: ClassManifest] protected ( new GraphImpl(newVTable, edges, vertexPlacement) } - override def deltaJoinVertices( - newVerts: VertexRDD[VD], - changedVerts: VertexRDD[VD]): Graph[VD, ED] = { + override def deltaJoinVertices(changedVerts: VertexRDD[VD]): Graph[VD, ED] = { + val newVerts = vertices.leftZipJoin(changedVerts) { (vid, oldAttr, newAttrOpt) => + newAttrOpt match { + case Some(newAttr) => newAttr + case None => oldAttr + } + } val newVTableReplicated = new VTableReplicated( changedVerts, edges, vertexPlacement, Some(vTableReplicated)) new GraphImpl(newVerts, edges, vertexPlacement, newVTableReplicated) diff --git a/graph/src/main/scala/org/apache/spark/graph/impl/VertexPartition.scala b/graph/src/main/scala/org/apache/spark/graph/impl/VertexPartition.scala index c922350345..0af445fa7d 100644 --- a/graph/src/main/scala/org/apache/spark/graph/impl/VertexPartition.scala +++ b/graph/src/main/scala/org/apache/spark/graph/impl/VertexPartition.scala @@ -127,6 +127,30 @@ class VertexPartition[@specialized(Long, Int, Double) VD: ClassManifest]( } } + /** Inner join another VertexPartition. */ + def deltaJoin[VD2: ClassManifest, VD3: ClassManifest] + (other: VertexPartition[VD2]) + (f: (Vid, VD, VD2) => VD3): VertexPartition[VD3] = + { + if (index != other.index) { + logWarning("Joining two VertexPartitions with different indexes is slow.") + join(createUsingIndex(other.iterator))(f) + } else { + val newValues = new Array[VD3](capacity) + val newMask = mask & other.mask + + var i = newMask.nextSetBit(0) + while (i >= 0) { + newValues(i) = f(index.getValue(i), values(i), other.values(i)) + if (newValues(i) == values(i)) { + newMask.unset(i) + } + i = mask.nextSetBit(i + 1) + } + new VertexPartition(index, newValues, newMask) + } + } + /** Left outer join another VertexPartition. */ def leftJoin[VD2: ClassManifest, VD3: ClassManifest] (other: VertexPartition[VD2])