Merge pull request #112 from amatsukawa/scc
Strongly connected component algorithm
This commit is contained in:
commit
da301b57fc
|
@ -0,0 +1,87 @@
|
|||
package org.apache.spark.graph.algorithms
|
||||
|
||||
import org.apache.spark.graph._
|
||||
|
||||
object StronglyConnectedComponents {
|
||||
|
||||
/**
|
||||
* Compute the strongly connected component (SCC) of each vertex and return an RDD with the vertex
|
||||
* value containing the lowest vertex id in the SCC containing that vertex.
|
||||
*
|
||||
* @tparam VD the vertex attribute type (discarded in the computation)
|
||||
* @tparam ED the edge attribute type (preserved in the computation)
|
||||
*
|
||||
* @param graph the graph for which to compute the SCC
|
||||
*
|
||||
* @return a graph with vertex attributes containing the smallest vertex id in each SCC
|
||||
*/
|
||||
def run[VD: Manifest, ED: Manifest](graph: Graph[VD, ED], numIter: Int): Graph[Vid, ED] = {
|
||||
|
||||
// the graph we update with final SCC ids, and the graph we return at the end
|
||||
var sccGraph = graph.mapVertices { case (vid, _) => vid }
|
||||
// graph we are going to work with in our iterations
|
||||
var sccWorkGraph = graph.mapVertices { case (vid, _) => (vid, false) }
|
||||
|
||||
var numVertices = sccWorkGraph.numVertices
|
||||
var iter = 0
|
||||
while (sccWorkGraph.numVertices > 0 && iter < numIter) {
|
||||
iter += 1
|
||||
do {
|
||||
numVertices = sccWorkGraph.numVertices
|
||||
sccWorkGraph = sccWorkGraph.outerJoinVertices(sccWorkGraph.outDegrees) {
|
||||
(vid, data, degreeOpt) => if (degreeOpt.isDefined) data else (vid, true)
|
||||
}
|
||||
sccWorkGraph = sccWorkGraph.outerJoinVertices(sccWorkGraph.inDegrees) {
|
||||
(vid, data, degreeOpt) => if (degreeOpt.isDefined) data else (vid, true)
|
||||
}
|
||||
|
||||
// get all vertices to be removed
|
||||
val finalVertices = sccWorkGraph.vertices
|
||||
.filter { case (vid, (scc, isFinal)) => isFinal}
|
||||
.mapValues { (vid, data) => data._1}
|
||||
|
||||
// write values to sccGraph
|
||||
sccGraph = sccGraph.outerJoinVertices(finalVertices) {
|
||||
(vid, scc, opt) => opt.getOrElse(scc)
|
||||
}
|
||||
// only keep vertices that are not final
|
||||
sccWorkGraph = sccWorkGraph.subgraph(vpred = (vid, data) => !data._2)
|
||||
} while (sccWorkGraph.numVertices < numVertices)
|
||||
|
||||
sccWorkGraph = sccWorkGraph.mapVertices{ case (vid, (color, isFinal)) => (vid, isFinal) }
|
||||
|
||||
// collect min of all my neighbor's scc values, update if it's smaller than mine
|
||||
// then notify any neighbors with scc values larger than mine
|
||||
sccWorkGraph = GraphLab[(Vid, Boolean), ED, Vid](sccWorkGraph, Integer.MAX_VALUE)(
|
||||
(vid, e) => e.otherVertexAttr(vid)._1,
|
||||
(vid1, vid2) => math.min(vid1, vid2),
|
||||
(vid, scc, optScc) =>
|
||||
(math.min(scc._1, optScc.getOrElse(scc._1)), scc._2),
|
||||
(vid, e) => e.vertexAttr(vid)._1 < e.otherVertexAttr(vid)._1
|
||||
)
|
||||
|
||||
// start at root of SCCs. Traverse values in reverse, notify all my neighbors
|
||||
// do not propagate if colors do not match!
|
||||
sccWorkGraph = GraphLab[(Vid, Boolean), ED, Boolean](
|
||||
sccWorkGraph,
|
||||
Integer.MAX_VALUE,
|
||||
EdgeDirection.Out,
|
||||
EdgeDirection.In
|
||||
)(
|
||||
// vertex is final if it is the root of a color
|
||||
// or it has the same color as a neighbor that is final
|
||||
(vid, e) => (vid == e.vertexAttr(vid)._1) || (e.vertexAttr(vid)._1 == e.otherVertexAttr(vid)._1),
|
||||
(final1, final2) => final1 || final2,
|
||||
(vid, scc, optFinal) =>
|
||||
(scc._1, scc._2 || optFinal.getOrElse(false)),
|
||||
// activate neighbor if they are not final, you are, and you have the same color
|
||||
(vid, e) => e.vertexAttr(vid)._2 &&
|
||||
!e.otherVertexAttr(vid)._2 && (e.vertexAttr(vid)._1 == e.otherVertexAttr(vid)._1),
|
||||
// start at root of colors
|
||||
(vid, data) => vid == data._1
|
||||
)
|
||||
}
|
||||
sccGraph
|
||||
}
|
||||
|
||||
}
|
|
@ -199,6 +199,49 @@ class AnalyticsSuite extends FunSuite with LocalSparkContext {
|
|||
}
|
||||
} // end of reverse chain connected components
|
||||
|
||||
test("Island Strongly Connected Components") {
|
||||
withSpark(new SparkContext("local", "test")) { sc =>
|
||||
val vertices = sc.parallelize((1L to 5L).map(x => (x, -1)))
|
||||
val edges = sc.parallelize(Seq.empty[Edge[Int]])
|
||||
val graph = Graph(vertices, edges)
|
||||
val sccGraph = StronglyConnectedComponents.run(graph, 5)
|
||||
for ((id, scc) <- sccGraph.vertices.collect) {
|
||||
assert(id == scc)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
test("Cycle Strongly Connected Components") {
|
||||
withSpark(new SparkContext("local", "test")) { sc =>
|
||||
val rawEdges = sc.parallelize((0L to 6L).map(x => (x, (x + 1) % 7)))
|
||||
val graph = Graph.fromEdgeTuples(rawEdges, -1)
|
||||
val sccGraph = StronglyConnectedComponents.run(graph, 20)
|
||||
for ((id, scc) <- sccGraph.vertices.collect) {
|
||||
assert(0L == scc)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
test("2 Cycle Strongly Connected Components") {
|
||||
withSpark(new SparkContext("local", "test")) { sc =>
|
||||
val edges =
|
||||
Array(0L -> 1L, 1L -> 2L, 2L -> 0L) ++
|
||||
Array(3L -> 4L, 4L -> 5L, 5L -> 3L) ++
|
||||
Array(6L -> 0L, 5L -> 7L)
|
||||
val rawEdges = sc.parallelize(edges)
|
||||
val graph = Graph.fromEdgeTuples(rawEdges, -1)
|
||||
val sccGraph = StronglyConnectedComponents.run(graph, 20)
|
||||
for ((id, scc) <- sccGraph.vertices.collect) {
|
||||
if (id < 3)
|
||||
assert(0L == scc)
|
||||
else if (id < 6)
|
||||
assert(3L == scc)
|
||||
else
|
||||
assert(id == scc)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
test("Count a single triangle") {
|
||||
withSpark(new SparkContext("local", "test")) { sc =>
|
||||
val rawEdges = sc.parallelize(Array( 0L->1L, 1L->2L, 2L->0L ), 2)
|
||||
|
|
Loading…
Reference in a new issue