Fixed a bug in VTableReplicated that we are always broadcasting all the vertices.

This commit is contained in:
Reynold Xin 2013-12-05 23:25:53 -08:00
parent a6075ba11f
commit 15168d6c4d
3 changed files with 32 additions and 17 deletions

View file

@ -105,7 +105,7 @@ object Pregel {
// receive the messages // receive the messages
val changedVerts = g.vertices.deltaJoin(messages)(vprog).cache() // updating the vertices val changedVerts = g.vertices.deltaJoin(messages)(vprog).cache() // updating the vertices
// replicate the changed vertices // replicate the changed vertices
g = graph.deltaJoinVertices(changedVerts) g = g.deltaJoinVertices(changedVerts)
// count the iteration // count the iteration
i += 1 i += 1
} }
@ -177,19 +177,19 @@ object Pregel {
var g = graph.mapVertices( (vid, vdata) => vprog(vid, vdata, initialMsg) ) var g = graph.mapVertices( (vid, vdata) => vprog(vid, vdata, initialMsg) )
// compute the messages // compute the messages
var messages = g.mapReduceTriplets(sendMsgFun, mergeMsg).cache() var messages = g.mapReduceTriplets(sendMsgFun, mergeMsg).cache()
var activeMessages = messages.count var activeMessages = messages.count()
// Loop // Loop
var i = 0 var i = 0
while (activeMessages > 0) { while (activeMessages > 0) {
// receive the messages // receive the messages
val changedVerts = g.vertices.deltaJoin(messages)(vprog).cache() // updating the vertices val changedVerts = g.vertices.deltaJoin(messages)(vprog).cache() // updating the vertices
// replicate the changed vertices // replicate the changed vertices
g = graph.deltaJoinVertices(changedVerts) g = g.deltaJoinVertices(changedVerts)
val oldMessages = messages val oldMessages = messages
// compute the messages // compute the messages
messages = g.mapReduceTriplets(sendMsgFun, mergeMsg).cache messages = g.mapReduceTriplets(sendMsgFun, mergeMsg).cache()
activeMessages = messages.count activeMessages = messages.count()
// after counting we can unpersist the old messages // after counting we can unpersist the old messages
oldMessages.unpersist(blocking=false) oldMessages.unpersist(blocking=false)
// count the iteration // count the iteration

View file

@ -2,7 +2,7 @@ package org.apache.spark.graph.impl
import org.apache.spark.SparkContext._ import org.apache.spark.SparkContext._
import org.apache.spark.rdd.RDD import org.apache.spark.rdd.RDD
import org.apache.spark.util.collection.OpenHashSet import org.apache.spark.util.collection.{PrimitiveVector, OpenHashSet}
import org.apache.spark.graph._ import org.apache.spark.graph._
@ -109,14 +109,23 @@ object VTableReplicated {
def buildBuffer[VD: ClassManifest](pid2vidIter: Iterator[Array[Array[Vid]]], vertexPartIter: Iterator[VertexPartition[VD]]) = { def buildBuffer[VD: ClassManifest](pid2vidIter: Iterator[Array[Array[Vid]]], vertexPartIter: Iterator[VertexPartition[VD]]) = {
val pid2vid: Array[Array[Vid]] = pid2vidIter.next() val pid2vid: Array[Array[Vid]] = pid2vidIter.next()
val vertexPart: VertexPartition[VD] = vertexPartIter.next() val vertexPart: VertexPartition[VD] = vertexPartIter.next()
val output = new Array[(Pid, VertexAttributeBlock[VD])](pid2vid.size)
//val output = mmm.newArray(pid2vid.size) Iterator.tabulate(pid2vid.size) { pid =>
for (pid <- 0 until pid2vid.size) { val vidsCandidate = pid2vid(pid)
val block = new VertexAttributeBlock( val size = vidsCandidate.length
pid2vid(pid), pid2vid(pid).map(vid => vertexPart(vid)).toArray) val vids = new PrimitiveVector[Vid](pid2vid(pid).size)
output(pid) = (pid, block) val attrs = new PrimitiveVector[VD](pid2vid(pid).size)
var i = 0
while (i < size) {
val vid = vidsCandidate(i)
if (vertexPart.isDefined(vid)) {
vids += vid
attrs += vertexPart(vid)
}
i += 1
}
(pid, new VertexAttributeBlock(vids.trim().array, attrs.trim().array))
} }
output.iterator
} }
} }

View file

@ -159,14 +159,20 @@ class GraphSuite extends FunSuite with LocalSparkContext {
val star = Graph.fromEdgeTuples(sc.parallelize((1 to n).map(x => (0: Vid, x: Vid))), "v1").cache() val star = Graph.fromEdgeTuples(sc.parallelize((1 to n).map(x => (0: Vid, x: Vid))), "v1").cache()
// Modify only vertices whose vids are even // Modify only vertices whose vids are even
val newVerts = star.vertices.mapValues((vid, attr) => if (vid % 2 == 0) "v2" else attr) val changedVerts = star.vertices.filter(_._1 % 2 == 0).mapValues((vid, attr) => "v2")
val changedVerts = star.vertices.diff(newVerts)
// Apply the modification to the graph // Apply the modification to the graph
val changedStar = star.deltaJoinVertices(newVerts, changedVerts) val changedStar = star.deltaJoinVertices(changedVerts)
val newVertices = star.vertices.leftZipJoin(changedVerts) { (vid, oldVd, newVdOpt) =>
newVdOpt match {
case Some(newVd) => newVd
case None => oldVd
}
}
// The graph's vertices should be correct // The graph's vertices should be correct
assert(changedStar.vertices.collect().toSet === newVerts.collect().toSet) assert(changedStar.vertices.collect().toSet === newVertices.collect().toSet)
// Send the leaf attributes to the center // Send the leaf attributes to the center
val sums = changedStar.mapReduceTriplets( val sums = changedStar.mapReduceTriplets(