Fixed a bug in VTableReplicated that we are always broadcasting all the vertices.
This commit is contained in:
parent
a6075ba11f
commit
15168d6c4d
|
@ -105,7 +105,7 @@ object Pregel {
|
||||||
// receive the messages
|
// receive the messages
|
||||||
val changedVerts = g.vertices.deltaJoin(messages)(vprog).cache() // updating the vertices
|
val changedVerts = g.vertices.deltaJoin(messages)(vprog).cache() // updating the vertices
|
||||||
// replicate the changed vertices
|
// replicate the changed vertices
|
||||||
g = graph.deltaJoinVertices(changedVerts)
|
g = g.deltaJoinVertices(changedVerts)
|
||||||
// count the iteration
|
// count the iteration
|
||||||
i += 1
|
i += 1
|
||||||
}
|
}
|
||||||
|
@ -177,19 +177,19 @@ object Pregel {
|
||||||
var g = graph.mapVertices( (vid, vdata) => vprog(vid, vdata, initialMsg) )
|
var g = graph.mapVertices( (vid, vdata) => vprog(vid, vdata, initialMsg) )
|
||||||
// compute the messages
|
// compute the messages
|
||||||
var messages = g.mapReduceTriplets(sendMsgFun, mergeMsg).cache()
|
var messages = g.mapReduceTriplets(sendMsgFun, mergeMsg).cache()
|
||||||
var activeMessages = messages.count
|
var activeMessages = messages.count()
|
||||||
// Loop
|
// Loop
|
||||||
var i = 0
|
var i = 0
|
||||||
while (activeMessages > 0) {
|
while (activeMessages > 0) {
|
||||||
// receive the messages
|
// receive the messages
|
||||||
val changedVerts = g.vertices.deltaJoin(messages)(vprog).cache() // updating the vertices
|
val changedVerts = g.vertices.deltaJoin(messages)(vprog).cache() // updating the vertices
|
||||||
// replicate the changed vertices
|
// replicate the changed vertices
|
||||||
g = graph.deltaJoinVertices(changedVerts)
|
g = g.deltaJoinVertices(changedVerts)
|
||||||
|
|
||||||
val oldMessages = messages
|
val oldMessages = messages
|
||||||
// compute the messages
|
// compute the messages
|
||||||
messages = g.mapReduceTriplets(sendMsgFun, mergeMsg).cache
|
messages = g.mapReduceTriplets(sendMsgFun, mergeMsg).cache()
|
||||||
activeMessages = messages.count
|
activeMessages = messages.count()
|
||||||
// after counting we can unpersist the old messages
|
// after counting we can unpersist the old messages
|
||||||
oldMessages.unpersist(blocking=false)
|
oldMessages.unpersist(blocking=false)
|
||||||
// count the iteration
|
// count the iteration
|
||||||
|
|
|
@ -2,7 +2,7 @@ package org.apache.spark.graph.impl
|
||||||
|
|
||||||
import org.apache.spark.SparkContext._
|
import org.apache.spark.SparkContext._
|
||||||
import org.apache.spark.rdd.RDD
|
import org.apache.spark.rdd.RDD
|
||||||
import org.apache.spark.util.collection.OpenHashSet
|
import org.apache.spark.util.collection.{PrimitiveVector, OpenHashSet}
|
||||||
|
|
||||||
import org.apache.spark.graph._
|
import org.apache.spark.graph._
|
||||||
|
|
||||||
|
@ -109,14 +109,23 @@ object VTableReplicated {
|
||||||
def buildBuffer[VD: ClassManifest](pid2vidIter: Iterator[Array[Array[Vid]]], vertexPartIter: Iterator[VertexPartition[VD]]) = {
|
def buildBuffer[VD: ClassManifest](pid2vidIter: Iterator[Array[Array[Vid]]], vertexPartIter: Iterator[VertexPartition[VD]]) = {
|
||||||
val pid2vid: Array[Array[Vid]] = pid2vidIter.next()
|
val pid2vid: Array[Array[Vid]] = pid2vidIter.next()
|
||||||
val vertexPart: VertexPartition[VD] = vertexPartIter.next()
|
val vertexPart: VertexPartition[VD] = vertexPartIter.next()
|
||||||
val output = new Array[(Pid, VertexAttributeBlock[VD])](pid2vid.size)
|
|
||||||
//val output = mmm.newArray(pid2vid.size)
|
Iterator.tabulate(pid2vid.size) { pid =>
|
||||||
for (pid <- 0 until pid2vid.size) {
|
val vidsCandidate = pid2vid(pid)
|
||||||
val block = new VertexAttributeBlock(
|
val size = vidsCandidate.length
|
||||||
pid2vid(pid), pid2vid(pid).map(vid => vertexPart(vid)).toArray)
|
val vids = new PrimitiveVector[Vid](pid2vid(pid).size)
|
||||||
output(pid) = (pid, block)
|
val attrs = new PrimitiveVector[VD](pid2vid(pid).size)
|
||||||
|
var i = 0
|
||||||
|
while (i < size) {
|
||||||
|
val vid = vidsCandidate(i)
|
||||||
|
if (vertexPart.isDefined(vid)) {
|
||||||
|
vids += vid
|
||||||
|
attrs += vertexPart(vid)
|
||||||
|
}
|
||||||
|
i += 1
|
||||||
|
}
|
||||||
|
(pid, new VertexAttributeBlock(vids.trim().array, attrs.trim().array))
|
||||||
}
|
}
|
||||||
output.iterator
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -159,14 +159,20 @@ class GraphSuite extends FunSuite with LocalSparkContext {
|
||||||
val star = Graph.fromEdgeTuples(sc.parallelize((1 to n).map(x => (0: Vid, x: Vid))), "v1").cache()
|
val star = Graph.fromEdgeTuples(sc.parallelize((1 to n).map(x => (0: Vid, x: Vid))), "v1").cache()
|
||||||
|
|
||||||
// Modify only vertices whose vids are even
|
// Modify only vertices whose vids are even
|
||||||
val newVerts = star.vertices.mapValues((vid, attr) => if (vid % 2 == 0) "v2" else attr)
|
val changedVerts = star.vertices.filter(_._1 % 2 == 0).mapValues((vid, attr) => "v2")
|
||||||
val changedVerts = star.vertices.diff(newVerts)
|
|
||||||
|
|
||||||
// Apply the modification to the graph
|
// Apply the modification to the graph
|
||||||
val changedStar = star.deltaJoinVertices(newVerts, changedVerts)
|
val changedStar = star.deltaJoinVertices(changedVerts)
|
||||||
|
|
||||||
|
val newVertices = star.vertices.leftZipJoin(changedVerts) { (vid, oldVd, newVdOpt) =>
|
||||||
|
newVdOpt match {
|
||||||
|
case Some(newVd) => newVd
|
||||||
|
case None => oldVd
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// The graph's vertices should be correct
|
// The graph's vertices should be correct
|
||||||
assert(changedStar.vertices.collect().toSet === newVerts.collect().toSet)
|
assert(changedStar.vertices.collect().toSet === newVertices.collect().toSet)
|
||||||
|
|
||||||
// Send the leaf attributes to the center
|
// Send the leaf attributes to the center
|
||||||
val sums = changedStar.mapReduceTriplets(
|
val sums = changedStar.mapReduceTriplets(
|
||||||
|
|
Loading…
Reference in a new issue