Finished updating connected components to used Pregel like abstraction and created a series of tests in the AnalyticsSuite.

This commit is contained in:
Joseph E. Gonzalez 2013-10-28 11:52:26 -07:00
parent a2287ae138
commit d6a902f309
2 changed files with 83 additions and 3 deletions

View file

@ -167,8 +167,17 @@ object Analytics extends Logging {
* and return an RDD with the vertex value containing the
* lowest vertex id in the connected component containing
* that vertex.
*
* @tparam VD the vertex attribute type (discarded in the computation)
* @tparam ED the edge attribute type (preserved in the computation)
*
* @param graph the graph for which to compute the connected components
*
* @return a graph with vertex attributes containing the smallest vertex
* in each connected component
*/
def connectedComponents[VD: Manifest, ED: Manifest](graph: Graph[VD, ED]) = {
def connectedComponents[VD: Manifest, ED: Manifest](graph: Graph[VD, ED]):
Graph[Vid, ED] = {
val ccGraph = graph.mapVertices { case (vid, _) => vid }
def sendMessage(id: Vid, edge: EdgeTriplet[Vid, ED]): Option[Vid] = {
@ -179,21 +188,27 @@ object Analytics extends Logging {
}
val initialMessage = Long.MaxValue
Pregel(ccGraph, initialMessage)(
Pregel(ccGraph, initialMessage, EdgeDirection.Both)(
(id, attr, msg) => math.min(attr, msg),
sendMessage,
(a,b) => math.min(a,b)
)
/**
* Originally this was implemented using the GraphLab abstraction but with
* support for message computation along all edge directions the pregel
* abstraction is sufficient
*/
// GraphLab(ccGraph, gatherDirection = EdgeDirection.Both, scatterDirection = EdgeDirection.Both)(
// (me_id, edge) => edge.otherVertexAttr(me_id), // gather
// (a: Vid, b: Vid) => math.min(a, b), // merge
// (id, data, a: Option[Vid]) => math.min(data, a.getOrElse(Long.MaxValue)), // apply
// (me_id, edge) => (edge.vertexAttr(me_id) < edge.otherVertexAttr(me_id))
// )
} // end of connectedComponents
}
def main(args: Array[String]) = {
val host = args(0)
val taskType = args(1)

View file

@ -79,6 +79,7 @@ class AnalyticsSuite extends FunSuite with LocalSparkContext {
} // end of test Star PageRank
test("Grid PageRank") {
withSpark(new SparkContext("local", "test")) { sc =>
val gridGraph = GraphGenerators.gridGraph(sc, 10, 10)
@ -104,4 +105,68 @@ class AnalyticsSuite extends FunSuite with LocalSparkContext {
} // end of Grid PageRank
test("Grid Connected Components") {
withSpark(new SparkContext("local", "test")) { sc =>
val gridGraph = GraphGenerators.gridGraph(sc, 10, 10)
val ccGraph = Analytics.connectedComponents(gridGraph).cache()
val maxCCid = ccGraph.vertices.map { case (vid, ccId) => ccId }.sum
assert(maxCCid === 0)
}
} // end of Grid connected components
test("Reverse Grid Connected Components") {
withSpark(new SparkContext("local", "test")) { sc =>
val gridGraph = GraphGenerators.gridGraph(sc, 10, 10).reverse
val ccGraph = Analytics.connectedComponents(gridGraph).cache()
val maxCCid = ccGraph.vertices.map { case (vid, ccId) => ccId }.sum
assert(maxCCid === 0)
}
} // end of Grid connected components
test("Chain Connected Components") {
withSpark(new SparkContext("local", "test")) { sc =>
val chain1 = (0 until 9).map(x => (x, x+1) )
val chain2 = (10 until 20).map(x => (x, x+1) )
val rawEdges = sc.parallelize(chain1 ++ chain2, 3).map { case (s,d) => (s.toLong, d.toLong) }
val twoChains = Graph(rawEdges)
val ccGraph = Analytics.connectedComponents(twoChains).cache()
val vertices = ccGraph.vertices.collect
for ( (id, cc) <- vertices ) {
if(id < 10) { assert(cc === 0) }
else { assert(cc === 10) }
}
val ccMap = vertices.toMap
println(ccMap)
for( id <- 0 until 20 ) {
if(id < 10) { assert(ccMap(id) === 0) }
else { assert(ccMap(id) === 10) }
}
}
} // end of chain connected components
test("Reverse Chain Connected Components") {
withSpark(new SparkContext("local", "test")) { sc =>
val chain1 = (0 until 9).map(x => (x, x+1) )
val chain2 = (10 until 20).map(x => (x, x+1) )
val rawEdges = sc.parallelize(chain1 ++ chain2, 3).map { case (s,d) => (s.toLong, d.toLong) }
val twoChains = Graph(rawEdges).reverse
val ccGraph = Analytics.connectedComponents(twoChains).cache()
val vertices = ccGraph.vertices.collect
for ( (id, cc) <- vertices ) {
if(id < 10) { assert(cc === 0) }
else { assert(cc === 10) }
}
val ccMap = vertices.toMap
println(ccMap)
for( id <- 0 until 20 ) {
if(id < 10) { assert(ccMap(id) === 0) }
else { assert(ccMap(id) === 10) }
}
}
} // end of chain connected components
} // end of AnalyticsSuite