Finished updating connected components to used Pregel like abstraction and created a series of tests in the AnalyticsSuite.
This commit is contained in:
parent
a2287ae138
commit
d6a902f309
|
@ -167,8 +167,17 @@ object Analytics extends Logging {
|
|||
* and return an RDD with the vertex value containing the
|
||||
* lowest vertex id in the connected component containing
|
||||
* that vertex.
|
||||
*
|
||||
* @tparam VD the vertex attribute type (discarded in the computation)
|
||||
* @tparam ED the edge attribute type (preserved in the computation)
|
||||
*
|
||||
* @param graph the graph for which to compute the connected components
|
||||
*
|
||||
* @return a graph with vertex attributes containing the smallest vertex
|
||||
* in each connected component
|
||||
*/
|
||||
def connectedComponents[VD: Manifest, ED: Manifest](graph: Graph[VD, ED]) = {
|
||||
def connectedComponents[VD: Manifest, ED: Manifest](graph: Graph[VD, ED]):
|
||||
Graph[Vid, ED] = {
|
||||
val ccGraph = graph.mapVertices { case (vid, _) => vid }
|
||||
|
||||
def sendMessage(id: Vid, edge: EdgeTriplet[Vid, ED]): Option[Vid] = {
|
||||
|
@ -179,21 +188,27 @@ object Analytics extends Logging {
|
|||
}
|
||||
|
||||
val initialMessage = Long.MaxValue
|
||||
Pregel(ccGraph, initialMessage)(
|
||||
Pregel(ccGraph, initialMessage, EdgeDirection.Both)(
|
||||
(id, attr, msg) => math.min(attr, msg),
|
||||
sendMessage,
|
||||
(a,b) => math.min(a,b)
|
||||
)
|
||||
|
||||
/**
|
||||
* Originally this was implemented using the GraphLab abstraction but with
|
||||
* support for message computation along all edge directions the pregel
|
||||
* abstraction is sufficient
|
||||
*/
|
||||
// GraphLab(ccGraph, gatherDirection = EdgeDirection.Both, scatterDirection = EdgeDirection.Both)(
|
||||
// (me_id, edge) => edge.otherVertexAttr(me_id), // gather
|
||||
// (a: Vid, b: Vid) => math.min(a, b), // merge
|
||||
// (id, data, a: Option[Vid]) => math.min(data, a.getOrElse(Long.MaxValue)), // apply
|
||||
// (me_id, edge) => (edge.vertexAttr(me_id) < edge.otherVertexAttr(me_id))
|
||||
// )
|
||||
} // end of connectedComponents
|
||||
|
||||
}
|
||||
|
||||
|
||||
def main(args: Array[String]) = {
|
||||
val host = args(0)
|
||||
val taskType = args(1)
|
||||
|
|
|
@ -79,6 +79,7 @@ class AnalyticsSuite extends FunSuite with LocalSparkContext {
|
|||
} // end of test Star PageRank
|
||||
|
||||
|
||||
|
||||
test("Grid PageRank") {
|
||||
withSpark(new SparkContext("local", "test")) { sc =>
|
||||
val gridGraph = GraphGenerators.gridGraph(sc, 10, 10)
|
||||
|
@ -104,4 +105,68 @@ class AnalyticsSuite extends FunSuite with LocalSparkContext {
|
|||
} // end of Grid PageRank
|
||||
|
||||
|
||||
test("Grid Connected Components") {
|
||||
withSpark(new SparkContext("local", "test")) { sc =>
|
||||
val gridGraph = GraphGenerators.gridGraph(sc, 10, 10)
|
||||
val ccGraph = Analytics.connectedComponents(gridGraph).cache()
|
||||
val maxCCid = ccGraph.vertices.map { case (vid, ccId) => ccId }.sum
|
||||
assert(maxCCid === 0)
|
||||
}
|
||||
} // end of Grid connected components
|
||||
|
||||
|
||||
test("Reverse Grid Connected Components") {
|
||||
withSpark(new SparkContext("local", "test")) { sc =>
|
||||
val gridGraph = GraphGenerators.gridGraph(sc, 10, 10).reverse
|
||||
val ccGraph = Analytics.connectedComponents(gridGraph).cache()
|
||||
val maxCCid = ccGraph.vertices.map { case (vid, ccId) => ccId }.sum
|
||||
assert(maxCCid === 0)
|
||||
}
|
||||
} // end of Grid connected components
|
||||
|
||||
|
||||
test("Chain Connected Components") {
|
||||
withSpark(new SparkContext("local", "test")) { sc =>
|
||||
val chain1 = (0 until 9).map(x => (x, x+1) )
|
||||
val chain2 = (10 until 20).map(x => (x, x+1) )
|
||||
val rawEdges = sc.parallelize(chain1 ++ chain2, 3).map { case (s,d) => (s.toLong, d.toLong) }
|
||||
val twoChains = Graph(rawEdges)
|
||||
val ccGraph = Analytics.connectedComponents(twoChains).cache()
|
||||
val vertices = ccGraph.vertices.collect
|
||||
for ( (id, cc) <- vertices ) {
|
||||
if(id < 10) { assert(cc === 0) }
|
||||
else { assert(cc === 10) }
|
||||
}
|
||||
val ccMap = vertices.toMap
|
||||
println(ccMap)
|
||||
for( id <- 0 until 20 ) {
|
||||
if(id < 10) { assert(ccMap(id) === 0) }
|
||||
else { assert(ccMap(id) === 10) }
|
||||
}
|
||||
}
|
||||
} // end of chain connected components
|
||||
|
||||
test("Reverse Chain Connected Components") {
|
||||
withSpark(new SparkContext("local", "test")) { sc =>
|
||||
val chain1 = (0 until 9).map(x => (x, x+1) )
|
||||
val chain2 = (10 until 20).map(x => (x, x+1) )
|
||||
val rawEdges = sc.parallelize(chain1 ++ chain2, 3).map { case (s,d) => (s.toLong, d.toLong) }
|
||||
val twoChains = Graph(rawEdges).reverse
|
||||
val ccGraph = Analytics.connectedComponents(twoChains).cache()
|
||||
val vertices = ccGraph.vertices.collect
|
||||
for ( (id, cc) <- vertices ) {
|
||||
if(id < 10) { assert(cc === 0) }
|
||||
else { assert(cc === 10) }
|
||||
}
|
||||
val ccMap = vertices.toMap
|
||||
println(ccMap)
|
||||
for( id <- 0 until 20 ) {
|
||||
if(id < 10) { assert(ccMap(id) === 0) }
|
||||
else { assert(ccMap(id) === 10) }
|
||||
}
|
||||
}
|
||||
} // end of chain connected components
|
||||
|
||||
|
||||
|
||||
} // end of AnalyticsSuite
|
||||
|
|
Loading…
Reference in a new issue