Fixed a bug in VTableReplicated that we are always broadcasting all the vertices.

This commit is contained in:
Reynold Xin 2013-12-05 23:25:53 -08:00
parent a6075ba11f
commit 15168d6c4d
3 changed files with 32 additions and 17 deletions

View file

@ -105,7 +105,7 @@ object Pregel {
// receive the messages
val changedVerts = g.vertices.deltaJoin(messages)(vprog).cache() // updating the vertices
// replicate the changed vertices
g = graph.deltaJoinVertices(changedVerts)
g = g.deltaJoinVertices(changedVerts)
// count the iteration
i += 1
}
@ -177,19 +177,19 @@ object Pregel {
var g = graph.mapVertices( (vid, vdata) => vprog(vid, vdata, initialMsg) )
// compute the messages
var messages = g.mapReduceTriplets(sendMsgFun, mergeMsg).cache()
var activeMessages = messages.count
var activeMessages = messages.count()
// Loop
var i = 0
while (activeMessages > 0) {
// receive the messages
val changedVerts = g.vertices.deltaJoin(messages)(vprog).cache() // updating the vertices
// replicate the changed vertices
g = graph.deltaJoinVertices(changedVerts)
g = g.deltaJoinVertices(changedVerts)
val oldMessages = messages
// compute the messages
messages = g.mapReduceTriplets(sendMsgFun, mergeMsg).cache
activeMessages = messages.count
messages = g.mapReduceTriplets(sendMsgFun, mergeMsg).cache()
activeMessages = messages.count()
// after counting we can unpersist the old messages
oldMessages.unpersist(blocking=false)
// count the iteration

View file

@ -2,7 +2,7 @@ package org.apache.spark.graph.impl
import org.apache.spark.SparkContext._
import org.apache.spark.rdd.RDD
import org.apache.spark.util.collection.OpenHashSet
import org.apache.spark.util.collection.{PrimitiveVector, OpenHashSet}
import org.apache.spark.graph._
@ -109,14 +109,23 @@ object VTableReplicated {
def buildBuffer[VD: ClassManifest](pid2vidIter: Iterator[Array[Array[Vid]]], vertexPartIter: Iterator[VertexPartition[VD]]) = {
val pid2vid: Array[Array[Vid]] = pid2vidIter.next()
val vertexPart: VertexPartition[VD] = vertexPartIter.next()
val output = new Array[(Pid, VertexAttributeBlock[VD])](pid2vid.size)
//val output = mmm.newArray(pid2vid.size)
for (pid <- 0 until pid2vid.size) {
val block = new VertexAttributeBlock(
pid2vid(pid), pid2vid(pid).map(vid => vertexPart(vid)).toArray)
output(pid) = (pid, block)
Iterator.tabulate(pid2vid.size) { pid =>
val vidsCandidate = pid2vid(pid)
val size = vidsCandidate.length
val vids = new PrimitiveVector[Vid](pid2vid(pid).size)
val attrs = new PrimitiveVector[VD](pid2vid(pid).size)
var i = 0
while (i < size) {
val vid = vidsCandidate(i)
if (vertexPart.isDefined(vid)) {
vids += vid
attrs += vertexPart(vid)
}
i += 1
}
(pid, new VertexAttributeBlock(vids.trim().array, attrs.trim().array))
}
output.iterator
}
}

View file

@ -159,14 +159,20 @@ class GraphSuite extends FunSuite with LocalSparkContext {
val star = Graph.fromEdgeTuples(sc.parallelize((1 to n).map(x => (0: Vid, x: Vid))), "v1").cache()
// Modify only vertices whose vids are even
val newVerts = star.vertices.mapValues((vid, attr) => if (vid % 2 == 0) "v2" else attr)
val changedVerts = star.vertices.diff(newVerts)
val changedVerts = star.vertices.filter(_._1 % 2 == 0).mapValues((vid, attr) => "v2")
// Apply the modification to the graph
val changedStar = star.deltaJoinVertices(newVerts, changedVerts)
val changedStar = star.deltaJoinVertices(changedVerts)
val newVertices = star.vertices.leftZipJoin(changedVerts) { (vid, oldVd, newVdOpt) =>
newVdOpt match {
case Some(newVd) => newVd
case None => oldVd
}
}
// The graph's vertices should be correct
assert(changedStar.vertices.collect().toSet === newVerts.collect().toSet)
assert(changedStar.vertices.collect().toSet === newVertices.collect().toSet)
// Send the leaf attributes to the center
val sums = changedStar.mapReduceTriplets(