Adding triangle count code
This commit is contained in:
parent
8719ba83c8
commit
2093a17ff3
|
@ -2,8 +2,6 @@ package org.apache.spark.graph
|
|||
|
||||
import org.apache.spark._
|
||||
|
||||
|
||||
|
||||
/**
|
||||
* The Analytics object contains a collection of basic graph analytics
|
||||
* algorithms that operate largely on the graph structure.
|
||||
|
@ -204,6 +202,65 @@ object Analytics extends Logging {
|
|||
} // end of connectedComponents
|
||||
|
||||
|
||||
/**
|
||||
* Compute the number of triangles passing through each vertex.
|
||||
*
|
||||
* @param graph
|
||||
* @tparam VD
|
||||
* @tparam ED
|
||||
* @return
|
||||
*/
|
||||
def triangleCount[VD: ClassManifest, ED: ClassManifest](rawGraph: Graph[VD,ED]):
|
||||
Graph[Int, ED] = {
|
||||
// Remove redundant edges
|
||||
val graph = rawGraph.groupEdges( (a,b) => a ).cache
|
||||
|
||||
// Construct set representations of the neighborhoods
|
||||
val nbrSets: VertexSetRDD[VertexSet] =
|
||||
graph.collectNeighborIds(EdgeDirection.Both).mapValuesWithKeys { (vid, nbrs) =>
|
||||
val set = new VertexSet//(math.ceil(nbrs.size/0.7).toInt)
|
||||
var i = 0
|
||||
while (i < nbrs.size) {
|
||||
// prevent self cycle
|
||||
if(nbrs(i) != vid) set.add(nbrs(i))
|
||||
i += 1
|
||||
}
|
||||
set
|
||||
}
|
||||
// join the sets with the graph
|
||||
val setGraph: Graph[VertexSet, ED] = graph.outerJoinVertices(nbrSets) {
|
||||
(vid, _, optSet) => optSet.getOrElse(null)
|
||||
}
|
||||
// Edge function computes intersection of smaller vertex with larger vertex
|
||||
def edgeFunc(et: EdgeTriplet[VertexSet, ED]): Array[(Vid, Int)] = {
|
||||
assert(et.srcAttr != null)
|
||||
assert(et.dstAttr != null)
|
||||
val (smallSet, largeSet) =
|
||||
if (et.srcAttr.size < et.dstAttr.size) { (et.srcAttr, et.dstAttr) }
|
||||
else { (et.dstAttr, et.srcAttr) }
|
||||
val iter = smallSet.iterator()
|
||||
var counter: Int = 0
|
||||
while (iter.hasNext) {
|
||||
val vid = iter.next
|
||||
if (vid != et.srcId && vid != et.dstId && largeSet.contains(vid)) { counter += 1 }
|
||||
}
|
||||
Array((et.srcId, counter), (et.dstId, counter))
|
||||
}
|
||||
// compute the intersection along edges
|
||||
val counters: VertexSetRDD[Int] = setGraph.mapReduceTriplets(edgeFunc, _+_)
|
||||
// Merge counters with the graph and divide by two since each triangle is counted twice
|
||||
graph.outerJoinVertices(counters) {
|
||||
(vid, _, optCounter: Option[Int]) =>
|
||||
val dblCount = optCounter.getOrElse(0)
|
||||
// double count should be even (divisible by two)
|
||||
assert((dblCount & 1) == 0 )
|
||||
dblCount/2
|
||||
}
|
||||
|
||||
} // end of TriangleCount
|
||||
|
||||
|
||||
|
||||
|
||||
def main(args: Array[String]) = {
|
||||
val host = args(0)
|
||||
|
@ -277,7 +334,7 @@ object Analytics extends Logging {
|
|||
|
||||
val sc = new SparkContext(host, "PageRank(" + fname + ")")
|
||||
|
||||
val graph = GraphLoader.textFile(sc, fname, a => 1.0F,
|
||||
val graph = GraphLoader.edgeListFile(sc, fname,
|
||||
minEdgePartitions = numEPart, minVertexPartitions = numVPart).cache()
|
||||
|
||||
val startTime = System.currentTimeMillis
|
||||
|
@ -329,7 +386,7 @@ object Analytics extends Logging {
|
|||
|
||||
val sc = new SparkContext(host, "ConnectedComponents(" + fname + ")")
|
||||
//val graph = GraphLoader.textFile(sc, fname, a => 1.0F)
|
||||
val graph = GraphLoader.textFile(sc, fname, a => 1.0F,
|
||||
val graph = GraphLoader.edgeListFile(sc, fname,
|
||||
minEdgePartitions = numEPart, minVertexPartitions = numVPart).cache()
|
||||
val cc = Analytics.connectedComponents(graph)
|
||||
//val cc = if(isDynamic) Analytics.dynamicConnectedComponents(graph, numIter)
|
||||
|
@ -338,6 +395,31 @@ object Analytics extends Logging {
|
|||
|
||||
sc.stop()
|
||||
}
|
||||
|
||||
case "triangles" => {
|
||||
var numVPart = 4
|
||||
var numEPart = 4
|
||||
|
||||
options.foreach{
|
||||
case ("numEPart", v) => numEPart = v.toInt
|
||||
case ("numVPart", v) => numVPart = v.toInt
|
||||
case (opt, _) => throw new IllegalArgumentException("Invalid option: " + opt)
|
||||
}
|
||||
println("======================================")
|
||||
println("| Triangle Count |")
|
||||
println("--------------------------------------")
|
||||
val sc = new SparkContext(host, "TriangleCount(" + fname + ")")
|
||||
//val graph = GraphLoader.textFile(sc, fname, a => 1.0F)
|
||||
val graph = GraphLoader.edgeListFileUndirected(sc, fname,
|
||||
minEdgePartitions = numEPart, minVertexPartitions = numVPart).cache()
|
||||
val triangles = Analytics.triangleCount(graph)
|
||||
//val cc = if(isDynamic) Analytics.dynamicConnectedComponents(graph, numIter)
|
||||
// else Analytics.connectedComponents(graph, numIter)
|
||||
println("Triangles: " + triangles.vertices.map{ case (vid,data) => data.toLong }.reduce(_+_) /3)
|
||||
|
||||
sc.stop()
|
||||
}
|
||||
|
||||
//
|
||||
// case "shortestpath" => {
|
||||
//
|
||||
|
|
|
@ -132,7 +132,7 @@ class AnalyticsSuite extends FunSuite with LocalSparkContext {
|
|||
val chain1 = (0 until 9).map(x => (x, x+1) )
|
||||
val chain2 = (10 until 20).map(x => (x, x+1) )
|
||||
val rawEdges = sc.parallelize(chain1 ++ chain2, 3).map { case (s,d) => (s.toLong, d.toLong) }
|
||||
val twoChains = Graph(rawEdges)
|
||||
val twoChains = Graph(rawEdges, 1.0)
|
||||
val ccGraph = Analytics.connectedComponents(twoChains).cache()
|
||||
val vertices = ccGraph.vertices.collect
|
||||
for ( (id, cc) <- vertices ) {
|
||||
|
@ -153,7 +153,7 @@ class AnalyticsSuite extends FunSuite with LocalSparkContext {
|
|||
val chain1 = (0 until 9).map(x => (x, x+1) )
|
||||
val chain2 = (10 until 20).map(x => (x, x+1) )
|
||||
val rawEdges = sc.parallelize(chain1 ++ chain2, 3).map { case (s,d) => (s.toLong, d.toLong) }
|
||||
val twoChains = Graph(rawEdges).reverse
|
||||
val twoChains = Graph(rawEdges, true).reverse
|
||||
val ccGraph = Analytics.connectedComponents(twoChains).cache()
|
||||
val vertices = ccGraph.vertices.collect
|
||||
for ( (id, cc) <- vertices ) {
|
||||
|
@ -167,8 +167,58 @@ class AnalyticsSuite extends FunSuite with LocalSparkContext {
|
|||
else { assert(ccMap(id) === 10) }
|
||||
}
|
||||
}
|
||||
} // end of chain connected components
|
||||
} // end of reverse chain connected components
|
||||
|
||||
test("Count a single triangle") {
|
||||
withSpark(new SparkContext("local", "test")) { sc =>
|
||||
val rawEdges = sc.parallelize(Array( 0L->1L, 1L->2L, 2L->0L ), 2)
|
||||
val graph = Graph(rawEdges, true).cache
|
||||
val triangleCount = Analytics.triangleCount(graph)
|
||||
val verts = triangleCount.vertices
|
||||
verts.collect.foreach { case (vid, count) => assert(count === 1) }
|
||||
}
|
||||
}
|
||||
|
||||
test("Count two triangles") {
|
||||
withSpark(new SparkContext("local", "test")) { sc =>
|
||||
val triangles = Array( 0L -> 1L, 1L -> 2L, 2L -> 0L ) ++
|
||||
Array( 0L -> -1L, -1L -> -2L, -2L -> 0L )
|
||||
val rawEdges = sc.parallelize(triangles, 2)
|
||||
val graph = Graph(rawEdges, true).cache
|
||||
val triangleCount = Analytics.triangleCount(graph)
|
||||
val verts = triangleCount.vertices
|
||||
verts.collect.foreach { case (vid, count) =>
|
||||
if(vid == 0) { assert(count === 2) }
|
||||
else { assert(count === 1) }
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
test("Count two triangles with bi-directed edges") {
|
||||
withSpark(new SparkContext("local", "test")) { sc =>
|
||||
val triangles =
|
||||
Array( 0L -> 1L, 1L -> 2L, 2L -> 0L ) ++
|
||||
Array( 0L -> -1L, -1L -> -2L, -2L -> 0L )
|
||||
val revTriangles = triangles.map { case (a,b) => (b,a) }
|
||||
|
||||
val rawEdges = sc.parallelize(triangles ++ revTriangles, 2)
|
||||
val graph = Graph(rawEdges, true).cache
|
||||
val triangleCount = Analytics.triangleCount(graph)
|
||||
val verts = triangleCount.vertices
|
||||
verts.collect.foreach { case (vid, count) =>
|
||||
if(vid == 0) { assert(count === 4) }
|
||||
else { assert(count === 2) }
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
test("Count a single triangle with duplicate edges") {
|
||||
withSpark(new SparkContext("local", "test")) { sc =>
|
||||
val rawEdges = sc.parallelize(Array( 0L->1L, 1L->2L, 2L->0L ) ++ Array( 0L->1L, 1L->2L, 2L->0L ), 2)
|
||||
val graph = Graph(rawEdges, true).cache
|
||||
val triangleCount = Analytics.triangleCount(graph)
|
||||
val verts = triangleCount.vertices
|
||||
verts.collect.foreach { case (vid, count) => assert(count === 1) }
|
||||
}
|
||||
}
|
||||
} // end of AnalyticsSuite
|
||||
|
|
Loading…
Reference in a new issue