diff --git a/graph/src/main/scala/org/apache/spark/graph/Graph.scala b/graph/src/main/scala/org/apache/spark/graph/Graph.scala index 1fb22c56ff..b5c4fcc99b 100644 --- a/graph/src/main/scala/org/apache/spark/graph/Graph.scala +++ b/graph/src/main/scala/org/apache/spark/graph/Graph.scala @@ -227,12 +227,12 @@ abstract class Graph[VD: ClassManifest, ED: ClassManifest] { * }}} * */ - def aggregateNeighbors[VD2: ClassManifest]( - mapFunc: (Vid, EdgeTriplet[VD, ED]) => Option[VD2], - mergeFunc: (VD2, VD2) => VD2, + def aggregateNeighbors[A: ClassManifest]( + mapFunc: (Vid, EdgeTriplet[VD, ED]) => Option[A], + mergeFunc: (A, A) => A, direction: EdgeDirection) - : RDD[(Vid, VD2)] - + : Graph[(VD, Option[A]), ED] + // TODO: consider a version that doesn't preserve the original VD /** * This function is used to compute a statistic for the neighborhood of each diff --git a/graph/src/main/scala/org/apache/spark/graph/GraphLab.scala b/graph/src/main/scala/org/apache/spark/graph/GraphLab.scala index 1dba813e91..01f24a1302 100644 --- a/graph/src/main/scala/org/apache/spark/graph/GraphLab.scala +++ b/graph/src/main/scala/org/apache/spark/graph/GraphLab.scala @@ -44,13 +44,13 @@ object GraphLab { // Add an active attribute to all vertices to track convergence. - var activeGraph = graph.mapVertices { + var activeGraph: Graph[(Boolean, VD), ED] = graph.mapVertices { case Vertex(id, data) => (true, data) }.cache() // The gather function wrapper strips the active attribute and // only invokes the gather function on active vertices - def gather(vid: Vid, e: EdgeTriplet[(Boolean, VD), ED]) = { + def gather(vid: Vid, e: EdgeTriplet[(Boolean, VD), ED]): Option[A] = { if (e.vertex(vid).data._1) { val edge = new EdgeTriplet[VD,ED] edge.src = Vertex(e.src.id, e.src.data._2) @@ -64,14 +64,15 @@ object GraphLab { // The apply function wrapper strips the vertex of the active attribute // and only invokes the apply function on active vertices - def apply(v: Vertex[(Boolean, VD)], accum: Option[A]) = { - if (v.data._1) (true, applyFunc(Vertex(v.id, v.data._2), accum)) - else (false, v.data._2) + def apply(v: Vertex[((Boolean, VD), Option[A])]): (Boolean, VD) = { + val ((active, vData), accum) = v.data + if (active) (true, applyFunc(Vertex(v.id, vData), accum)) + else (false, vData) } // The scatter function wrapper strips the vertex of the active attribute // and only invokes the scatter function on active vertices - def scatter(rawVid: Vid, e: EdgeTriplet[(Boolean, VD), ED]) = { + def scatter(rawVid: Vid, e: EdgeTriplet[(Boolean, VD), ED]): Option[Boolean] = { val vid = e.otherVertex(rawVid).id if (e.vertex(vid).data._1) { val edge = new EdgeTriplet[VD,ED] @@ -88,24 +89,31 @@ object GraphLab { } // Used to set the active status of vertices for the next round - def applyActive(v: Vertex[(Boolean, VD)], accum: Option[Boolean]) = - (accum.getOrElse(false), v.data._2) + def applyActive(v: Vertex[((Boolean, VD), Option[Boolean])]): (Boolean, VD) = { + val ((prevActive, vData), newActive) = v.data + (newActive.getOrElse(false), vData) + } // Main Loop --------------------------------------------------------------------- var i = 0 var numActive = activeGraph.numVertices while (i < numIter && numActive > 0) { - val accUpdates: RDD[(Vid, A)] = + val gathered: Graph[((Boolean, VD), Option[A]), ED] = activeGraph.aggregateNeighbors(gather, mergeFunc, gatherDirection) - activeGraph = activeGraph.leftJoinVertices(accUpdates, apply).cache() + val applied: Graph[(Boolean, VD), ED] = gathered.mapVertices(apply).cache() + + activeGraph = applied.cache() // Scatter is basically a gather in the opposite direction so we reverse the edge direction - val activeVertices: RDD[(Vid, Boolean)] = + // activeGraph: Graph[(Boolean, VD), ED] + val scattered: Graph[((Boolean, VD), Option[Boolean]), ED] = activeGraph.aggregateNeighbors(scatter, _ || _, scatterDirection.reverse) + val newActiveGraph: Graph[(Boolean, VD), ED] = + scattered.mapVertices(applyActive) - activeGraph = activeGraph.leftJoinVertices(activeVertices, applyActive).cache() + activeGraph = newActiveGraph.cache() numActive = activeGraph.vertices.map(v => if (v.data._1) 1 else 0).reduce(_ + _) println("Number active vertices: " + numActive) diff --git a/graph/src/main/scala/org/apache/spark/graph/GraphOps.scala b/graph/src/main/scala/org/apache/spark/graph/GraphOps.scala index 8de96680b8..9e8cc0a6d5 100644 --- a/graph/src/main/scala/org/apache/spark/graph/GraphOps.scala +++ b/graph/src/main/scala/org/apache/spark/graph/GraphOps.scala @@ -9,22 +9,29 @@ class GraphOps[VD: ClassManifest, ED: ClassManifest](g: Graph[VD, ED]) { lazy val numVertices: Long = g.vertices.count() - lazy val inDegrees: RDD[(Vid, Int)] = { - g.aggregateNeighbors((vid, edge) => Some(1), _+_, EdgeDirection.In) - } + lazy val inDegrees: RDD[(Vid, Int)] = degreesRDD(EdgeDirection.In) - lazy val outDegrees: RDD[(Vid, Int)] = { - g.aggregateNeighbors((vid, edge) => Some(1), _+_, EdgeDirection.Out) - } + lazy val outDegrees: RDD[(Vid, Int)] = degreesRDD(EdgeDirection.Out) - lazy val degrees: RDD[(Vid, Int)] = { - g.aggregateNeighbors((vid, edge) => Some(1), _+_, EdgeDirection.Both) - } + lazy val degrees: RDD[(Vid, Int)] = degreesRDD(EdgeDirection.Both) def collectNeighborIds(edgeDirection: EdgeDirection) : RDD[(Vid, Array[Vid])] = { - g.aggregateNeighbors( + val graph: Graph[(VD, Option[Array[Vid]]), ED] = g.aggregateNeighbors( (vid, edge) => Some(Array(edge.otherVertex(vid).id)), (a, b) => a ++ b, edgeDirection) + graph.vertices.map(v => { + val (_, neighborIds) = v.data + (v.id, neighborIds.getOrElse(Array())) + }) + } + + private def degreesRDD(edgeDirection: EdgeDirection): RDD[(Vid, Int)] = { + val degreeGraph: Graph[(VD, Option[Int]), ED] = + g.aggregateNeighbors((vid, edge) => Some(1), _+_, edgeDirection) + degreeGraph.vertices.map(v => { + val (_, degree) = v.data + (v.id, degree.getOrElse(0)) + }) } } diff --git a/graph/src/main/scala/org/apache/spark/graph/Pregel.scala b/graph/src/main/scala/org/apache/spark/graph/Pregel.scala index 27b75a7988..09bcc67c8c 100644 --- a/graph/src/main/scala/org/apache/spark/graph/Pregel.scala +++ b/graph/src/main/scala/org/apache/spark/graph/Pregel.scala @@ -19,18 +19,25 @@ object Pregel { def mapF(vid: Vid, edge: EdgeTriplet[VD,ED]) = sendMsg(edge.otherVertex(vid).id, edge) - def runProg(v: Vertex[VD], msg: Option[A]): VD = { - if (msg.isEmpty) v.data else vprog(v, msg.get) + def runProg(vertexWithMsgs: Vertex[(VD, Option[A])]): VD = { + val (vData, msg) = vertexWithMsgs.data + val v = Vertex(vertexWithMsgs.id, vData) + msg match { + case Some(m) => vprog(v, m) + case None => v.data + } } - var msgs: RDD[(Vid, A)] = g.vertices.map{ v => (v.id, initialMsg) } + var graphWithMsgs: Graph[(VD, Option[A]), ED] = + g.mapVertices(v => (v.data, Some(initialMsg))) while (i < numIter) { - g = g.leftJoinVertices(msgs, runProg).cache() - msgs = g.aggregateNeighbors(mapF, mergeMsg, EdgeDirection.In) + val newGraph: Graph[VD, ED] = graphWithMsgs.mapVertices(runProg).cache() + graphWithMsgs = newGraph.aggregateNeighbors(mapF, mergeMsg, EdgeDirection.In) i += 1 } - g + graphWithMsgs.mapVertices(vertexWithMsgs => vertexWithMsgs.data match { + case (vData, _) => vData + }) } - }