Merge pull request #112 from amatsukawa/scc

Strongly connected component algorithm
This commit is contained in:
Ankur Dave 2013-12-18 20:03:29 -08:00
commit da301b57fc
2 changed files with 130 additions and 0 deletions

View file

@ -0,0 +1,87 @@
package org.apache.spark.graph.algorithms
import org.apache.spark.graph._
object StronglyConnectedComponents {
/**
* Compute the strongly connected component (SCC) of each vertex and return an RDD with the vertex
* value containing the lowest vertex id in the SCC containing that vertex.
*
* @tparam VD the vertex attribute type (discarded in the computation)
* @tparam ED the edge attribute type (preserved in the computation)
*
* @param graph the graph for which to compute the SCC
*
* @return a graph with vertex attributes containing the smallest vertex id in each SCC
*/
def run[VD: Manifest, ED: Manifest](graph: Graph[VD, ED], numIter: Int): Graph[Vid, ED] = {
// the graph we update with final SCC ids, and the graph we return at the end
var sccGraph = graph.mapVertices { case (vid, _) => vid }
// graph we are going to work with in our iterations
var sccWorkGraph = graph.mapVertices { case (vid, _) => (vid, false) }
var numVertices = sccWorkGraph.numVertices
var iter = 0
while (sccWorkGraph.numVertices > 0 && iter < numIter) {
iter += 1
do {
numVertices = sccWorkGraph.numVertices
sccWorkGraph = sccWorkGraph.outerJoinVertices(sccWorkGraph.outDegrees) {
(vid, data, degreeOpt) => if (degreeOpt.isDefined) data else (vid, true)
}
sccWorkGraph = sccWorkGraph.outerJoinVertices(sccWorkGraph.inDegrees) {
(vid, data, degreeOpt) => if (degreeOpt.isDefined) data else (vid, true)
}
// get all vertices to be removed
val finalVertices = sccWorkGraph.vertices
.filter { case (vid, (scc, isFinal)) => isFinal}
.mapValues { (vid, data) => data._1}
// write values to sccGraph
sccGraph = sccGraph.outerJoinVertices(finalVertices) {
(vid, scc, opt) => opt.getOrElse(scc)
}
// only keep vertices that are not final
sccWorkGraph = sccWorkGraph.subgraph(vpred = (vid, data) => !data._2)
} while (sccWorkGraph.numVertices < numVertices)
sccWorkGraph = sccWorkGraph.mapVertices{ case (vid, (color, isFinal)) => (vid, isFinal) }
// collect min of all my neighbor's scc values, update if it's smaller than mine
// then notify any neighbors with scc values larger than mine
sccWorkGraph = GraphLab[(Vid, Boolean), ED, Vid](sccWorkGraph, Integer.MAX_VALUE)(
(vid, e) => e.otherVertexAttr(vid)._1,
(vid1, vid2) => math.min(vid1, vid2),
(vid, scc, optScc) =>
(math.min(scc._1, optScc.getOrElse(scc._1)), scc._2),
(vid, e) => e.vertexAttr(vid)._1 < e.otherVertexAttr(vid)._1
)
// start at root of SCCs. Traverse values in reverse, notify all my neighbors
// do not propagate if colors do not match!
sccWorkGraph = GraphLab[(Vid, Boolean), ED, Boolean](
sccWorkGraph,
Integer.MAX_VALUE,
EdgeDirection.Out,
EdgeDirection.In
)(
// vertex is final if it is the root of a color
// or it has the same color as a neighbor that is final
(vid, e) => (vid == e.vertexAttr(vid)._1) || (e.vertexAttr(vid)._1 == e.otherVertexAttr(vid)._1),
(final1, final2) => final1 || final2,
(vid, scc, optFinal) =>
(scc._1, scc._2 || optFinal.getOrElse(false)),
// activate neighbor if they are not final, you are, and you have the same color
(vid, e) => e.vertexAttr(vid)._2 &&
!e.otherVertexAttr(vid)._2 && (e.vertexAttr(vid)._1 == e.otherVertexAttr(vid)._1),
// start at root of colors
(vid, data) => vid == data._1
)
}
sccGraph
}
}

View file

@ -199,6 +199,49 @@ class AnalyticsSuite extends FunSuite with LocalSparkContext {
} }
} // end of reverse chain connected components } // end of reverse chain connected components
test("Island Strongly Connected Components") {
withSpark(new SparkContext("local", "test")) { sc =>
val vertices = sc.parallelize((1L to 5L).map(x => (x, -1)))
val edges = sc.parallelize(Seq.empty[Edge[Int]])
val graph = Graph(vertices, edges)
val sccGraph = StronglyConnectedComponents.run(graph, 5)
for ((id, scc) <- sccGraph.vertices.collect) {
assert(id == scc)
}
}
}
test("Cycle Strongly Connected Components") {
withSpark(new SparkContext("local", "test")) { sc =>
val rawEdges = sc.parallelize((0L to 6L).map(x => (x, (x + 1) % 7)))
val graph = Graph.fromEdgeTuples(rawEdges, -1)
val sccGraph = StronglyConnectedComponents.run(graph, 20)
for ((id, scc) <- sccGraph.vertices.collect) {
assert(0L == scc)
}
}
}
test("2 Cycle Strongly Connected Components") {
withSpark(new SparkContext("local", "test")) { sc =>
val edges =
Array(0L -> 1L, 1L -> 2L, 2L -> 0L) ++
Array(3L -> 4L, 4L -> 5L, 5L -> 3L) ++
Array(6L -> 0L, 5L -> 7L)
val rawEdges = sc.parallelize(edges)
val graph = Graph.fromEdgeTuples(rawEdges, -1)
val sccGraph = StronglyConnectedComponents.run(graph, 20)
for ((id, scc) <- sccGraph.vertices.collect) {
if (id < 3)
assert(0L == scc)
else if (id < 6)
assert(3L == scc)
else
assert(id == scc)
}
}
}
test("Count a single triangle") { test("Count a single triangle") {
withSpark(new SparkContext("local", "test")) { sc => withSpark(new SparkContext("local", "test")) { sc =>
val rawEdges = sc.parallelize(Array( 0L->1L, 1L->2L, 2L->0L ), 2) val rawEdges = sc.parallelize(Array( 0L->1L, 1L->2L, 2L->0L ), 2)